Newer
Older
minerva / Kernel / Arch / riscv64 / SafeMem.cpp
@minerva minerva on 13 Jul 9 KB Initial commit
/*
 * Copyright (c) 2022, Timon Kruiper <timonkruiper@gmail.com>
 * Copyright (c) 2023, Daniel Bertalan <dani@danielbertalan.dev>
 * Copyright (c) 2024, Sönke Holz <sholz8530@gmail.com>
 *
 * SPDX-License-Identifier: BSD-2-Clause
 */

#include <Kernel/Arch/RegisterState.h>
#include <Kernel/Arch/SafeMem.h>
#include <Kernel/Library/StdLib.h>

#define CODE_SECTION(section_name) __attribute__((section(section_name)))

extern "C" u8 start_of_safemem_text[];

extern "C" u8 safe_memset_ins[];
extern "C" u8 safe_memset_faulted[];

extern "C" u8 safe_strnlen_ins[];
extern "C" u8 safe_strnlen_faulted[];

extern "C" u8 safe_memcpy_ins_1[];
extern "C" u8 safe_memcpy_ins_2[];
extern "C" u8 safe_memcpy_faulted[];

extern "C" u8 end_of_safemem_text[];

extern "C" u8 start_of_safemem_atomic_text[];

extern "C" u8 safe_atomic_compare_exchange_relaxed_ins_1[];
extern "C" u8 safe_atomic_compare_exchange_relaxed_ins_2[];
extern "C" u8 safe_atomic_compare_exchange_relaxed_faulted[];

extern "C" u8 safe_atomic_load_relaxed_ins[];
extern "C" u8 safe_atomic_load_relaxed_faulted[];

extern "C" u8 safe_atomic_fetch_add_relaxed_ins[];
extern "C" u8 safe_atomic_fetch_add_relaxed_faulted[];

extern "C" u8 safe_atomic_exchange_relaxed_ins[];
extern "C" u8 safe_atomic_exchange_relaxed_faulted[];

extern "C" u8 safe_atomic_store_relaxed_ins[];
extern "C" u8 safe_atomic_store_relaxed_faulted[];

extern "C" u8 end_of_safemem_atomic_text[];

namespace Kernel {

CODE_SECTION(".text.safemem")
NEVER_INLINE NAKED bool safe_memset(void*, int, size_t, void*&)
{
    // a0: void* dest_ptr
    // a1: int c
    // a2: size_t n
    // a3: void*& fault_at

    asm(R"(
    beqz a2, 2f
    add t0, a0, a2      # t0: pointer to the (exclusive) end of the target memory area

1:
.global safe_memset_ins
safe_memset_ins:
    sb a1, (a0)
    addi a0, a0, 1
    bne a0, t0, 1b

2:
    li a0, 1
.global safe_memset_faulted
safe_memset_faulted:
    ret
)");
}

CODE_SECTION(".text.safemem")
NEVER_INLINE NAKED ssize_t safe_strnlen(char const*, unsigned long, void*&)
{
    // a0: char const* str
    // a1: unsigned long max_n
    // a2: void*& fault_at

    asm(R"(
    mv t0, a0            # t0: original string pointer
    li a0, 0             # a0: result
    beqz a1, 2f

1:
    add t1, t0, a0       # t1: pointer to current char
.global safe_strnlen_ins
safe_strnlen_ins:
    lbu t2, (t1)         # t2: current char
    beqz t2, 2f
    addi a0, a0, 1
    bne a0, a1, 1b

2:
.global safe_strnlen_faulted
safe_strnlen_faulted:
    ret
)");
}

CODE_SECTION(".text.safemem")
NEVER_INLINE NAKED bool safe_memcpy(void*, void const*, unsigned long, void*&)
{
    // a0: void* dest_ptr
    // a1: void const* src_ptr
    // a2: unsigned long n
    // a3: void*& fault_at

    asm(R"(
    beqz a2, 2f
    add t0, a0, a2         # t0: pointer to the (exclusive) end of the target memory area

1:
.global safe_memcpy_ins_1
safe_memcpy_ins_1:
    lbu t1, (a1)           # t1: byte to copy
.global safe_memcpy_ins_2
safe_memcpy_ins_2:
    sb t1, (a0)
    addi a0, a0, 1
    addi a1, a1, 1
    bne a0, t0, 1b

2:
    li a0, 1
.global safe_memcpy_faulted
safe_memcpy_faulted:
    ret
)");
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<bool> safe_atomic_compare_exchange_relaxed(u32 volatile* var, u32& expected, u32 desired)
{
    // Based on Example 2, https://github.com/riscv/riscv-isa-manual/releases/download/20240411/unpriv-isa-asciidoc.pdf
    FlatPtr result;
    register FlatPtr error asm("t6") = 0; // handle_safe_access_fault sets t6 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
    li %[result], 0
    lw t0, (%[expected_ptr])                       # t0: expected value

1:
.global safe_atomic_compare_exchange_relaxed_ins_1
safe_atomic_compare_exchange_relaxed_ins_1:
    lr.w t1, (%[var_ptr])                          # Load the value at *var into t1.
    bne t1, t0, 2f                                 # Doesn't match the expected value, so fail.
.global safe_atomic_compare_exchange_relaxed_ins_2
safe_atomic_compare_exchange_relaxed_ins_2:
    sc.w t2, %[desired], (%[var_ptr])              # Try to update the value at *var.
    bnez t2, 1b                                    # Retry if sc.w failed (that is when t2 != 0).
    li %[result], 1
    j 3f

2:
    sw t1, (%[expected_ptr])                       # Write the read value to expected on failure.
3:
.global safe_atomic_compare_exchange_relaxed_faulted
safe_atomic_compare_exchange_relaxed_faulted:
)"
                 : [result] "=&r"(result), "+&r"(error)
                 : [var_ptr] "r"(var), [expected_ptr] "r"(&expected), [desired] "r"(desired)
                 : "memory", "t0", "t1", "t2");
    if (error != 0)
        return {};
    return static_cast<bool>(result);
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_load_relaxed(u32 volatile* var)
{
    u32 result;
    register FlatPtr error asm("t6") = 0; // handle_safe_access_fault sets t6 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_load_relaxed_ins
safe_atomic_load_relaxed_ins:
    lw %[result], (%[var_ptr])
.global safe_atomic_load_relaxed_faulted
safe_atomic_load_relaxed_faulted:
)"
                 : [result] "=r"(result), "+r"(error)
                 : [var_ptr] "r"(var)
                 : "memory");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_fetch_add_relaxed(u32 volatile* var, u32 val)
{
    u32 result;
    register FlatPtr error asm("t6") = 0; // handle_safe_access_fault sets t6 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_fetch_add_relaxed_ins
safe_atomic_fetch_add_relaxed_ins:
    amoadd.w %[result], %[val], (%[var_ptr])
.global safe_atomic_fetch_add_relaxed_faulted
safe_atomic_fetch_add_relaxed_faulted:
)"
                 : [result] "=r"(result), "+r"(error)
                 : [val] "r"(val), [var_ptr] "r"(var)
                 : "memory");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN Optional<u32> safe_atomic_exchange_relaxed(u32 volatile* var, u32 desired)
{
    u32 result;
    register FlatPtr error asm("t6") = 0; // handle_safe_access_fault sets t6 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_exchange_relaxed_ins
safe_atomic_exchange_relaxed_ins:
    amoswap.w %[result], %[desired], (%[var_ptr])
.global safe_atomic_exchange_relaxed_faulted
safe_atomic_exchange_relaxed_faulted:
)"
                 : [result] "=r"(result), "+r"(error)
                 : [desired] "r"(desired), [var_ptr] "r"(var)
                 : "memory");
    if (error != 0)
        return {};
    return result;
}

