Newer
Older
minerva / Kernel / Arch / aarch64 / SafeMem.cpp
@minerva minerva on 13 Jul 11 KB Initial commit
/*
 * Copyright (c) 2022, Timon Kruiper <timonkruiper@gmail.com>
 * Copyright (c) 2023, Daniel Bertalan <dani@danielbertalan.dev>
 * Copyright (c) 2025, Sönke Holz <sholz8530@gmail.com>
 *
 * SPDX-License-Identifier: BSD-2-Clause
 */

#include <Kernel/Arch/RegisterState.h>
#include <Kernel/Arch/SafeMem.h>
#include <Kernel/Library/StdLib.h>

#define CODE_SECTION(section_name) __attribute__((section(section_name)))

extern "C" u8 start_of_safemem_text[];

extern "C" u8 safe_memset_ins[];
extern "C" u8 safe_memset_faulted[];

extern "C" u8 safe_strnlen_ins[];
extern "C" u8 safe_strnlen_faulted[];

extern "C" u8 safe_memcpy_ins_1[];
extern "C" u8 safe_memcpy_ins_2[];
extern "C" u8 safe_memcpy_faulted[];

extern "C" u8 end_of_safemem_text[];

extern "C" u8 start_of_safemem_atomic_text[];

extern "C" u8 safe_atomic_compare_exchange_relaxed_ins_1[];
extern "C" u8 safe_atomic_compare_exchange_relaxed_ins_2[];
extern "C" u8 safe_atomic_compare_exchange_relaxed_faulted[];

extern "C" u8 safe_atomic_load_relaxed_ins[];
extern "C" u8 safe_atomic_load_relaxed_faulted[];

extern "C" u8 safe_atomic_fetch_add_relaxed_ins_1[];
extern "C" u8 safe_atomic_fetch_add_relaxed_ins_2[];
extern "C" u8 safe_atomic_fetch_add_relaxed_faulted[];

extern "C" u8 safe_atomic_exchange_relaxed_ins_1[];
extern "C" u8 safe_atomic_exchange_relaxed_ins_2[];
extern "C" u8 safe_atomic_exchange_relaxed_faulted[];

extern "C" u8 safe_atomic_store_relaxed_ins[];
extern "C" u8 safe_atomic_store_relaxed_faulted[];

extern "C" u8 end_of_safemem_atomic_text[];
namespace Kernel {

CODE_SECTION(".text.safemem")
NEVER_INLINE FLATTEN bool safe_memset(void* dest_ptr, int c, size_t n, void*& fault_at)
{
    register FlatPtr result asm("x0") = 0;                // handle_safe_access_fault sets x0 to 0 if a fault occurred.
    register void** fault_at_in_x3 asm("x3") = &fault_at; // ensure fault_at stays in x3 so handle_safe_access_fault can set it to the faulting address.
    asm volatile(R"(
    cbz %[n], 2f
    add x4, %[dest_ptr], %[n]          // x4: pointer to the (exclusive) end of the target memory area

1:
.global safe_memset_ins
safe_memset_ins:
    strb %w[c], [%[dest_ptr]], #1
    cmp %[dest_ptr], x4
    b.ne 1b

2:
    mov %[result], #1
.global safe_memset_faulted
safe_memset_faulted:
)"
                 : [dest_ptr] "+&r"(dest_ptr), [result] "+&r"(result), "+&r"(fault_at_in_x3)
                 : [n] "r"(n), [c] "r"(c)
                 : "memory", "x4");
    return result != 0;
}

CODE_SECTION(".text.safemem")
NEVER_INLINE FLATTEN ssize_t safe_strnlen(char const* str, unsigned long max_n, void*& fault_at)
{
    register ssize_t result asm("x0") = 0;                // handle_safe_access_fault sets x0 to -1 if a fault occurred.
    register void** fault_at_in_x2 asm("x2") = &fault_at; // ensure fault_at stays in x2 so handle_safe_access_fault can set it to the faulting address.
    asm volatile(R"(
    cbz %[max_n], 2f
    mov %[result], #0

1:
.global safe_strnlen_ins
safe_strnlen_ins:
    ldrb w3, [%[str], %[result]]     // w3: current char
    cbz w3, 2f
    add %[result], %[result], #1
    cmp %[result], %[max_n]
    b.ne 1b

2:
.global safe_strnlen_faulted
safe_strnlen_faulted:
)"
                 : [result] "+&r"(result), "+&r"(fault_at_in_x2)
                 : [str] "r"(str), [max_n] "r"(max_n)
                 : "memory", "w3");
    return result;
}

CODE_SECTION(".text.safemem")
NEVER_INLINE FLATTEN bool safe_memcpy(void* dest_ptr, void const* src_ptr, unsigned long n, void*& fault_at)
{
    register FlatPtr result asm("x0") = 0;                // handle_safe_access_fault sets x0 to 0 if a fault occurred.
    register void** fault_at_in_x3 asm("x3") = &fault_at; // ensure fault_at stays in x3 so handle_safe_access_fault can set it to the faulting address.
    asm volatile(R"(
    cbz %[n], 2f
    mov x4, #0                 // x4: current index

1:
.global safe_memcpy_ins_1
safe_memcpy_ins_1:
    ldrb w5, [%[src_ptr], x4]  // w5: byte to copy
.global safe_memcpy_ins_2
safe_memcpy_ins_2:
    strb w5, [%[dest_ptr], x4]
    add x4, x4, #1
    cmp x4, %[n]
    b.ne 1b

2:
    mov %[result], #1
.global safe_memcpy_faulted
safe_memcpy_faulted:
)"
                 : [result] "+&r"(result), "+&r"(fault_at_in_x3)
                 : [dest_ptr] "r"(dest_ptr), [src_ptr] "r"(src_ptr), [n] "r"(n)
                 : "memory", "x4", "w5");
    return result != 0;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<bool> safe_atomic_compare_exchange_relaxed(u32 volatile* var, u32& expected, u32 desired)
{
    FlatPtr result;
    register FlatPtr error asm("x15") = 0; // handle_safe_access_fault sets x15 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
    mov %[result], #0
    ldr w3, [%[expected_ptr]]                      // w3: expected value

1:
.global safe_atomic_compare_exchange_relaxed_ins_1
safe_atomic_compare_exchange_relaxed_ins_1:
    ldxr w4, [%[var_ptr]]                          // Load the value at *var into w4.
    cmp w4, w3
    b.ne 2f                                        // Doesn't match the expected value, so fail.
.global safe_atomic_compare_exchange_relaxed_ins_2
safe_atomic_compare_exchange_relaxed_ins_2:
    stxr w5, %w[desired], [%[var_ptr]]             // Try to update the value at *var.
    cbnz w5, 1b                                    // Retry if stxr failed (that is when w5 != 0).
    mov %[result], #1
    b 3f

2:
    str w4, [%[expected_ptr]]                      // Write the read value to expected on failure.
3:
.global safe_atomic_compare_exchange_relaxed_faulted
safe_atomic_compare_exchange_relaxed_faulted:
)"
                 : [result] "=&r"(result), "+&r"(error)
                 : [var_ptr] "r"(var), [expected_ptr] "r"(&expected), [desired] "r"(desired)
                 : "memory", "w3", "w4", "w5");
    if (error != 0)
        return {};
    return static_cast<bool>(result);
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_load_relaxed(u32 volatile* var)
{
    u32 result;
    register FlatPtr error asm("x15") = 0; // handle_safe_access_fault sets x15 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_load_relaxed_ins
safe_atomic_load_relaxed_ins:
    ldr %w[result], [%[var_ptr]]
.global safe_atomic_load_relaxed_faulted
safe_atomic_load_relaxed_faulted:
)"
                 : [result] "=r"(result), "+r"(error)
                 : [var_ptr] "r"(var)
                 : "memory");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_fetch_add_relaxed(u32 volatile* var, u32 val)
{
    u32 result;
    register FlatPtr error asm("x15") = 0; // handle_safe_access_fault sets x15 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
1:
.global safe_atomic_fetch_add_relaxed_ins_1
safe_atomic_fetch_add_relaxed_ins_1:
    ldxr %w[result], [%[var_ptr]]
    add w2, %w[result], %w[val]
.global safe_atomic_fetch_add_relaxed_ins_2
safe_atomic_fetch_add_relaxed_ins_2:
    stxr w3, w2, [%[var_ptr]]
    cbnz w3, 1b
.global safe_atomic_fetch_add_relaxed_faulted
safe_atomic_fetch_add_relaxed_faulted:
)"
                 : [result] "=&r"(result), "+&r"(error)
                 : [val] "r"(val), [var_ptr] "r"(var)
                 : "memory", "w2", "w3");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_exchange_relaxed(u32 volatile* var, u32 desired)
{
    u32 result;
    register FlatPtr error asm("x15") = 0; // handle_safe_access_fault sets x15 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
1:
.global safe_atomic_exchange_relaxed_ins_1
safe_atomic_exchange_relaxed_ins_1:
    ldxr %w[result], [%[var_ptr]]
.global safe_atomic_exchange_relaxed_ins_2
safe_atomic_exchange_relaxed_ins_2:
    stxr w2, %w[desired], [%[var_ptr]]
    cbnz w2, 1b
.global safe_atomic_exchange_relaxed_faulted
safe_atomic_exchange_relaxed_faulted:
)"
                 : [result] "=&r"(result), "+&r"(error)
                 : [desired] "r"(desired), [var_ptr] "r"(var)
                 : "memory", "w2");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN bool safe_atomic_store_relaxed(u32 volatile* var, u32 desired)
{
    register FlatPtr error asm("x15") = 0; // handle_safe_access_fault sets x15 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_store_relaxed_ins
safe_atomic_store_relaxed_ins:
    str %w[desired], [%[var_ptr]]
.global safe_atomic_store_relaxed_faulted
safe_atomic_store_relaxed_faulted:
)"
                 : "+r"(error)
                 : [desired] "r"(desired), [var_ptr] "r"(var)
                 : "memory");
    return error == 0;
}