CODE_SECTION(".text.safemem.atomic")
NEVER_INLINE FLATTEN bool safe_atomic_store_relaxed(u32 volatile* var, u32 desired)
{
    register FlatPtr error asm("t6") = 0; // handle_safe_access_fault sets t6 to 1 when a page fault occurs in one of the safe_atomic_* functions.
    asm volatile(R"(
.global safe_atomic_store_relaxed_ins
safe_atomic_store_relaxed_ins:
    sw %[desired], (%[var_ptr])
.global safe_atomic_store_relaxed_faulted
safe_atomic_store_relaxed_faulted:
)"
                 : "+r"(error)
                 : [desired] "r"(desired), [var_ptr] "r"(var)
                 : "memory");
    return error == 0;
}

bool handle_safe_access_fault(RegisterState& regs, FlatPtr fault_address)
{
    FlatPtr pc = regs.ip();

    if (pc >= bit_cast<FlatPtr>(&start_of_safemem_text) && pc < bit_cast<FlatPtr>(&end_of_safemem_text)) {
        // If we detect that the fault happened in safe_memcpy() safe_strnlen(),
        // or safe_memset() then resume at the appropriate _faulted label
        // and set fault_at to the faulting address
        if (pc == bit_cast<FlatPtr>(&safe_memset_ins)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_memset_faulted));
            regs.x[9] = 0;
            *bit_cast<FlatPtr*>(regs.x[12]) = fault_address; // a3: void*& fault_at
            return true;
        }
        if (pc == bit_cast<FlatPtr>(&safe_strnlen_ins)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_strnlen_faulted));
            regs.x[9] = -1;
            *bit_cast<FlatPtr*>(regs.x[11]) = fault_address; // a2: void*& fault_at
            return true;
        }
        if (pc == bit_cast<FlatPtr>(&safe_memcpy_ins_1) || pc == bit_cast<FlatPtr>(&safe_memcpy_ins_2)) {
            regs.set_ip(bit_cast<FlatPtr>(&safe_memcpy_faulted));
            regs.x[9] = 0;
            *bit_cast<FlatPtr*>(regs.x[12]) = fault_address; // a3: void*& fault_at
            return true;
        }
    } else if (pc >= bit_cast<FlatPtr>(&start_of_safemem_atomic_text) && pc < bit_cast<FlatPtr>(&end_of_safemem_atomic_text)) {
        // If we detect that a fault happened in one of the atomic safe_
        // functions, resume at the appropriate _faulted label and set
        // the t6 register to 1 to indicate an error
        if (pc == bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_ins_1) || pc == bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_ins_2)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_compare_exchange_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_load_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_load_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_fetch_add_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_fetch_add_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_exchange_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_exchange_relaxed_faulted);
        } else if (pc == bit_cast<FlatPtr>(&safe_atomic_store_relaxed_ins)) {
            pc = bit_cast<FlatPtr>(&safe_atomic_store_relaxed_faulted);
        } else {
            return false;
        }

        regs.set_ip(pc);
        regs.x[30] = 1;
        return true;
    }

    return false;
}

}