bool handle_safe_access_fault(RegisterState& regs, FlatPtr fault_address)
{
    FlatPtr pc = regs.ip();

    if (pc >= bit_cast<FlatPtr>(&start_of_safemem_text) && pc < bit_cast<FlatPtr>(&end_of_safemem_text)) {
        // If we detect that the fault happened in safe_memcpy(), safe_strnlen(),
        // or safe_memset(), then resume at the appropriate _faulted label
        // and set fault_at to the faulting address.
        if (pc == bit_cast<FlatPtr>(&safe_memset_ins)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_memset_faulted));
            regs.x[0] = 0;
            *bit_cast<FlatPtr*>(regs.x[3]) = fault_address; // x3: void*& fault_at
            return true;
        }
        if (pc == bit_cast<FlatPtr>(&safe_strnlen_ins)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_strnlen_faulted));
            regs.x[0] = -1;
            *bit_cast<FlatPtr*>(regs.x[2]) = fault_address; // x2: void*& fault_at
            return true;
        }
        if (pc == bit_cast<FlatPtr>(&safe_memcpy_ins_1) || pc == bit_cast<FlatPtr>(&safe_memcpy_ins_2)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_memcpy_faulted));
            regs.x[0] = 0;
            *bit_cast<FlatPtr*>(regs.x[3]) = fault_address; // x3: void*& fault_at
            return true;
        }
    } else if (pc >= bit_cast<FlatPtr>(&start_of_safemem_atomic_text) && pc < bit_cast<FlatPtr>(&end_of_safemem_atomic_text)) {
        // If we detect that a fault happened in one of the atomic safe_
        // functions, resume at the appropriate _faulted label and set
        // the x15 register to 1 to indicate an error.
        if (pc == bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_ins_1) || pc == bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_ins_2)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_load_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_load_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_fetch_add_relaxed_ins_1) || pc == bit_cast<FlatPtr>(&safe_atomic_fetch_add_relaxed_ins_2)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_fetch_add_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_exchange_relaxed_ins_1) || pc == bit_cast<FlatPtr>(&safe_atomic_exchange_relaxed_ins_2)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_exchange_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_store_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_store_relaxed_faulted);
        } else {
            return false;
        }

        regs.set_ip(pc);
        regs.x[15] = 1;
        return true;
    }

    return false;
}

}