Initial Commit

Initial commit of arm64 bring-up code, kernel core (libkernel) and build
infrastructure.
This commit is contained in:
Matthew Leach
2025-06-14 16:05:40 +01:00
commit ca6fdd0da5
177 changed files with 27215 additions and 0 deletions

19
libkernel/Cargo.toml Normal file
View File

@@ -0,0 +1,19 @@
[package]
name = "libkernel"
version = "0.0.0"
edition = "2024"
[dependencies]
paste = "1.0.15"
thiserror = { version = "2.0.12", default-features = false }
tock-registers = "0.10.1"
log = "0.4.27"
async-trait = "0.1.88"
object = { version = "0.37.1", default-features = false, features = ["core", "elf", "read_core"] }
bitflags = "2.9.1"
ringbuf = { version = "0.4.8", default-features = false, features = ["alloc"] }
intrusive-collections = { version = "0.9.7", default-features = false }
[dev-dependencies]
rand = "0.9.1"
tokio = { version = "1.47.1", features = ["full"] }

View File

@@ -0,0 +1,4 @@
pub mod pg_descriptors;
pub mod pg_tables;
pub mod pg_walk;
pub mod tlb;

View File

@@ -0,0 +1,574 @@
use paste::paste;
use tock_registers::interfaces::{ReadWriteable, Readable};
use tock_registers::{register_bitfields, registers::InMemoryRegister};
use crate::memory::PAGE_SHIFT;
use crate::memory::address::{PA, VA};
use crate::memory::permissions::PtePermissions;
use crate::memory::region::PhysMemoryRegion;
/// Trait for common behavior across different types of page table entries.
pub trait PageTableEntry: Sized + Copy + Clone {
/// Returns `true` if the entry is valid (i.e., not an Invalid/Fault entry).
fn is_valid(self) -> bool;
/// Returns the raw value of this page descriptor.
fn as_raw(self) -> u64;
/// Returns a representation of the page descriptor from a raw value.
fn from_raw(v: u64) -> Self;
/// Return a new invalid page descriptor.
fn invalid() -> Self;
}
/// Trait for descriptors that can point to a next-level table.
pub trait TableMapper: PageTableEntry {
/// Returns the physical address of the next-level table, if this descriptor
/// is a table descriptor.
fn next_table_address(self) -> Option<PA>;
/// Creates a new descriptor that points to the given next-level table.
fn new_next_table(pa: PA) -> Self;
}
/// A descriptor that maps a physical address (L1, L2 blocks and L3 page).
pub trait PaMapper: PageTableEntry {
/// Constructs a new valid page descriptor that maps a physical address.
fn new_map_pa(page_address: PA, memory_type: MemoryType, perms: PtePermissions) -> Self;
/// Return how many bytes this descriptor type maps.
fn map_shift() -> usize;
/// Whether a subsection of the region could be mapped via this type of
/// page.
fn could_map(region: PhysMemoryRegion, va: VA) -> bool;
/// Return the mapped physical address.
fn mapped_address(self) -> Option<PA>;
}
#[derive(Clone, Copy)]
struct TableAddr(PA);
impl TableAddr {
fn as_raw_parts(&self) -> u64 {
(self.0.value() as u64) & !((1 << PAGE_SHIFT) - 1)
}
fn from_raw_parts(v: u64) -> Self {
Self(PA::from_value(v as usize & !((1 << PAGE_SHIFT) - 1)))
}
}
#[derive(Debug, Clone, Copy)]
pub enum MemoryType {
Device,
Normal,
}
macro_rules! define_descriptor {
(
$(#[$outer:meta])*
$name:ident,
// Optional: Implement TableMapper if this section is present
$( table: $table_bits:literal, )?
// Optional: Implement PaMapper if this section is present
$( map: {
bits: $map_bits:literal,
shift: $tbl_shift:literal,
oa_len: $oa_len:literal,
},
)?
) => {
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
$(#[$outer])*
pub struct $name(u64);
impl PageTableEntry for $name {
fn is_valid(self) -> bool { (self.0 & 0b11) != 0 }
fn as_raw(self) -> u64 { self.0 }
fn from_raw(v: u64) -> Self { Self(v) }
fn invalid() -> Self { Self(0) }
}
$(
impl TableMapper for $name {
fn next_table_address(self) -> Option<PA> {
if (self.0 & 0b11) == $table_bits {
Some(TableAddr::from_raw_parts(self.0).0)
} else {
None
}
}
fn new_next_table(pa: PA) -> Self {
Self(TableAddr(pa).as_raw_parts() | $table_bits)
}
}
)?
$(
paste! {
#[allow(non_snake_case)]
mod [<$name Fields>] {
use super::*;
register_bitfields![u64,
pub BlockPageFields [
ATTR_INDEX OFFSET(2) NUMBITS(3) [],
AP OFFSET(6) NUMBITS(2) [ RW_EL1 = 0b00, RW_EL0 = 0b01, RO_EL1 = 0b10, RO_EL0 = 0b11 ],
SH OFFSET(8) NUMBITS(2) [ NonShareable = 0b00, Unpredictable = 0b01, OuterShareable = 0b10, InnerShareable = 0b11 ],
AF OFFSET(10) NUMBITS(1) [ Accessed = 1 ],
PXN OFFSET(53) NUMBITS(1) [ NotExecutableAtEL1 = 1, ExecutableAtEL1 = 0 ],
XN OFFSET(54) NUMBITS(1) [ NotExecutable = 1, Executable = 0 ],
// Software defined bit
COW OFFSET(55) NUMBITS(1) [ CowShared = 1, NotCowShared = 0 ],
OUTPUT_ADDR OFFSET($tbl_shift) NUMBITS($oa_len) []
]
];
}
impl $name {
/// Returns the interpreted permissions if this is a block/page
/// descriptor.
pub fn permissions(self) -> Option<PtePermissions> {
// Check if the descriptor bits match the block/page type
if (self.0 & 0b11) != $map_bits {
return None;
}
let reg = InMemoryRegister::new(self.0);
let ap_val = reg.read([<$name Fields>]::BlockPageFields::AP);
let (write, user) = match ap_val {
0b00 => (true, false), // RW_EL1
0b01 => (true, true), // RW_EL0
0b10 => (false, false), // RO_EL1
0b11 => (false, true), // RO_EL0
_ => unreachable!(),
};
let xn = reg.is_set([<$name Fields>]::BlockPageFields::XN);
let cow = reg.is_set([<$name Fields>]::BlockPageFields::COW);
let execute = !xn;
Some(PtePermissions::from_raw_bits(
true, // Always true if valid
write,
execute,
user,
cow,
))
}
pub fn set_permissions(self, perms: PtePermissions) -> Self {
let reg = InMemoryRegister::new(self.0);
use [<$name Fields>]::BlockPageFields;
let ap = match (perms.is_user(), perms.is_write()) {
(false, true) => BlockPageFields::AP::RW_EL1,
(true, true) => BlockPageFields::AP::RW_EL0,
(false, false) => BlockPageFields::AP::RO_EL1,
(true, false) => BlockPageFields::AP::RO_EL0,
};
reg.modify(ap);
if !perms.is_execute() {
reg.modify(BlockPageFields::XN::NotExecutable + BlockPageFields::PXN::NotExecutableAtEL1);
} else {
reg.modify(BlockPageFields::XN::Executable + BlockPageFields::PXN::ExecutableAtEL1);
}
if perms.is_cow() {
reg.modify(BlockPageFields::COW::CowShared)
} else {
reg.modify(BlockPageFields::COW::NotCowShared)
}
Self(reg.get())
}
}
impl PaMapper for $name {
fn map_shift() -> usize { $tbl_shift }
fn could_map(region: PhysMemoryRegion, va: VA) -> bool {
let is_aligned = |addr: usize| (addr & ((1 << $tbl_shift) - 1)) == 0;
is_aligned(region.start_address().value())
&& is_aligned(va.value())
&& region.size() >= (1 << $tbl_shift)
}
fn new_map_pa(page_address: PA, memory_type: MemoryType, perms: PtePermissions) -> Self {
let is_aligned = |addr: usize| (addr & ((1 << $tbl_shift) - 1)) == 0;
if !is_aligned(page_address.value()) {
panic!("Cannot map non-aligned physical address");
}
let reg = InMemoryRegister::new(0);
use [<$name Fields>]::BlockPageFields;
reg.modify(BlockPageFields::OUTPUT_ADDR.val((page_address.value() >> $tbl_shift) as u64)
+ BlockPageFields::AF::Accessed);
match memory_type {
MemoryType::Device => {
reg.modify(BlockPageFields::SH::NonShareable + BlockPageFields::ATTR_INDEX.val(1));
}
MemoryType::Normal => {
reg.modify(BlockPageFields::SH::InnerShareable + BlockPageFields::ATTR_INDEX.val(0));
}
}
Self(reg.get() | $map_bits).set_permissions(perms)
}
fn mapped_address(self) -> Option<PA> {
use [<$name Fields>]::BlockPageFields;
match self.0 & 0b11 {
0b00 => return None,
// Swapped out page.
0b10 => {},
$map_bits => {},
_ => return None,
}
let reg = InMemoryRegister::new(self.0);
let addr = reg.read(BlockPageFields::OUTPUT_ADDR);
Some(PA::from_value((addr << $tbl_shift) as usize))
}
}
}
)?
};
}
define_descriptor!(
/// A Level 0 descriptor. Can only be an invalid or table descriptor.
L0Descriptor,
table: 0b11,
);
define_descriptor!(
/// A Level 1 descriptor. Can be a block, table, or invalid descriptor.
L1Descriptor,
table: 0b11,
map: {
bits: 0b01, // L1 Block descriptor has bits[1:0] = 01
shift: 30, // Maps a 1GiB block
oa_len: 18, // Output address length for 48-bit PA
},
);
define_descriptor!(
/// A Level 2 descriptor. Can be a block, table, or invalid descriptor.
L2Descriptor,
table: 0b11,
map: {
bits: 0b01, // L2 Block descriptor has bits[1:0] = 01
shift: 21, // Maps a 2MiB block
oa_len: 27, // Output address length for 48-bit PA
},
);
define_descriptor!(
/// A Level 3 descriptor. Can be a page or invalid descriptor.
L3Descriptor,
// Note: No 'table' capability at L3.
map: {
bits: 0b11, // L3 Page descriptor has bits[1:0] = 11
shift: 12, // Maps a 4KiB page
oa_len: 36, // Output address length for 48-bit PA
},
);
pub enum L3DescriptorState {
Invalid,
Swapped,
Valid,
}
impl L3Descriptor {
const SWAPPED_BIT: u64 = 1 << 1;
const STATE_MASK: u64 = 0b11;
/// Checks if this is a non-present entry (e.g., PROT_NONE or paged to
/// disk).
pub fn state(self) -> L3DescriptorState {
match self.0 & Self::STATE_MASK {
0b00 => L3DescriptorState::Invalid,
0b10 => L3DescriptorState::Swapped,
0b01 => L3DescriptorState::Invalid,
0b11 => L3DescriptorState::Valid,
_ => unreachable!(),
}
}
/// Mark an existing PTE as swapped (invalid but containing valid
/// information).
pub fn mark_as_swapped(self) -> Self {
Self(Self::SWAPPED_BIT | (self.0 & !Self::STATE_MASK))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::region::PhysMemoryRegion;
use crate::memory::{PAGE_SHIFT, PAGE_SIZE};
const KERNEL_PERMS: bool = false;
const USER_PERMS: bool = true;
#[test]
fn test_invalid_descriptor() {
let d = L0Descriptor::invalid();
assert!(!d.is_valid());
assert_eq!(d.as_raw(), 0);
}
#[test]
fn test_l0_table_descriptor() {
let pa = PA::from_value(0x1000_0000);
let d = L0Descriptor::new_next_table(pa);
assert!(d.is_valid());
assert_eq!(d.as_raw(), 0x1000_0000 | 0b11);
assert_eq!(d.next_table_address(), Some(pa));
}
#[test]
fn test_l1_table_descriptor() {
let pa = PA::from_value(0x2000_0000);
let d = L1Descriptor::new_next_table(pa);
assert!(d.is_valid());
assert_eq!(d.as_raw(), 0x2000_0000 | 0b11);
assert_eq!(d.next_table_address(), Some(pa));
assert!(d.mapped_address().is_none());
assert!(d.permissions().is_none());
}
#[test]
fn test_l1_block_creation() {
let pa = PA::from_value(1 << 30); // 1GiB aligned
let perms = PtePermissions::rw(KERNEL_PERMS);
let d = L1Descriptor::new_map_pa(pa, MemoryType::Normal, perms);
assert!(d.is_valid());
assert_eq!(d.as_raw() & 0b11, 0b01); // Is a block descriptor
assert!(d.next_table_address().is_none());
// Check address part (bits [47:30])
assert_eq!((d.as_raw() >> 30) & 0x3_FFFF, 1);
// AF bit should be set
assert_ne!(d.as_raw() & (1 << 10), 0);
}
#[test]
fn test_l1_block_permissions() {
let pa = PA::from_value(1 << 30);
let d_krw =
L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rw(KERNEL_PERMS));
let d_kro =
L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::ro(KERNEL_PERMS));
let d_urw =
L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rw(USER_PERMS));
let d_uro =
L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::ro(USER_PERMS));
let d_krwx =
L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rwx(KERNEL_PERMS));
assert_eq!(d_krw.permissions(), Some(PtePermissions::rw(KERNEL_PERMS)));
assert_eq!(d_kro.permissions(), Some(PtePermissions::ro(KERNEL_PERMS)));
assert_eq!(d_urw.permissions(), Some(PtePermissions::rw(USER_PERMS)));
assert_eq!(d_uro.permissions(), Some(PtePermissions::ro(USER_PERMS)));
assert_eq!(
d_krwx.permissions(),
Some(PtePermissions::rwx(KERNEL_PERMS))
);
// Verify XN bit is NOT set for executable
assert_eq!(d_krwx.as_raw() & (1 << 54), 0);
// Verify XN bit IS set for non-executable
assert_ne!(d_krw.as_raw() & (1 << 54), 0);
}
#[test]
fn test_l1_could_map() {
let one_gib = 1 << 30;
let good_region = PhysMemoryRegion::new(PA::from_value(one_gib), one_gib);
let good_va = VA::from_value(one_gib * 2);
assert!(L1Descriptor::could_map(good_region, good_va));
// Bad region size
let small_region = PhysMemoryRegion::new(PA::from_value(one_gib), one_gib - 1);
assert!(!L1Descriptor::could_map(small_region, good_va));
// Bad region alignment
let unaligned_region = PhysMemoryRegion::new(PA::from_value(one_gib + 1), one_gib);
assert!(!L1Descriptor::could_map(unaligned_region, good_va));
// Bad VA alignment
let unaligned_va = VA::from_value(one_gib + 1);
assert!(!L1Descriptor::could_map(good_region, unaligned_va));
}
#[test]
#[should_panic]
fn test_l1_map_unaligned_pa_panics() {
let pa = PA::from_value((1 << 30) + 1); // Not 1GiB aligned
let perms = PtePermissions::rw(KERNEL_PERMS);
L1Descriptor::new_map_pa(pa, MemoryType::Normal, perms);
}
#[test]
fn test_l1_from_raw_roundtrip() {
let pa = PA::from_value(1 << 30);
let d = L1Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rw(false));
let raw = d.as_raw();
let decoded = L1Descriptor::from_raw(raw);
assert_eq!(decoded.as_raw(), d.as_raw());
assert_eq!(decoded.mapped_address(), d.mapped_address());
assert_eq!(decoded.permissions(), d.permissions());
}
#[test]
fn test_l2_block_creation() {
let pa = PA::from_value(2 << 21); // 2MiB aligned
let perms = PtePermissions::rw(USER_PERMS);
let d = L2Descriptor::new_map_pa(pa, MemoryType::Normal, perms);
assert!(d.is_valid());
assert_eq!(d.as_raw() & 0b11, 0b01); // L2 block
assert!(d.next_table_address().is_none());
assert_eq!(d.mapped_address(), Some(pa));
}
#[test]
fn test_l2_block_permissions() {
let pa = PA::from_value(4 << 21); // 2MiB aligned
let d_kro =
L2Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::ro(KERNEL_PERMS));
let d_krwx =
L2Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rwx(KERNEL_PERMS));
assert_eq!(d_kro.permissions(), Some(PtePermissions::ro(KERNEL_PERMS)));
assert_eq!(
d_krwx.permissions(),
Some(PtePermissions::rwx(KERNEL_PERMS))
);
// XN bit for execute = false should not be set
assert_eq!(d_krwx.as_raw() & (1 << 54), 0);
// XN bit for execute = false should be set
assert_ne!(d_kro.as_raw() & (1 << 54), 0);
}
#[test]
fn test_l2_could_map() {
let size = 1 << 21;
let good_region = PhysMemoryRegion::new(PA::from_value(size), size);
let good_va = VA::from_value(size * 3);
assert!(L2Descriptor::could_map(good_region, good_va));
let unaligned_pa = PhysMemoryRegion::new(PA::from_value(size + 1), size);
let unaligned_va = VA::from_value(size + 1);
assert!(!L2Descriptor::could_map(unaligned_pa, good_va));
assert!(!L2Descriptor::could_map(good_region, unaligned_va));
}
#[test]
#[should_panic]
fn test_l2_map_unaligned_pa_panics() {
let pa = PA::from_value((1 << 21) + 1);
L2Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rw(false));
}
#[test]
fn test_l2_from_raw_roundtrip() {
let pa = PA::from_value(1 << 21);
let d = L2Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rx(true));
let raw = d.as_raw();
let decoded = L2Descriptor::from_raw(raw);
assert_eq!(decoded.as_raw(), d.as_raw());
assert_eq!(decoded.mapped_address(), d.mapped_address());
assert_eq!(decoded.permissions(), d.permissions());
}
#[test]
fn test_l3_page_creation() {
let pa = PA::from_value(PAGE_SIZE * 10); // 4KiB aligned
let perms = PtePermissions::rx(USER_PERMS);
let d = L3Descriptor::new_map_pa(pa, MemoryType::Normal, perms);
assert!(d.is_valid());
assert_eq!(d.as_raw() & 0b11, 0b11); // Is a page descriptor
// Check address part (bits [47:12])
assert_eq!(
(d.as_raw() >> PAGE_SHIFT),
(pa.value() >> PAGE_SHIFT) as u64
);
// AF bit should be set
assert_ne!(d.as_raw() & (1 << 10), 0);
}
#[test]
fn test_l3_permissions() {
let pa = PA::from_value(PAGE_SIZE);
let d_urx =
L3Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rx(USER_PERMS));
assert_eq!(d_urx.permissions(), Some(PtePermissions::rx(USER_PERMS)));
// Verify XN bit is NOT set for executable
assert_eq!(d_urx.as_raw() & (1 << 54), 0);
}
#[test]
fn test_l3_could_map() {
let good_region = PhysMemoryRegion::new(PA::from_value(PAGE_SIZE), PAGE_SIZE);
let good_va = VA::from_value(PAGE_SIZE * 2);
assert!(L3Descriptor::could_map(good_region, good_va));
// Bad region alignment
let unaligned_region = PhysMemoryRegion::new(PA::from_value(PAGE_SIZE + 1), PAGE_SIZE);
assert!(!L3Descriptor::could_map(unaligned_region, good_va));
}
#[test]
fn test_l3_from_raw_roundtrip() {
let pa = PA::from_value(PAGE_SIZE * 8);
let d = L3Descriptor::new_map_pa(pa, MemoryType::Device, PtePermissions::rw(true));
let raw = d.as_raw();
let decoded = L3Descriptor::from_raw(raw);
assert_eq!(decoded.as_raw(), d.as_raw());
assert_eq!(decoded.mapped_address(), d.mapped_address());
assert_eq!(decoded.permissions(), d.permissions());
}
#[test]
fn test_l2_invalid_descriptor() {
let d = L2Descriptor::invalid();
assert!(!d.is_valid());
assert_eq!(d.as_raw(), 0);
assert!(d.next_table_address().is_none());
assert!(d.mapped_address().is_none());
assert!(d.permissions().is_none());
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,448 @@
use super::{
pg_descriptors::{L3Descriptor, PageTableEntry, TableMapper},
pg_tables::{L0Table, L3Table, PageTableMapper, PgTable, PgTableArray, TableMapperTable},
tlb::{NullTlbInvalidator, TLBInvalidator},
};
use crate::{
error::{MapError, Result},
memory::{
PAGE_SIZE,
address::{TPA, VA},
region::VirtMemoryRegion,
},
};
/// A collection of context required to modify page tables.
pub struct WalkContext<'a, PM>
where
PM: PageTableMapper + 'a,
{
pub mapper: &'a mut PM,
pub invalidator: &'a dyn TLBInvalidator,
}
trait RecursiveWalker: PgTable + Sized {
fn walk<F, PM>(
table_pa: TPA<PgTableArray<Self>>,
region: VirtMemoryRegion,
ctx: &mut WalkContext<PM>,
modifier: &mut F,
) -> Result<()>
where
PM: PageTableMapper,
F: FnMut(VA, L3Descriptor) -> L3Descriptor;
}
impl<T> RecursiveWalker for T
where
T: TableMapperTable,
T::NextLevel: RecursiveWalker,
{
fn walk<F, PM>(
table_pa: TPA<PgTableArray<Self>>,
region: VirtMemoryRegion,
ctx: &mut WalkContext<PM>,
modifier: &mut F,
) -> Result<()>
where
PM: PageTableMapper,
F: FnMut(VA, L3Descriptor) -> L3Descriptor,
{
let table_coverage = 1 << T::SHIFT;
let start_idx = Self::pg_index(region.start_address());
let end_idx = Self::pg_index(region.end_address_inclusive());
// Calculate the base address of the *entire* table.
let table_base_va = region.start_address().align(1 << (T::SHIFT + 9));
for idx in start_idx..=end_idx {
let entry_va = table_base_va.add_bytes(idx * table_coverage);
let desc = unsafe {
ctx.mapper
.with_page_table(table_pa, |pgtable| T::from_ptr(pgtable).get_desc(entry_va))?
};
if let Some(next_desc) = desc.next_table_address() {
let sub_region = VirtMemoryRegion::new(entry_va, table_coverage)
.intersection(region)
.expect("Sub region should overlap with parent region");
T::NextLevel::walk(next_desc.cast(), sub_region, ctx, modifier)?;
} else if desc.is_valid() {
Err(MapError::NotL3Mapped)?
} else {
// Permit sparse mappings.
continue;
}
}
Ok(())
}
}
impl RecursiveWalker for L3Table {
fn walk<F, PM>(
table_pa: TPA<PgTableArray<Self>>,
region: VirtMemoryRegion,
ctx: &mut WalkContext<PM>,
modifier: &mut F,
) -> Result<()>
where
PM: PageTableMapper,
F: FnMut(VA, L3Descriptor) -> L3Descriptor,
{
unsafe {
ctx.mapper.with_page_table(table_pa, |pgtable| {
let table = L3Table::from_ptr(pgtable);
for va in region.iter_pages() {
let desc = table.get_desc(va);
if desc.is_valid() {
table.set_desc(va, modifier(va, desc), ctx.invalidator);
}
}
})
}
}
}
/// Walks the page table hierarchy for a given virtual memory region and applies
/// a modifying closure to every L3 (4KiB page) descriptor within that region.
//
/// # Parameters
/// - `l0_table`: The physical address of the root (L0) page table.
/// - `region`: The virtual memory region to modify. Must be page-aligned.
/// - `ctx`: The context for the operation, including the page table mapper
/// and TLB invalidator.
/// - `modifier`: A closure that will be called for each L3 descriptor found
/// within the `region`. It receives the virtual address of the page and a
/// mutable reference to its `L3Descriptor`.
///
/// # Returns
/// - `Ok(())` on success.
///
/// # Errors
/// - `MapError::VirtNotAligned`: The provided `region` is not page-aligned.
/// - `MapError::NotMapped`: Part of the `region` is not mapped down to the L3
/// level.
/// - `MapError::NotAnL3Mapping`: Part of the `region` is covered by a larger
/// block mapping (1GiB or 2MiB), which cannot be modified at the L3 level.
pub fn walk_and_modify_region<F, PM>(
l0_table: TPA<PgTableArray<L0Table>>,
region: VirtMemoryRegion,
ctx: &mut WalkContext<PM>,
mut modifier: F, // Pass closure as a mutable ref to be used across recursive calls
) -> Result<()>
where
PM: PageTableMapper,
F: FnMut(VA, L3Descriptor) -> L3Descriptor,
{
if !region.is_page_aligned() {
Err(MapError::VirtNotAligned)?;
}
if region.size() == 0 {
return Ok(()); // Nothing to do for an empty region.
}
L0Table::walk(l0_table, region, ctx, &mut modifier)
}
/// Obtain the PTE that mapps the VA into the current address space.
pub fn get_pte<PM: PageTableMapper>(
l0_table: TPA<PgTableArray<L0Table>>,
va: VA,
mapper: &mut PM,
) -> Result<Option<L3Descriptor>> {
let mut descriptor = None;
let mut walk_ctx = WalkContext {
mapper,
// Safe to not invalidate the TLB, as we are not modifying any PTEs.
invalidator: &NullTlbInvalidator {},
};
walk_and_modify_region(
l0_table,
VirtMemoryRegion::new(va.page_aligned(), PAGE_SIZE),
&mut walk_ctx,
|_, pte| {
descriptor = Some(pte);
pte
},
)?;
Ok(descriptor)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arch::arm64::memory::pg_descriptors::{L2Descriptor, MemoryType, PaMapper};
use crate::arch::arm64::memory::pg_tables::tests::TestHarness;
use crate::arch::arm64::memory::pg_tables::{L1Table, L2Table, map_at_level};
use crate::error::KernelError;
use crate::memory::PAGE_SIZE;
use crate::memory::address::{PA, VA};
use crate::memory::permissions::PtePermissions;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn walk_modify_single_page() {
let mut harness = TestHarness::new(10);
let va = VA::from_value(0x1_0000_0000);
let pa = 0x8_0000;
// Map a single page with RO permissions
harness
.map_4k_pages(pa, va.value(), 1, PtePermissions::ro(false))
.unwrap();
harness.verify_perms(va, PtePermissions::ro(false));
// Walk and modify permissions to RW
let mut modifier_was_called = false;
walk_and_modify_region(
harness.l0_table,
VirtMemoryRegion::new(va, PAGE_SIZE),
&mut harness.create_walk_ctx(),
&mut |_va, desc: L3Descriptor| {
modifier_was_called = true;
// Create a new descriptor with new permissions
L3Descriptor::new_map_pa(
desc.mapped_address().unwrap(),
MemoryType::Normal,
PtePermissions::rw(false),
)
},
)
.unwrap();
assert!(modifier_was_called);
harness.verify_perms(va, PtePermissions::rw(false));
}
#[test]
fn walk_contiguous_region_in_one_l3_table() {
let mut harness = TestHarness::new(4);
let num_pages = 10;
let va_start = VA::from_value(0x2_0000_0000);
let pa_start = 0x9_0000;
let region = VirtMemoryRegion::new(va_start, num_pages * PAGE_SIZE);
harness
.map_4k_pages(
pa_start,
va_start.value(),
num_pages,
PtePermissions::ro(false),
)
.unwrap();
// Walk and count the pages modified
let counter = AtomicUsize::new(0);
walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| {
counter.fetch_add(1, Ordering::SeqCst);
desc
},
)
.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), num_pages);
}
#[test]
fn walk_region_spanning_l3_tables() {
let mut harness = TestHarness::new(5);
// This VA range will cross an L2 entry boundary, forcing a walk over
// two L3 tables. L2 entry covers 2MiB. Let's map a region around a 2MiB
// boundary.
let l2_boundary = 1 << L2Table::SHIFT; // 2MiB
let va_start = VA::from_value(l2_boundary - 5 * PAGE_SIZE);
let num_pages = 10;
let region = VirtMemoryRegion::new(va_start, num_pages * PAGE_SIZE);
harness
.map_4k_pages(
0x10_0000,
va_start.value(),
num_pages,
PtePermissions::ro(true),
)
.unwrap();
let counter = AtomicUsize::new(0);
walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| {
counter.fetch_add(1, Ordering::SeqCst);
desc
},
)
.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), num_pages);
}
#[test]
fn walk_region_spanning_l2_tables() {
let mut harness = TestHarness::new(6);
// This VA range will cross an L1 entry boundary, forcing a walk over two L2 tables.
let l1_boundary = 1 << L1Table::SHIFT; // 1GiB
let va_start = VA::from_value(l1_boundary - 5 * PAGE_SIZE);
let num_pages = 10;
let region = VirtMemoryRegion::new(va_start, num_pages * PAGE_SIZE);
harness
.map_4k_pages(
0x20_0000,
va_start.value(),
num_pages,
PtePermissions::ro(false),
)
.unwrap();
let counter = AtomicUsize::new(0);
walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| {
counter.fetch_add(1, Ordering::SeqCst);
desc
},
)
.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), num_pages);
}
#[test]
fn walk_sparse_region() {
let mut harness = TestHarness::new(10);
let va1 = VA::from_value(0x3_0000_0000);
let va2 = va1.add_pages(2);
let va3 = va1.add_pages(4);
// Map three pages with a "hole" in between
harness
.map_4k_pages(0x30000, va1.value(), 1, PtePermissions::ro(false))
.unwrap();
harness
.map_4k_pages(0x40000, va2.value(), 1, PtePermissions::ro(false))
.unwrap();
harness
.map_4k_pages(0x50000, va3.value(), 1, PtePermissions::ro(false))
.unwrap();
let counter = AtomicUsize::new(0);
let entire_region = VirtMemoryRegion::new(va1, 5 * PAGE_SIZE);
// Walk should succeed and only call the modifier for the valid pages
walk_and_modify_region(
harness.l0_table,
entire_region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| {
counter.fetch_add(1, Ordering::SeqCst);
desc
},
)
.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[test]
fn walk_into_block_mapping_fails() {
let mut harness = TestHarness::new(10);
let va = VA::from_value(0x4_0000_0000);
let pa = PA::from_value(0x80_0000); // 2MiB aligned
// Manually create a 2MiB block mapping
let l1 = map_at_level(harness.l0_table, va, &mut harness.create_map_ctx()).unwrap();
let l2 = map_at_level(l1, va, &mut harness.create_map_ctx()).unwrap();
let l2_desc = L2Descriptor::new_map_pa(pa, MemoryType::Normal, PtePermissions::rw(false));
unsafe {
harness
.mapper
.with_page_table(l2, |l2_tbl| {
let table = L2Table::from_ptr(l2_tbl);
table.set_desc(va, l2_desc, &harness.invalidator);
})
.unwrap();
}
let region = VirtMemoryRegion::new(va, PAGE_SIZE);
let result = walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| desc,
);
assert!(matches!(
result,
Err(crate::error::KernelError::MappingError(
MapError::NotL3Mapped
))
));
}
#[test]
fn walk_unmapped_region_does_nothing() {
let mut harness = TestHarness::new(10);
let region = VirtMemoryRegion::new(VA::from_value(0xDEADBEEF000), PAGE_SIZE);
let counter = AtomicUsize::new(0);
let result = walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| {
counter.fetch_add(1, Ordering::SeqCst);
desc
},
);
// The walk should succeed because it just finds nothing to modify.
assert!(result.is_ok());
// Crucially, the modifier should never have been called.
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[test]
fn walk_empty_region() {
let mut harness = TestHarness::new(10);
let region = VirtMemoryRegion::new(VA::from_value(0x5_0000_0000), 0); // Zero size
let result = walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, _desc| panic!("Modifier should not be called for empty region"),
);
assert!(result.is_ok());
}
#[test]
fn walk_unaligned_region_fails() {
let mut harness = TestHarness::new(10);
let region = VirtMemoryRegion::new(VA::from_value(123), PAGE_SIZE); // Not page-aligned
let result = walk_and_modify_region(
harness.l0_table,
region,
&mut harness.create_walk_ctx(),
&mut |_va, desc| desc,
);
assert!(matches!(
result,
Err(KernelError::MappingError(MapError::VirtNotAligned))
));
}
}

View File

@@ -0,0 +1,5 @@
pub trait TLBInvalidator {}
pub struct NullTlbInvalidator {}
impl TLBInvalidator for NullTlbInvalidator {}

View File

@@ -0,0 +1 @@
pub mod memory;

View File

@@ -0,0 +1 @@
pub mod arm64;

5
libkernel/src/driver.rs Normal file
View File

@@ -0,0 +1,5 @@
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct CharDevDescriptor {
pub major: u64,
pub minor: u64,
}

178
libkernel/src/error.rs Normal file
View File

@@ -0,0 +1,178 @@
use core::convert::Infallible;
use thiserror::Error;
pub mod syscall_error;
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ProbeError {
#[error("No registers present in FDT")]
NoReg,
#[error("No register bank size in FDT")]
NoRegSize,
#[error("No interrupts in FDT")]
NoInterrupts,
#[error("No parent interrupt controller in FDT")]
NoParentIntterupt,
#[error("The specified interrupt parent isn't an interrupt controller")]
NotInterruptController,
// Driver probing should be tried again after other probes have succeeded.
#[error("Driver probing deferred for other dependencies")]
Deferred,
}
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum MapError {
#[error("Physical address not page aligned")]
PhysNotAligned,
#[error("Physical address not page aligned")]
VirtNotAligned,
#[error("Physical and virtual range sizes do not match")]
SizeMismatch,
#[error("Failed to walk to the next level page table")]
WalkFailed,
#[error("Invalid page table descriptor encountered")]
InvalidDescriptor,
#[error("The region to be mapped is smaller than PAGE_SIZE")]
TooSmall,
#[error("The VA range is has already been mapped")]
AlreadyMapped,
#[error("Page table does not contain an L3 mapping")]
NotL3Mapped,
}
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum IoError {
#[error("The requested I/O operation was out of bounds for the block device")]
OutOfBounds,
#[error("Courruption found in the filesystem metadata")]
MetadataCorruption,
}
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum FsError {
#[error("The path or file was not found")]
NotFound,
#[error("The path component is not a directory.")]
NotADirectory,
#[error("The path component is a directory.")]
IsADirectory,
#[error("The file or directory already exists.")]
AlreadyExists,
#[error("Invalid input parameters.")]
InvalidInput,
#[error("The filesystem is corrupted or has an invalid format.")]
InvalidFs,
#[error("Attempted to access data out of bounds.")]
OutOfBounds,
#[error("The operation is not permitted.")]
PermissionDenied,
#[error("Could not find the specified FS driver")]
DriverNotFound,
#[error("Too many open files")]
TooManyFiles,
#[error("The device could not be found")]
NoDevice,
}
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum ExecError {
#[error("Invalid ELF Format")]
InvalidElfFormat,
#[error("Invalid Porgram Header Format")]
InvalidPHdrFormat,
}
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum KernelError {
#[error("Cannot allocate memory")]
NoMemory,
#[error("Memory region not found")]
NoMemRegion,
#[error("Invalid value")]
InvalidValue,
#[error("The current resource is already in use")]
InUse,
#[error("Page table mapping failed: {0}")]
MappingError(#[from] MapError),
#[error("Provided object is too large")]
TooLarge,
#[error("Operation not supported")]
NotSupported,
#[error("Device probe failed: {0}")]
Probe(#[from] ProbeError),
#[error("I/O operation failed: {0}")]
Io(#[from] IoError),
#[error("Filesystem operation failed: {0}")]
Fs(#[from] FsError),
#[error("Exec error: {0}")]
Exec(#[from] ExecError),
#[error("Not a tty")]
NotATty,
#[error("Fault errror during syscall")]
Fault,
#[error("Not an open file descriptor")]
BadFd,
#[error("Cannot seek on a pipe")]
SeekPipe,
#[error("Broken pipe")]
BrokenPipe,
#[error("Operation not permitted")]
NotPermitted,
#[error("Buffer is full")]
BufferFull,
#[error("No such process")]
NoProcess,
#[error("{0}")]
Other(&'static str),
}
pub type Result<T> = core::result::Result<T, KernelError>;
impl From<Infallible> for KernelError {
fn from(error: Infallible) -> Self {
match error {}
}
}

View File

@@ -0,0 +1,52 @@
use crate::error::FsError;
use super::KernelError;
pub const EPERM: isize = -1;
pub const ENOENT: isize = -2;
pub const ESRCH: isize = -3;
pub const EINTR: isize = -4;
pub const EIO: isize = -5;
pub const ENXIO: isize = -6;
pub const E2BIG: isize = -7;
pub const ENOEXEC: isize = -8;
pub const EBADF: isize = -9;
pub const ECHILD: isize = -10;
pub const EAGAIN: isize = -11;
pub const ENOMEM: isize = -12;
pub const EACCES: isize = -13;
pub const EFAULT: isize = -14;
pub const ENOTBLK: isize = -15;
pub const EBUSY: isize = -16;
pub const EEXIST: isize = -17;
pub const EXDEV: isize = -18;
pub const ENODEV: isize = -19;
pub const ENOTDIR: isize = -20;
pub const EISDIR: isize = -21;
pub const EINVAL: isize = -22;
pub const ENFILE: isize = -23;
pub const EMFILE: isize = -24;
pub const ENOTTY: isize = -25;
pub const ETXTBSY: isize = -26;
pub const EFBIG: isize = -27;
pub const ENOSPC: isize = -28;
pub const ESPIPE: isize = -29;
pub const EROFS: isize = -30;
pub const EMLINK: isize = -31;
pub const EPIPE: isize = -32;
pub const EDOM: isize = -33;
pub const ERANGE: isize = -34;
pub const EWOULDBLOCK: isize = -EAGAIN;
pub fn kern_err_to_syscall(err: KernelError) -> isize {
match err {
KernelError::BadFd => EBADF,
KernelError::InvalidValue => EINVAL,
KernelError::Fault => EFAULT,
KernelError::BrokenPipe => EPIPE,
KernelError::Fs(FsError::NotFound) => ENOENT,
KernelError::NotATty => ENOTTY,
KernelError::SeekPipe => ESPIPE,
_ => todo!(),
}
}

327
libkernel/src/fs/attr.rs Normal file
View File

@@ -0,0 +1,327 @@
use crate::{
error::{KernelError, Result},
proc::ids::{Gid, Uid},
};
use super::{FileType, InodeId};
use bitflags::bitflags;
use core::time::Duration;
bitflags::bitflags! {
#[derive(Debug, Clone, Copy)]
pub struct AccessMode: i32 {
/// Execution is permitted
const X_OK = 1;
/// Writing is permitted
const W_OK = 2;
/// Reading is permitted
const R_OK = 4;
}
}
bitflags! {
#[derive(Clone, Copy, Debug)]
pub struct FilePermissions: u16 {
// Owner permissions
const S_IRUSR = 0o400; // Read permission, owner
const S_IWUSR = 0o200; // Write permission, owner
const S_IXUSR = 0o100; // Execute/search permission, owner
// Group permissions
const S_IRGRP = 0o040; // Read permission, group
const S_IWGRP = 0o020; // Write permission, group
const S_IXGRP = 0o010; // Execute/search permission, group
// Others permissions
const S_IROTH = 0o004; // Read permission, others
const S_IWOTH = 0o002; // Write permission, others
const S_IXOTH = 0o001; // Execute/search permission, others
// Optional: sticky/setuid/setgid bits
const S_ISUID = 0o4000; // Set-user-ID on execution
const S_ISGID = 0o2000; // Set-group-ID on execution
const S_ISVTX = 0o1000; // Sticky bit
}
}
/// Represents file metadata, similar to `stat`.
#[derive(Debug, Clone)]
pub struct FileAttr {
pub id: InodeId,
pub size: u64,
pub block_size: u32,
pub blocks: u64,
pub atime: Duration, // Access time (e.g., seconds since epoch)
pub mtime: Duration, // Modification time
pub ctime: Duration, // Change time
pub file_type: FileType,
pub mode: FilePermissions,
pub nlinks: u32,
pub uid: Uid,
pub gid: Gid,
}
impl Default for FileAttr {
fn default() -> Self {
Self {
id: InodeId::dummy(),
size: 0,
block_size: 0,
blocks: 0,
atime: Duration::new(0, 0),
mtime: Duration::new(0, 0),
ctime: Duration::new(0, 0),
file_type: FileType::File,
mode: FilePermissions::empty(),
nlinks: 1,
uid: Uid::new_root(),
gid: Gid::new_root_group(),
}
}
}
impl FileAttr {
/// Checks if a given set of credentials has the requested access permissions for this file.
///
/// # Arguments
/// * `uid` - The user-ID that will be checked against this file's uid field.
/// * `gid` - The group-ID that will be checked against this file's uid field.
/// * `requested_mode` - A bitmask of `AccessMode` flags (`R_OK`, `W_OK`, `X_OK`) to check.
pub fn check_access(&self, uid: Uid, gid: Gid, requested_mode: AccessMode) -> Result<()> {
// root (UID 0) bypasses most permission checks. For execute, at
// least one execute bit must be set.
if uid.is_root() {
if requested_mode.contains(AccessMode::X_OK) {
// Root still needs at least one execute bit to be set for X_OK
if self.mode.intersects(
FilePermissions::S_IXUSR | FilePermissions::S_IXGRP | FilePermissions::S_IXOTH,
) {
return Ok(());
}
} else {
return Ok(());
}
}
// Determine which set of permission bits to use (owner, group, or other)
let perms_to_check = if self.uid == uid {
// User is the owner
self.mode
} else if self.gid == gid {
// User is in the file's group. Shift group bits to align with owner bits for easier checking.
FilePermissions::from_bits_truncate(self.mode.bits() << 3)
} else {
// Others. Shift other bits to align with owner bits.
FilePermissions::from_bits_truncate(self.mode.bits() << 6)
};
if requested_mode.contains(AccessMode::R_OK)
&& !perms_to_check.contains(FilePermissions::S_IRUSR)
{
return Err(KernelError::NotPermitted);
}
if requested_mode.contains(AccessMode::W_OK)
&& !perms_to_check.contains(FilePermissions::S_IWUSR)
{
return Err(KernelError::NotPermitted);
}
if requested_mode.contains(AccessMode::X_OK)
&& !perms_to_check.contains(FilePermissions::S_IXUSR)
{
return Err(KernelError::NotPermitted);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::KernelError;
const ROOT_UID: Uid = Uid::new(0);
const ROOT_GID: Gid = Gid::new(0);
const OWNER_UID: Uid = Uid::new(1000);
const OWNER_GID: Gid = Gid::new(1000);
const GROUP_MEMBER_UID: Uid = Uid::new(1001);
const FILE_GROUP_GID: Gid = Gid::new(2000);
const OTHER_UID: Uid = Uid::new(1002);
const OTHER_GID: Gid = Gid::new(3000);
fn setup_file(mode: FilePermissions) -> FileAttr {
FileAttr {
uid: OWNER_UID,
gid: FILE_GROUP_GID,
mode,
..Default::default()
}
}
#[test]
fn root_can_read_without_perms() {
let file = setup_file(FilePermissions::empty());
assert!(
file.check_access(ROOT_UID, ROOT_GID, AccessMode::R_OK)
.is_ok()
);
}
#[test]
fn root_can_write_without_perms() {
let file = setup_file(FilePermissions::empty());
assert!(
file.check_access(ROOT_UID, ROOT_GID, AccessMode::W_OK)
.is_ok()
);
}
#[test]
fn root_cannot_execute_if_no_exec_bits_are_set() {
let file = setup_file(FilePermissions::S_IRUSR | FilePermissions::S_IWUSR);
let result = file.check_access(ROOT_UID, ROOT_GID, AccessMode::X_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn root_can_execute_if_owner_exec_bit_is_set() {
let file = setup_file(FilePermissions::S_IXUSR);
assert!(
file.check_access(ROOT_UID, ROOT_GID, AccessMode::X_OK)
.is_ok()
);
}
#[test]
fn root_can_execute_if_group_exec_bit_is_set() {
let file = setup_file(FilePermissions::S_IXGRP);
assert!(
file.check_access(ROOT_UID, ROOT_GID, AccessMode::X_OK)
.is_ok()
);
}
#[test]
fn root_can_execute_if_other_exec_bit_is_set() {
let file = setup_file(FilePermissions::S_IXOTH);
assert!(
file.check_access(ROOT_UID, ROOT_GID, AccessMode::X_OK)
.is_ok()
);
}
#[test]
fn owner_can_read_when_permitted() {
let file = setup_file(FilePermissions::S_IRUSR);
assert!(
file.check_access(OWNER_UID, OWNER_GID, AccessMode::R_OK)
.is_ok()
);
}
#[test]
fn owner_cannot_read_when_denied() {
let file = setup_file(FilePermissions::S_IWUSR | FilePermissions::S_IXUSR);
let result = file.check_access(OWNER_UID, OWNER_GID, AccessMode::R_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn owner_can_write_when_permitted() {
let file = setup_file(FilePermissions::S_IWUSR);
assert!(
file.check_access(OWNER_UID, OWNER_GID, AccessMode::W_OK)
.is_ok()
);
}
#[test]
fn owner_cannot_write_when_denied() {
let file = setup_file(FilePermissions::S_IRUSR);
let result = file.check_access(OWNER_UID, OWNER_GID, AccessMode::W_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn owner_can_read_write_execute_when_permitted() {
let file = setup_file(
FilePermissions::S_IRUSR | FilePermissions::S_IWUSR | FilePermissions::S_IXUSR,
);
let mode = AccessMode::R_OK | AccessMode::W_OK | AccessMode::X_OK;
assert!(file.check_access(OWNER_UID, OWNER_GID, mode).is_ok());
}
#[test]
fn owner_access_denied_if_one_of_many_perms_is_missing() {
let file = setup_file(FilePermissions::S_IRUSR | FilePermissions::S_IXUSR);
let mode = AccessMode::R_OK | AccessMode::W_OK | AccessMode::X_OK; // Requesting Write is denied
let result = file.check_access(OWNER_UID, OWNER_GID, mode);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn group_member_can_read_when_permitted() {
let file = setup_file(FilePermissions::S_IRGRP);
assert!(
file.check_access(GROUP_MEMBER_UID, FILE_GROUP_GID, AccessMode::R_OK)
.is_ok()
);
}
#[test]
fn group_member_cannot_write_when_owner_can() {
let file = setup_file(FilePermissions::S_IWUSR | FilePermissions::S_IRGRP);
let result = file.check_access(GROUP_MEMBER_UID, FILE_GROUP_GID, AccessMode::W_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn group_member_cannot_read_when_denied() {
let file = setup_file(FilePermissions::S_IWGRP);
let result = file.check_access(GROUP_MEMBER_UID, FILE_GROUP_GID, AccessMode::R_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn other_can_execute_when_permitted() {
let file = setup_file(FilePermissions::S_IXOTH);
assert!(
file.check_access(OTHER_UID, OTHER_GID, AccessMode::X_OK)
.is_ok()
);
}
#[test]
fn other_cannot_read_when_only_owner_and_group_can() {
let file = setup_file(FilePermissions::S_IRUSR | FilePermissions::S_IRGRP);
let result = file.check_access(OTHER_UID, OTHER_GID, AccessMode::R_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn other_cannot_write_when_denied() {
let file = setup_file(FilePermissions::S_IROTH);
let result = file.check_access(OTHER_UID, OTHER_GID, AccessMode::W_OK);
assert!(matches!(result, Err(KernelError::NotPermitted)));
}
#[test]
fn no_requested_mode_is_always_ok() {
// Checking for nothing should always succeed if the file exists.
let file = setup_file(FilePermissions::empty());
assert!(
file.check_access(OTHER_UID, OTHER_GID, AccessMode::empty())
.is_ok()
);
}
#[test]
fn user_in_different_group_is_treated_as_other() {
let file = setup_file(FilePermissions::S_IROTH); // Only other can read
// This user is not the owner and not in the file's group.
assert!(
file.check_access(GROUP_MEMBER_UID, OTHER_GID, AccessMode::R_OK)
.is_ok()
);
}
}

View File

@@ -0,0 +1,113 @@
use core::{mem, slice};
use crate::{error::Result, fs::BlockDevice, pod::Pod};
use alloc::{boxed::Box, vec};
/// A buffer that provides byte-level access to an underlying BlockDevice.
///
/// This layer handles the logic of translating byte offsets and lengths into
/// block-based operations, including handling requests that span multiple
/// blocks or are not aligned to block boundaries.
///
/// TODO: Cache blocks.
pub struct BlockBuffer {
dev: Box<dyn BlockDevice>,
block_size: usize,
}
impl BlockBuffer {
/// Creates a new `BlockBuffer` that wraps the given block device.
pub fn new(dev: Box<dyn BlockDevice>) -> Self {
let block_size = dev.block_size();
Self { dev, block_size }
}
/// Reads a sequence of bytes starting at a specific offset.
pub async fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<()> {
let len = buf.len();
if len == 0 {
return Ok(());
}
let start_block = offset / self.block_size as u64;
let end_offset = offset + len as u64;
let end_block = (end_offset - 1) / self.block_size as u64;
let num_blocks_to_read = end_block - start_block + 1;
let mut temp_buf = vec![0; num_blocks_to_read as usize * self.block_size];
self.dev.read(start_block, &mut temp_buf).await?;
let start_in_temp_buf = (offset % self.block_size as u64) as usize;
let end_in_temp_buf = start_in_temp_buf + len;
buf.copy_from_slice(&temp_buf[start_in_temp_buf..end_in_temp_buf]);
Ok(())
}
/// Reads a `Pod` struct directly from the device at a given offset.
pub async fn read_obj<T: Pod>(&self, offset: u64) -> Result<T> {
let mut dest = mem::MaybeUninit::<T>::uninit();
// SAFETY: We create a mutable byte slice that points to our
// uninitialized stack space. This is safe because:
// 1. The pointer is valid and properly aligned for T.
// 2. The size is correct for T.
// 3. We are only writing to this slice, not reading from it yet.
let buf: &mut [u8] =
unsafe { slice::from_raw_parts_mut(dest.as_mut_ptr() as *mut u8, mem::size_of::<T>()) };
// Read directly from the device into our stack-allocated space.
self.read_at(offset, buf).await?;
// SAFETY: The `read_at` call has now filled the buffer with bytes from
// the device. Since `T` is `Pod`, any combination of bytes is a valid
// `T`, so we can now safely assume it is initialized.
Ok(unsafe { dest.assume_init() })
}
/// Writes a sequence of bytes starting at a specific offset.
///
/// NOTE: This is a simple but potentially inefficient implementation that
/// uses a read-modify-write approach for all writes.
pub async fn write_at(&mut self, offset: u64, buf: &[u8]) -> Result<()> {
let len = buf.len();
if len == 0 {
return Ok(());
}
let start_block = offset / self.block_size as u64;
let end_offset = offset + len as u64;
let end_block = (end_offset - 1) / self.block_size as u64;
let num_blocks_to_rw = end_block - start_block + 1;
let mut temp_buf = vec![0; num_blocks_to_rw as usize * self.block_size];
// Read all affected blocks from the device into our temporary buffer.
// This preserves the data in the blocks that we are not modifying.
self.dev.read(start_block, &mut temp_buf).await?;
// Copy the user's data into the correct position in our temporary
// buffer.
let start_in_temp_buf = (offset % self.block_size as u64) as usize;
let end_in_temp_buf = start_in_temp_buf + len;
temp_buf[start_in_temp_buf..end_in_temp_buf].copy_from_slice(buf);
// Write the entire modified buffer back to the device.
self.dev.write(start_block, &temp_buf).await?;
Ok(())
}
/// Forwards a sync call to the underlying device.
pub async fn sync(&self) -> Result<()> {
self.dev.sync().await
}
}

View File

@@ -0,0 +1,2 @@
pub mod buffer;
pub mod ramdisk;

View File

@@ -0,0 +1,113 @@
use crate::{
KernAddressSpace,
error::{IoError, KernelError, Result},
fs::BlockDevice,
memory::{
PAGE_SIZE,
address::{TVA, VA},
permissions::PtePermissions,
region::{PhysMemoryRegion, VirtMemoryRegion},
},
};
use alloc::boxed::Box;
use async_trait::async_trait;
use core::ptr;
pub struct RamdiskBlkDev {
base: TVA<u8>,
num_blocks: u64,
}
const BLOCK_SIZE: usize = PAGE_SIZE;
impl RamdiskBlkDev {
/// Creates a new ramdisk.
///
/// Maps the given physical memory region into the kernel's address space at
/// the specified virtual base address.
pub fn new<K: KernAddressSpace>(
region: PhysMemoryRegion,
base: VA,
kern_addr_spc: &mut K,
) -> Result<Self> {
kern_addr_spc.map_normal(
region,
VirtMemoryRegion::new(base, region.size()),
PtePermissions::rw(false),
)?;
if !region.size().is_multiple_of(BLOCK_SIZE) {
return Err(KernelError::InvalidValue);
}
let num_blocks = (region.size() / BLOCK_SIZE) as u64;
Ok(Self {
base: TVA::from_value(base.value()),
num_blocks,
})
}
}
#[async_trait]
impl BlockDevice for RamdiskBlkDev {
/// Read one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn read(&self, block_id: u64, buf: &mut [u8]) -> Result<()> {
debug_assert!(buf.len().is_multiple_of(BLOCK_SIZE));
let num_blocks_to_read = (buf.len() / BLOCK_SIZE) as u64;
// Ensure the read doesn't go past the end of the ramdisk.
if block_id + num_blocks_to_read > self.num_blocks {
return Err(IoError::OutOfBounds.into());
}
let offset = block_id as usize * BLOCK_SIZE;
unsafe {
// SAFETY: VA can be accessed:
//
// 1. We have successfully mapped the ramdisk into virtual memory,
// starting at base.
// 2. We have bounds checked the access.
let src_ptr = self.base.as_ptr().add(offset);
ptr::copy_nonoverlapping(src_ptr, buf.as_mut_ptr(), buf.len());
}
Ok(())
}
/// Write one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn write(&self, block_id: u64, buf: &[u8]) -> Result<()> {
debug_assert!(buf.len().is_multiple_of(BLOCK_SIZE));
let num_blocks_to_write = (buf.len() / BLOCK_SIZE) as u64;
if block_id + num_blocks_to_write > self.num_blocks {
return Err(IoError::OutOfBounds.into());
}
let offset = block_id as usize * BLOCK_SIZE;
unsafe {
let dest_ptr = self.base.as_ptr_mut().add(offset);
ptr::copy_nonoverlapping(buf.as_ptr(), dest_ptr, buf.len());
}
Ok(())
}
/// The size of a single block in bytes.
fn block_size(&self) -> usize {
BLOCK_SIZE
}
/// Flushes any caches to the underlying device.
async fn sync(&self) -> Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,260 @@
use crate::{
error::{FsError, Result},
fs::blk::buffer::BlockBuffer,
pod::Pod,
};
use log::warn;
use super::{Cluster, Sector};
#[repr(C, packed)]
#[derive(Debug)]
pub struct BiosParameterBlock {
_jump: [u8; 3],
_oem_id: [u8; 8],
/* DOS 2.0 BPB */
pub bytes_per_sector: u16,
pub sectors_per_cluster: u8,
// Number of sectors in the Reserved Region. Usually 32.
pub reserved_sector_count: u16,
pub num_fats: u8,
_root_entry_count: u16,
_total_sectors_16: u16,
_media_type: u8,
_fat_size_16: u16,
_sectors_per_track: u16,
_head_count: u16,
_hidden_sector_count: u32,
_total_sectors_32: u32,
/* FAT32 Extended BPB */
// The size of ONE FAT in sectors.
pub fat_size_32: u32,
_ext_flags: u16,
_fs_version: u16,
// The cluster number where the root directory starts.
pub root_cluster: Cluster,
pub fsinfo_sector: u16,
// More stuff. Ignored, for now.
}
unsafe impl Pod for BiosParameterBlock {}
impl BiosParameterBlock {
pub async fn new(dev: &BlockBuffer) -> Result<Self> {
let bpb: Self = dev.read_obj(0).await?;
if bpb._fat_size_16 != 0 || bpb._root_entry_count != 0 {
warn!("Not a FAT32 volume (FAT16 fields are non-zero)");
return Err(FsError::InvalidFs.into());
}
if bpb.fat_size_32 == 0 {
warn!("FAT32 size is zero");
return Err(FsError::InvalidFs.into());
}
if bpb.num_fats == 0 {
warn!("Volume has 0 FATs, which is invalid.");
return Err(FsError::InvalidFs.into());
}
let bytes_per_sector = bpb.bytes_per_sector;
match bytes_per_sector {
512 | 1024 | 2048 | 4096 => {} // Good!
_ => {
warn!(
"Bytes per sector {} is not a valid value (must be 512, 1024, 2048, or 4096).",
bytes_per_sector
);
return Err(FsError::InvalidFs.into());
}
}
if !bpb.bytes_per_sector.is_power_of_two() {
let bytes_per_sector = bpb.bytes_per_sector;
warn!(
"Bytes per sector 0x{:X} not a power of two.",
bytes_per_sector
);
return Err(FsError::InvalidFs.into());
}
if !bpb.sectors_per_cluster.is_power_of_two() {
warn!(
"Sectors per cluster 0x{:X} not a power of two.",
bpb.sectors_per_cluster
);
return Err(FsError::InvalidFs.into());
}
if !bpb.root_cluster.is_valid() {
let root_cluster = bpb.root_cluster;
warn!("Root cluster {} < 2.", root_cluster);
return Err(FsError::InvalidFs.into());
}
Ok(bpb)
}
pub fn sector_offset(&self, sector: Sector) -> u64 {
sector.0 as u64 * self.bytes_per_sector as u64
}
pub fn fat_region(&self, fat_number: usize) -> Option<(Sector, Sector)> {
if fat_number >= self.num_fats as _ {
None
} else {
let start = self.fat_region_start() + self.fat_len() * fat_number;
let end = start + self.fat_len();
Some((start, end))
}
}
pub fn fat_region_start(&self) -> Sector {
Sector(self.reserved_sector_count as _)
}
pub fn fat_len(&self) -> Sector {
Sector(self.fat_size_32 as _)
}
pub fn data_region_start(&self) -> Sector {
self.fat_region_start() + self.fat_len() * self.num_fats as usize
}
pub fn sector_size(&self) -> usize {
self.bytes_per_sector as _
}
pub fn cluster_to_sectors(&self, cluster: Cluster) -> Result<impl Iterator<Item = Sector>> {
if cluster.0 < 2 {
warn!("Cannot conver sentinel cluster number");
Err(FsError::InvalidFs.into())
} else {
let root_sector = Sector(
self.data_region_start().0 + (cluster.0 - 2) * self.sectors_per_cluster as u32,
);
Ok(root_sector.sectors_until(Sector(root_sector.0 + self.sectors_per_cluster as u32)))
}
}
}
#[cfg(test)]
pub mod test {
use super::{BiosParameterBlock, Cluster, Sector};
// A helper to create a typical FAT32 BPB for testing.
pub fn create_test_bpb() -> BiosParameterBlock {
BiosParameterBlock {
_jump: [0; 3],
_oem_id: [0; 8],
bytes_per_sector: 512,
sectors_per_cluster: 8,
reserved_sector_count: 32,
num_fats: 2,
_root_entry_count: 0,
_total_sectors_16: 0,
_media_type: 0,
_fat_size_16: 0,
_sectors_per_track: 0,
_head_count: 0,
_hidden_sector_count: 0,
_total_sectors_32: 0,
fat_size_32: 1000, // Size of ONE FAT in sectors
_ext_flags: 0,
_fs_version: 0,
root_cluster: Cluster(2),
fsinfo_sector: 0,
}
}
#[test]
fn sector_iter() {
let sec = Sector(3);
let mut iter = sec.sectors_until(Sector(6));
assert_eq!(iter.next(), Some(Sector(3)));
assert_eq!(iter.next(), Some(Sector(4)));
assert_eq!(iter.next(), Some(Sector(5)));
assert_eq!(iter.next(), None);
}
#[test]
fn cluster_validity() {
assert!(!Cluster(0).is_valid());
assert!(!Cluster(1).is_valid());
assert!(Cluster(2).is_valid());
assert!(Cluster(u32::MAX).is_valid());
}
#[test]
fn fat_layout_calculations() {
let bpb = create_test_bpb();
// The first FAT should start immediately after the reserved region.
assert_eq!(bpb.fat_region_start(), Sector(32));
assert_eq!(bpb.fat_len(), Sector(1000));
// 32 (reserved) + 2 * 1000 (fats) = 2032
assert_eq!(bpb.data_region_start(), Sector(2032));
}
#[test]
fn fat_region_lookup() {
let bpb = create_test_bpb();
// First FAT (FAT 0)
let (start0, end0) = bpb.fat_region(0).unwrap();
assert_eq!(start0, Sector(32));
assert_eq!(end0, Sector(32 + 1000));
// Second FAT (FAT 1)
let (start1, end1) = bpb.fat_region(1).unwrap();
assert_eq!(start1, Sector(32 + 1000));
assert_eq!(end1, Sector(32 + 1000 + 1000));
// There is no third FAT
assert!(bpb.fat_region(2).is_none());
}
#[test]
fn cluster_to_sector_conversion() {
let bpb = create_test_bpb();
let data_start = bpb.data_region_start(); // 2032
let spc = bpb.sectors_per_cluster as u32; // 8
// Cluster 2 is the first data cluster. It should start at the beginning
// of the data region.
let mut cluster2_sectors = bpb.cluster_to_sectors(Cluster(2)).unwrap();
assert_eq!(cluster2_sectors.next(), Some(Sector(data_start.0))); // 2032
assert_eq!(
cluster2_sectors.last(),
Some(Sector(data_start.0 + spc - 1))
); // 2039
// Cluster 3 is the second data cluster.
let mut cluster3_sectors = bpb.cluster_to_sectors(Cluster(3)).unwrap();
assert_eq!(cluster3_sectors.next(), Some(Sector(data_start.0 + spc))); // 2040
assert_eq!(
cluster3_sectors.last(),
Some(Sector(data_start.0 + 2 * spc - 1))
); // 2047
}
#[test]
fn cluster_to_sector_invalid_input() {
let bpb = create_test_bpb();
// Clusters 0 and 1 are reserved and should not be converted to sectors.
assert!(matches!(bpb.cluster_to_sectors(Cluster(0)), Err(_)));
assert!(matches!(bpb.cluster_to_sectors(Cluster(1)), Err(_)));
}
}

View File

@@ -0,0 +1,671 @@
use core::ptr;
use super::{Cluster, Fat32Operations, file::Fat32FileNode, reader::Fat32Reader};
use crate::{
error::{FsError, KernelError, Result},
fs::{
DirStream, Dirent, FileType, Inode, InodeId,
attr::{FileAttr, FilePermissions},
},
};
use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
use async_trait::async_trait;
use log::warn;
bitflags::bitflags! {
#[derive(Clone, Copy, Debug)]
struct Fat32Attributes: u8 {
const READ_ONLY = 0x01;
const HIDDEN = 0x02;
const SYSTEM = 0x04;
const VOLUME_LABEL = 0x08;
const DIRECTORY = 0x10;
const ARCHIVE = 0x20;
const DEVICE = 0x40;
}
}
impl TryFrom<Fat32Attributes> for FileType {
type Error = KernelError;
fn try_from(value: Fat32Attributes) -> Result<Self> {
if value.contains(Fat32Attributes::DIRECTORY) {
Ok(FileType::Directory)
} else if value.contains(Fat32Attributes::ARCHIVE) {
Ok(FileType::File)
} else {
warn!("Entry is neither a regular file nor a directory. Ignoring.");
Err(FsError::InvalidFs.into())
}
}
}
#[derive(Clone, Copy)]
#[repr(C, packed)]
struct LfnEntry {
sequence_number: u8,
name1: [u16; 5], // Chars 1-5, UTF-16 LE
attributes: u8, // Always 0x0F
entry_type: u8, // Always 0x00
checksum: u8, // Checksum of the 8.3 name
name2: [u16; 6], // Chars 6-11
first_cluster: u16, // Always 0x0000
name3: [u16; 2], // Chars 12-13
}
impl LfnEntry {
fn extract_chars(&self) -> Vec<u16> {
let mut chars = Vec::with_capacity(13);
unsafe {
let name1 = ptr::read_unaligned(&raw const self.name1);
let name2 = ptr::read_unaligned(&raw const self.name2);
let name3 = ptr::read_unaligned(&raw const self.name3);
chars.extend_from_slice(&name1);
chars.extend_from_slice(&name2);
chars.extend_from_slice(&name3);
}
chars
}
}
#[derive(Clone, Copy, Debug)]
#[repr(C, packed)]
struct DirEntry {
dos_file_name: [u8; 8],
dos_extension: [u8; 3],
attributes: Fat32Attributes,
_reserved: u8,
ctime_ms: u8,
ctime: u16,
cdate: u16,
adate: u16,
clust_high: u16,
mtime: u16,
mdate: u16,
clust_low: u16,
size: u32,
}
impl DirEntry {
pub fn parse_filename(&self) -> String {
let name_part = self
.dos_file_name
.iter()
.position(|&c| c == b' ')
.unwrap_or(self.dos_file_name.len());
// Find the end of the extension part
let ext_part = self
.dos_extension
.iter()
.position(|&c| c == b' ')
.unwrap_or(self.dos_extension.len());
let mut result = String::from_utf8_lossy(&self.dos_file_name[..name_part]).into_owned();
if ext_part > 0 {
result.push('.');
result.push_str(&String::from_utf8_lossy(&self.dos_extension[..ext_part]));
}
result
}
}
struct Fat32DirEntry {
attr: FileAttr,
cluster: Cluster,
name: String,
offset: u64,
}
struct Fat32DirStream<T: Fat32Operations> {
reader: Fat32Reader<T>,
offset: u64,
lfn_buffer: Vec<u16>,
fs_id: u64,
}
impl<T: Fat32Operations> Clone for Fat32DirStream<T> {
fn clone(&self) -> Self {
Self {
reader: self.reader.clone(),
offset: self.offset,
lfn_buffer: self.lfn_buffer.clone(),
fs_id: self.fs_id,
}
}
}
impl<T: Fat32Operations> Fat32DirStream<T> {
pub fn new(fs: Arc<T>, root: Cluster) -> Self {
let max_sz = fs.iter_clusters(root).count() as u64 * fs.bytes_per_cluster() as u64;
let fs_id = fs.id();
// For directory nodes, the size is 0. In our case, fake the size to be
// the number of clusters in the chain such that we never read past the
// end.
Self {
reader: Fat32Reader::new(fs, root, max_sz),
offset: 0,
lfn_buffer: Vec::new(),
fs_id,
}
}
pub fn advance(&mut self, offset: u64) {
self.offset = offset;
}
async fn next_fat32_entry(&mut self) -> Result<Option<Fat32DirEntry>> {
loop {
let mut entry_bytes = [0; 32];
self.reader
.read_at(self.offset * 32, &mut entry_bytes)
.await?;
match entry_bytes[0] {
0x00 => return Ok(None), // End of directory, no more entries
0xE5 => {
// Deleted entry
self.lfn_buffer.clear();
self.offset += 1;
continue;
}
_ => (), // A normal entry
}
// Check attribute byte (offset 11) to differentiate LFN vs 8.3
let attribute_byte = entry_bytes[11];
if attribute_byte == 0x0F {
// It's an LFN entry
let lfn_entry: LfnEntry =
unsafe { ptr::read_unaligned(entry_bytes.as_ptr() as *const _) };
// LFN entries are stored backwards, so we prepend.
let new_chars = lfn_entry.extract_chars();
self.lfn_buffer.splice(0..0, new_chars);
self.offset += 1;
continue;
}
let dir_entry: DirEntry =
unsafe { ptr::read_unaligned(entry_bytes.as_ptr() as *const _) };
if dir_entry.attributes.contains(Fat32Attributes::VOLUME_LABEL) {
self.lfn_buffer.clear();
self.offset += 1;
continue;
}
let name = if !self.lfn_buffer.is_empty() {
let len = self
.lfn_buffer
.iter()
.position(|&c| c == 0x0000 || c == 0xFFFF)
.unwrap_or(self.lfn_buffer.len());
String::from_utf16_lossy(&self.lfn_buffer[..len])
} else {
// No LFN, parse the 8.3 name.
dir_entry.parse_filename()
};
// Process the metadata from the 8.3 entry
let file_type = FileType::try_from(dir_entry.attributes)?;
let cluster = Cluster::from_high_low(dir_entry.clust_high, dir_entry.clust_low);
let attr = FileAttr {
size: dir_entry.size as u64,
file_type,
mode: FilePermissions::from_bits_retain(0o755),
// TODO: parse date/time fields.
..Default::default()
};
self.lfn_buffer.clear();
self.offset += 1;
return Ok(Some(Fat32DirEntry {
attr,
cluster,
name,
// Note that the offset should be to the *next* entry, so using
// the advanced entry is correct.
offset: self.offset,
}));
}
}
}
#[async_trait]
impl<T: Fat32Operations> DirStream for Fat32DirStream<T> {
async fn next_entry(&mut self) -> Result<Option<Dirent>> {
let entry = self.next_fat32_entry().await?;
Ok(entry.map(|x| Dirent {
id: InodeId::from_fsid_and_inodeid(self.fs_id, x.cluster.value() as _),
name: x.name.clone(),
file_type: x.attr.file_type,
offset: x.offset,
}))
}
}
pub struct Fat32DirNode<T: Fat32Operations> {
attr: FileAttr,
root: Cluster,
fs: Arc<T>,
streamer: Fat32DirStream<T>,
}
impl<T: Fat32Operations> Fat32DirNode<T> {
pub fn new(fs: Arc<T>, root: Cluster, attr: FileAttr) -> Self {
let streamer = Fat32DirStream::new(fs.clone(), root);
Self {
attr,
root,
fs,
streamer,
}
}
}
#[async_trait]
impl<T: Fat32Operations> Inode for Fat32DirNode<T> {
fn id(&self) -> InodeId {
InodeId::from_fsid_and_inodeid(self.fs.id(), self.root.value() as _)
}
async fn lookup(&self, name: &str) -> Result<Arc<dyn Inode>> {
let mut dir_iter = self.streamer.clone();
while let Some(entry) = dir_iter.next_fat32_entry().await? {
if entry.name == name {
return match entry.attr.file_type {
FileType::File => Ok(Arc::new(Fat32FileNode::new(
self.fs.clone(),
entry.cluster,
entry.attr.clone(),
)?)),
FileType::Directory => Ok(Arc::new(Self::new(
self.fs.clone(),
entry.cluster,
entry.attr.clone(),
))),
_ => Err(KernelError::NotSupported),
};
}
}
Err(FsError::NotFound.into())
}
async fn readdir(&self, start_offset: u64) -> Result<Box<dyn DirStream>> {
let mut iter = self.streamer.clone();
iter.advance(start_offset);
Ok(Box::new(iter))
}
async fn getattr(&self) -> Result<FileAttr> {
Ok(self.attr.clone())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::fs::filesystems::fat32::file::test::MockFs;
mod raw_test;
fn checksum_83(name: &[u8; 8], ext: &[u8; 3]) -> u8 {
let mut sum: u8 = 0;
for &byte in name.iter().chain(ext.iter()) {
sum = (sum >> 1) | ((sum & 1) << 7); // Rotate right
sum = sum.wrapping_add(byte);
}
sum
}
/// A builder to easily create 32-byte DirEntry byte arrays for tests.
struct DirEntryBuilder {
entry: DirEntry,
}
impl DirEntryBuilder {
fn new(name: &str, ext: &str) -> Self {
let mut dos_file_name = [b' '; 8];
let mut dos_extension = [b' '; 3];
dos_file_name[..name.len()].copy_from_slice(name.as_bytes());
dos_extension[..ext.len()].copy_from_slice(ext.as_bytes());
Self {
entry: DirEntry {
dos_file_name,
dos_extension,
attributes: Fat32Attributes::ARCHIVE,
_reserved: 0,
ctime_ms: 0,
ctime: 0,
cdate: 0,
adate: 0,
clust_high: 0,
mtime: 0,
mdate: 0,
clust_low: 0,
size: 0,
},
}
}
fn attributes(mut self, attrs: Fat32Attributes) -> Self {
self.entry.attributes = attrs;
self
}
fn cluster(mut self, cluster: u32) -> Self {
self.entry.clust_high = (cluster >> 16) as u16;
self.entry.clust_low = cluster as u16;
self
}
fn size(mut self, size: u32) -> Self {
self.entry.size = size;
self
}
fn build(self) -> [u8; 32] {
unsafe { core::mem::transmute(self.entry) }
}
}
/// A builder to create the series of LFN entries for a long name.
struct LfnBuilder {
long_name: String,
checksum: u8,
}
impl LfnBuilder {
fn new(long_name: &str, checksum: u8) -> Self {
Self {
long_name: long_name.to_string(),
checksum,
}
}
fn build(self) -> Vec<[u8; 32]> {
let mut utf16_chars: Vec<u16> = self.long_name.encode_utf16().collect();
utf16_chars.push(0x0000); // Null terminator
// Pad to a multiple of 13 characters
while utf16_chars.len() % 13 != 0 {
utf16_chars.push(0xFFFF);
}
let mut lfn_entries = Vec::new();
let num_entries = utf16_chars.len() / 13;
for i in 0..num_entries {
let sequence_number = (num_entries - i) as u8;
let chunk = &utf16_chars[(num_entries - 1 - i) * 13..][..13];
let mut lfn = LfnEntry {
sequence_number,
name1: [0; 5],
attributes: 0x0F,
entry_type: 0,
checksum: self.checksum,
name2: [0; 6],
first_cluster: 0,
name3: [0; 2],
};
if i == 0 {
// First entry on disk is the last logically
lfn.sequence_number |= 0x40;
}
unsafe {
ptr::write_unaligned(&raw mut lfn.name1, chunk[0..5].try_into().unwrap());
ptr::write_unaligned(&raw mut lfn.name2, chunk[5..11].try_into().unwrap());
ptr::write_unaligned(&raw mut lfn.name3, chunk[11..13].try_into().unwrap());
}
lfn_entries.push(unsafe { core::mem::transmute(lfn) });
}
lfn_entries
}
}
/// Creates a mock Fat32Filesystem containing the directory data.
async fn setup_dir_test(dir_data: Vec<u8>) -> Arc<MockFs> {
Arc::new(MockFs::new(&dir_data, 512, 1))
}
async fn collect_entries<T: Fat32Operations>(
mut iter: Fat32DirStream<T>,
) -> Vec<Fat32DirEntry> {
let mut ret = Vec::new();
while let Some(entry) = iter.next_fat32_entry().await.unwrap() {
ret.push(entry);
}
ret
}
#[tokio::test]
async fn test_simple_83_entries() {
let mut data = Vec::new();
data.extend_from_slice(
&DirEntryBuilder::new("FILE", "TXT")
.attributes(Fat32Attributes::ARCHIVE)
.cluster(10)
.size(1024)
.build(),
);
data.extend_from_slice(
&DirEntryBuilder::new("SUBDIR", "")
.attributes(Fat32Attributes::DIRECTORY)
.cluster(11)
.build(),
);
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 2);
let file = &entries[0];
assert_eq!(file.name, "FILE.TXT");
assert_eq!(file.attr.file_type, FileType::File);
assert_eq!(file.cluster, Cluster(10));
assert_eq!(file.attr.size, 1024);
let dir = &entries[1];
assert_eq!(dir.name, "SUBDIR");
assert_eq!(dir.attr.file_type, FileType::Directory);
assert_eq!(dir.cluster, Cluster(11));
}
#[tokio::test]
async fn test_single_lfn_entry() {
let sfn = DirEntryBuilder::new("TESTFI~1", "TXT")
.attributes(Fat32Attributes::ARCHIVE)
.cluster(5)
.build();
let checksum = checksum_83(
&sfn[0..8].try_into().unwrap(),
&sfn[8..11].try_into().unwrap(),
);
let lfn = LfnBuilder::new("testfile.txt", checksum).build();
let mut data = Vec::new();
lfn.into_iter().for_each(|e| data.extend_from_slice(&e));
data.extend_from_slice(&sfn);
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 1);
let entry = &entries[0];
assert_eq!(entry.name, "testfile.txt");
assert_eq!(entry.cluster, Cluster(5));
assert_eq!(entry.attr.file_type, FileType::File);
}
#[tokio::test]
async fn test_multi_part_lfn_entry() {
let sfn = DirEntryBuilder::new("AVERYL~1", "LOG")
.attributes(Fat32Attributes::ARCHIVE)
.cluster(42)
.build();
let checksum = checksum_83(
&sfn[0..8].try_into().unwrap(),
&sfn[8..11].try_into().unwrap(),
);
let lfn = LfnBuilder::new("a very long filename indeed.log", checksum).build();
let mut data = Vec::new();
lfn.into_iter().for_each(|e| data.extend_from_slice(&e));
data.extend_from_slice(&sfn);
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 1);
let entry = &entries[0];
assert_eq!(entry.name, "a very long filename indeed.log");
assert_eq!(entry.cluster, Cluster(42));
}
#[tokio::test]
async fn ignores_deleted_entries() {
let mut data = Vec::new();
let mut deleted_entry = DirEntryBuilder::new("DELETED", "TMP").build();
deleted_entry[0] = 0xE5; // Mark as deleted
data.extend_from_slice(&deleted_entry);
data.extend_from_slice(&DirEntryBuilder::new("GOODFILE", "DAT").build());
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "GOODFILE.DAT");
assert_eq!(entries[0].offset, 2);
}
#[tokio::test]
async fn stops_at_end_of_dir_marker() {
let mut data = Vec::new();
data.extend_from_slice(&DirEntryBuilder::new("FIRST", "FIL").build());
data.extend_from_slice(&[0u8; 32]); // End of directory marker
data.extend_from_slice(&DirEntryBuilder::new("JUNK", "FIL").build()); // Should not be parsed
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "FIRST.FIL");
}
#[tokio::test]
async fn test_ignores_volume_label() {
let mut data = Vec::new();
data.extend_from_slice(
&DirEntryBuilder::new("MYVOLUME", "")
.attributes(Fat32Attributes::VOLUME_LABEL)
.build(),
);
data.extend_from_slice(&DirEntryBuilder::new("REALFILE", "TXT").build());
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "REALFILE.TXT");
assert_eq!(entries[0].offset, 2);
}
#[tokio::test]
async fn raw() {
let mut data = Vec::new();
data.extend_from_slice(&raw_test::RAW_DATA);
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].name, "test.txt");
assert_eq!(entries[0].cluster, Cluster(0xb));
assert_eq!(entries[0].attr.size, 0xd);
assert_eq!(
entries[1].name,
"some-really-long-file-name-that-should-span-over-multiple-entries.txt"
);
assert_eq!(entries[1].cluster, Cluster(0));
assert_eq!(entries[1].attr.size, 0);
}
#[tokio::test]
async fn test_mixed_directory_listing() {
let mut data = Vec::new();
// 1. A standard 8.3 file
data.extend_from_slice(
&DirEntryBuilder::new("KERNEL", "ELF")
.attributes(Fat32Attributes::ARCHIVE)
.cluster(3)
.build(),
);
// 2. A deleted LFN entry
let deleted_sfn = DirEntryBuilder::new("OLDLOG~1", "TMP").build();
let deleted_checksum = checksum_83(
&deleted_sfn[0..8].try_into().unwrap(),
&deleted_sfn[8..11].try_into().unwrap(),
);
let mut deleted_lfn_bytes = LfnBuilder::new("old-log-file.tmp", deleted_checksum)
.build()
.remove(0);
deleted_lfn_bytes[0] = 0xE5;
data.extend_from_slice(&deleted_lfn_bytes);
// 3. A valid LFN file
let sfn = DirEntryBuilder::new("MYNOTE~1", "MD")
.attributes(Fat32Attributes::ARCHIVE)
.cluster(4)
.build();
let checksum = checksum_83(
&sfn[0..8].try_into().unwrap(),
&sfn[8..11].try_into().unwrap(),
);
let lfn = LfnBuilder::new("my notes.md", checksum).build();
lfn.into_iter().for_each(|e| data.extend_from_slice(&e));
data.extend_from_slice(&sfn);
let fs = setup_dir_test(data).await;
let dir_stream = Fat32DirStream::new(fs, Cluster(2));
let entries = collect_entries(dir_stream).await;
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].name, "KERNEL.ELF");
assert_eq!(entries[0].cluster, Cluster(3));
assert_eq!(entries[1].name, "my notes.md");
assert_eq!(entries[1].cluster, Cluster(4));
}
}

View File

@@ -0,0 +1,24 @@
pub const RAW_DATA: [u8; 352] = [
0xe5, 0x2e, 0x00, 0x74, 0x00, 0x65, 0x00, 0x73, 0x00, 0x74, 0x00, 0x0f, 0x00, 0xa1, 0x2e, 0x00,
0x74, 0x00, 0x78, 0x00, 0x74, 0x00, 0x2e, 0x00, 0x73, 0x00, 0x00, 0x00, 0x77, 0x00, 0x70, 0x00,
0xe5, 0x45, 0x53, 0x54, 0x54, 0x58, 0x7e, 0x31, 0x53, 0x57, 0x50, 0x20, 0x00, 0x80, 0x1a, 0x66,
0x2d, 0x5b, 0x2d, 0x5b, 0x00, 0x00, 0x1a, 0x66, 0x2d, 0x5b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x41, 0x74, 0x00, 0x65, 0x00, 0x73, 0x00, 0x74, 0x00, 0x2e, 0x00, 0x0f, 0x00, 0x8f, 0x74, 0x00,
0x78, 0x00, 0x74, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
0x54, 0x45, 0x53, 0x54, 0x20, 0x20, 0x20, 0x20, 0x54, 0x58, 0x54, 0x20, 0x00, 0xa1, 0x1c, 0x66,
0x2d, 0x5b, 0x2d, 0x5b, 0x00, 0x00, 0x1c, 0x66, 0x2d, 0x5b, 0x0b, 0x00, 0x0d, 0x00, 0x00, 0x00,
0x46, 0x2e, 0x00, 0x74, 0x00, 0x78, 0x00, 0x74, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x34, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
0x05, 0x74, 0x00, 0x69, 0x00, 0x70, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x0f, 0x00, 0x34, 0x2d, 0x00,
0x65, 0x00, 0x6e, 0x00, 0x74, 0x00, 0x72, 0x00, 0x69, 0x00, 0x00, 0x00, 0x65, 0x00, 0x73, 0x00,
0x04, 0x73, 0x00, 0x70, 0x00, 0x61, 0x00, 0x6e, 0x00, 0x2d, 0x00, 0x0f, 0x00, 0x34, 0x6f, 0x00,
0x76, 0x00, 0x65, 0x00, 0x72, 0x00, 0x2d, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x75, 0x00, 0x6c, 0x00,
0x03, 0x2d, 0x00, 0x74, 0x00, 0x68, 0x00, 0x61, 0x00, 0x74, 0x00, 0x0f, 0x00, 0x34, 0x2d, 0x00,
0x73, 0x00, 0x68, 0x00, 0x6f, 0x00, 0x75, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x64, 0x00, 0x2d, 0x00,
0x02, 0x6f, 0x00, 0x6e, 0x00, 0x67, 0x00, 0x2d, 0x00, 0x66, 0x00, 0x0f, 0x00, 0x34, 0x69, 0x00,
0x6c, 0x00, 0x65, 0x00, 0x2d, 0x00, 0x6e, 0x00, 0x61, 0x00, 0x00, 0x00, 0x6d, 0x00, 0x65, 0x00,
0x01, 0x73, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x65, 0x00, 0x2d, 0x00, 0x0f, 0x00, 0x34, 0x72, 0x00,
0x65, 0x00, 0x61, 0x00, 0x6c, 0x00, 0x6c, 0x00, 0x79, 0x00, 0x00, 0x00, 0x2d, 0x00, 0x6c, 0x00,
0x53, 0x4f, 0x4d, 0x45, 0x2d, 0x52, 0x7e, 0x31, 0x54, 0x58, 0x54, 0x20, 0x00, 0x6e, 0xce, 0x3a,
0x31, 0x5b, 0x31, 0x5b, 0x00, 0x00, 0xce, 0x3a, 0x31, 0x5b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];

View File

@@ -0,0 +1,338 @@
use crate::{
error::{FsError, IoError, Result},
fs::blk::buffer::BlockBuffer,
};
use alloc::vec;
use alloc::vec::Vec;
use super::{Cluster, bpb::BiosParameterBlock};
#[derive(PartialEq, Eq, Debug)]
pub enum FatEntry {
Eoc,
NextCluster(Cluster),
Bad,
Reserved,
Free,
}
impl From<u32> for FatEntry {
fn from(value: u32) -> Self {
match value & 0x0fffffff {
0 => Self::Free,
1 => Self::Reserved,
n @ 2..=0xFFFFFF6 => Self::NextCluster(Cluster(n)),
0xFFFFFF7 => Self::Bad,
0xFFFFFF8..=0xFFFFFFF => Self::Eoc,
_ => unreachable!("The last nibble has been masked"),
}
}
}
#[derive(PartialEq, Eq, Debug)]
pub struct Fat {
data: Vec<FatEntry>,
}
pub struct ClusterChainIterator<'a> {
fat: &'a Fat,
current_or_next: Option<Cluster>,
}
impl<'a> Iterator for ClusterChainIterator<'a> {
type Item = Result<Cluster>;
fn next(&mut self) -> Option<Self::Item> {
let cluster_to_return = self.current_or_next?;
let entry = match self.fat.data.get(cluster_to_return.value()) {
Some(entry) => entry,
None => {
self.current_or_next = None;
return Some(Err(IoError::OutOfBounds.into()));
}
};
match entry {
FatEntry::Eoc => {
self.current_or_next = None;
}
FatEntry::NextCluster(next) => {
self.current_or_next = Some(*next);
}
FatEntry::Bad | FatEntry::Reserved | FatEntry::Free => {
self.current_or_next = None;
return Some(Err(IoError::MetadataCorruption.into()));
}
}
Some(Ok(cluster_to_return))
}
}
impl Fat {
pub async fn read_fat(
dev: &BlockBuffer,
bpb: &BiosParameterBlock,
fat_number: usize,
) -> Result<Self> {
let (start, end) = bpb.fat_region(fat_number).ok_or(FsError::InvalidFs)?;
let mut fat: Vec<FatEntry> = Vec::with_capacity(
(bpb.sector_offset(end) as usize - bpb.sector_offset(start) as usize) / 4,
);
let mut buf = vec![0; bpb.sector_size()];
for sec in start.sectors_until(end) {
dev.read_at(bpb.sector_offset(sec), &mut buf).await?;
fat.extend(
buf.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.map(|v| v.into()),
);
}
Ok(Self { data: fat })
}
pub fn get_cluster_chain(&self, root: Cluster) -> impl Iterator<Item = Result<Cluster>> {
ClusterChainIterator {
fat: self,
current_or_next: Some(root),
}
}
}
#[cfg(test)]
mod test {
use crate::error::{IoError, KernelError, Result};
use crate::fs::filesystems::fat32::Cluster;
use crate::fs::filesystems::fat32::bpb::test::create_test_bpb;
use crate::fs::filesystems::fat32::fat::{Fat, FatEntry};
use crate::fs::{BlockDevice, blk::buffer::BlockBuffer};
use async_trait::async_trait;
const EOC: u32 = 0xFFFFFFFF;
const BAD: u32 = 0xFFFFFFF7;
const FREE: u32 = 0;
const RESERVED: u32 = 1;
struct MemBlkDevice {
data: Vec<u8>,
}
#[async_trait]
impl BlockDevice for MemBlkDevice {
/// Read one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn read(&self, block_id: u64, buf: &mut [u8]) -> Result<()> {
buf.copy_from_slice(&self.data[block_id as usize..block_id as usize + buf.len()]);
Ok(())
}
/// Write one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn write(&self, _block_id: u64, _buf: &[u8]) -> Result<()> {
unimplemented!()
}
/// The size of a single block in bytes.
fn block_size(&self) -> usize {
1
}
/// Flushes any caches to the underlying device.
async fn sync(&self) -> Result<()> {
unimplemented!()
}
}
fn setup_fat_test(fat_data: &[u32]) -> BlockBuffer {
let mut data = Vec::new();
data.extend(fat_data.iter().flat_map(|x| x.to_le_bytes()));
BlockBuffer::new(Box::new(MemBlkDevice { data }))
}
#[tokio::test]
async fn test_read_fat_simple_parse() {
let fat_data = [
FREE, // Cluster 0
RESERVED, // Cluster 1
EOC, // Cluster 2
5, // Cluster 3 -> 5
BAD, // Cluster 4
EOC, // Cluster 5
0xDEADBEEF & 0x0FFFFFFF, // Test masking of top bits
];
let device = setup_fat_test(&fat_data);
let mut bpb = create_test_bpb();
bpb.bytes_per_sector = fat_data.len() as u16 * 4;
bpb.sectors_per_cluster = 1;
bpb.num_fats = 1;
bpb.fat_size_32 = 1;
bpb.reserved_sector_count = 0;
let fat = Fat::read_fat(&device, &bpb, 0)
.await
.expect("read_fat should succeed");
assert_eq!(
fat.data.len(),
fat_data.len(),
"Parsed FAT has incorrect length"
);
assert_eq!(fat.data[0], FatEntry::Free);
assert_eq!(fat.data[1], FatEntry::Reserved);
assert_eq!(fat.data[2], FatEntry::Eoc);
assert_eq!(fat.data[3], FatEntry::NextCluster(Cluster(5)));
assert_eq!(fat.data[4], FatEntry::Bad);
assert_eq!(fat.data[5], FatEntry::Eoc);
// Ensure the top 4 bits are ignored.
assert_eq!(fat.data[6], FatEntry::NextCluster(Cluster(0x0EADBEEF)));
}
#[tokio::test]
async fn test_read_fat_across_multiple_sectors() {
// A sector size of 512 bytes can hold 128 u32 entries.
// We'll create a FAT that is slightly larger to force a multi-sector read.
let mut fat_data = Vec::with_capacity(150);
for i in 0..150 {
fat_data.push(i + 2); // Create a simple chain: 0->2, 1->3, etc.
}
fat_data[149] = 0xFFFFFFFF; // End the last chain
let device = setup_fat_test(&fat_data);
let mut bpb = create_test_bpb();
bpb.bytes_per_sector = 300;
bpb.num_fats = 1;
bpb.reserved_sector_count = 0;
bpb.sectors_per_cluster = 1;
bpb.fat_size_32 = 2;
let fat = super::Fat::read_fat(&device, &bpb, 0)
.await
.expect("read_fat should succeed");
assert!(super::Fat::read_fat(&device, &bpb, 1).await.is_err());
assert_eq!(fat.data.len(), 150, "Parsed FAT has incorrect length");
assert_eq!(fat.data[0], FatEntry::NextCluster(Cluster(2)));
assert_eq!(fat.data[127], FatEntry::NextCluster(Cluster(129))); // End of 1st sector
assert_eq!(fat.data[128], FatEntry::NextCluster(Cluster(130))); // Start of 2nd sector
assert_eq!(fat.data[149], FatEntry::Eoc);
}
fn setup_chain_test_fat() -> super::Fat {
#[rustfmt::skip]
let fat_data = [
/* 0 */ FREE,
/* 1 */ RESERVED,
/* 2 */ EOC, // Single-cluster file
/* 3 */ 4, // Start of linear chain
/* 4 */ 5,
/* 5 */ EOC,
/* 6 */ 10, // Start of fragmented chain
/* 7 */ 9, // Chain leading to a bad cluster
/* 8 */ EOC,
/* 9 */ BAD,
/* 10 */ 8,
/* 11 */ 12, // Chain leading to a free cluster
/* 12 */ FREE,
/* 13 */ 14, // Chain with a cycle
/* 14 */ 15,
/* 15 */ 13,
/* 16 */ 99, // Chain pointing out of bounds
];
let data = fat_data.iter().map(|&v| FatEntry::from(v)).collect();
Fat { data }
}
#[test]
fn test_chain_single_cluster() {
let fat = setup_chain_test_fat();
let chain: Vec<_> = fat.get_cluster_chain(Cluster(2)).collect();
assert_eq!(chain, vec![Ok(Cluster(2))]);
}
#[test]
fn test_chain_linear() {
let fat = setup_chain_test_fat();
let chain: Vec<_> = fat.get_cluster_chain(Cluster(3)).collect();
assert_eq!(chain, vec![Ok(Cluster(3)), Ok(Cluster(4)), Ok(Cluster(5))]);
}
#[test]
fn test_chain_fragmented() {
let fat = setup_chain_test_fat();
let chain: Vec<_> = fat.get_cluster_chain(Cluster(6)).collect();
assert_eq!(chain, vec![Ok(Cluster(6)), Ok(Cluster(10)), Ok(Cluster(8))]);
}
#[test]
fn test_chain_points_to_bad_cluster() {
let fat = setup_chain_test_fat();
let chain: Vec<_> = fat.get_cluster_chain(Cluster(7)).collect();
assert_eq!(chain.len(), 2);
assert!(
chain[1].is_err(),
"Should fail when chain encounters a bad cluster"
);
assert!(matches!(
chain[1],
Err(KernelError::Io(IoError::MetadataCorruption))
));
}
#[test]
fn test_chain_points_to_free_cluster() {
let fat = setup_chain_test_fat();
let chain: Vec<_> = fat.get_cluster_chain(Cluster(11)).collect();
assert_eq!(chain.len(), 2);
assert!(
chain[1].is_err(),
"Should fail when chain encounters a free cluster"
);
assert!(matches!(
chain[1],
Err(KernelError::Io(IoError::MetadataCorruption))
));
}
#[test]
fn test_chain_points_out_of_bounds() {
let fat = setup_chain_test_fat();
let result: Vec<_> = fat.get_cluster_chain(Cluster(16)).collect();
dbg!(&result);
assert_eq!(result.len(), 2);
assert!(
result[1].is_err(),
"Should fail when chain points to an out-of-bounds cluster"
);
assert!(matches!(
result[1],
Err(KernelError::Io(IoError::OutOfBounds))
));
}
#[test]
fn test_chain_starts_out_of_bounds() {
let fat = setup_chain_test_fat();
// Start with a cluster number that is larger than the FAT itself.
let chain: Vec<_> = fat.get_cluster_chain(Cluster(100)).collect();
assert!(
chain[0].is_err(),
"Should fail when the starting cluster is out-of-bounds"
);
assert!(matches!(
chain[0],
Err(KernelError::Io(IoError::OutOfBounds))
));
}
}

View File

@@ -0,0 +1,196 @@
use crate::{
error::Result,
fs::{Inode, InodeId, attr::FileAttr},
};
use alloc::boxed::Box;
use alloc::sync::Arc;
use async_trait::async_trait;
use super::{Cluster, Fat32Operations, reader::Fat32Reader};
pub struct Fat32FileNode<T: Fat32Operations> {
reader: Fat32Reader<T>,
attr: FileAttr,
id: InodeId,
}
impl<T: Fat32Operations> Fat32FileNode<T> {
pub fn new(fs: Arc<T>, root: Cluster, attr: FileAttr) -> Result<Self> {
let id = InodeId::from_fsid_and_inodeid(fs.id() as _, root.value() as _);
Ok(Self {
reader: Fat32Reader::new(fs, root, attr.size),
attr,
id,
})
}
}
#[async_trait]
impl<T: Fat32Operations> Inode for Fat32FileNode<T> {
fn id(&self) -> InodeId {
self.id
}
async fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<usize> {
self.reader.read_at(offset, buf).await
}
async fn getattr(&self) -> Result<FileAttr> {
Ok(self.attr.clone())
}
}
#[cfg(test)]
pub mod test {
use crate::{error::FsError, fs::filesystems::fat32::Sector};
use super::*;
use alloc::{collections::BTreeMap, sync::Arc, vec};
pub struct MockFs {
file_data: BTreeMap<u32, Vec<u8>>, // Map Sector(u32) -> data
sector_size: usize,
sectors_per_cluster: usize,
}
impl MockFs {
pub fn new(file_content: &[u8], sector_size: usize, sectors_per_cluster: usize) -> Self {
let mut file_data = BTreeMap::new();
// Data region starts at sector 100 for simplicity
let data_start_sector = 100;
for (i, chunk) in file_content.chunks(sector_size).enumerate() {
let mut sector_data = vec![0; sector_size];
sector_data[..chunk.len()].copy_from_slice(chunk);
file_data.insert((data_start_sector + i) as u32, sector_data);
}
Self {
file_data,
sector_size,
sectors_per_cluster,
}
}
}
impl Fat32Operations for MockFs {
async fn read_sector(
&self,
sector: Sector,
offset: usize,
buf: &mut [u8],
) -> Result<usize> {
let sector_data = self.file_data.get(&sector.0).ok_or(FsError::OutOfBounds)?;
let bytes_in_sec = sector_data.len() - offset;
let read_size = core::cmp::min(buf.len(), bytes_in_sec);
buf[..read_size].copy_from_slice(&sector_data[offset..offset + read_size]);
Ok(read_size)
}
fn id(&self) -> u64 {
0
}
fn sector_size(&self) -> usize {
self.sector_size
}
fn sectors_per_cluster(&self) -> usize {
self.sectors_per_cluster
}
fn bytes_per_cluster(&self) -> usize {
self.sector_size * self.sectors_per_cluster
}
fn cluster_to_sectors(&self, cluster: Cluster) -> Result<impl Iterator<Item = Sector>> {
// Simple mapping for the test: Cluster C -> Sectors (100 + (C-2)*SPC ..)
let data_start_sector = 100;
let start = data_start_sector + (cluster.value() - 2) * self.sectors_per_cluster;
let end = start + self.sectors_per_cluster;
Ok((start as u32..end as u32).map(Sector))
}
fn iter_clusters(&self, root: Cluster) -> impl Iterator<Item = Result<Cluster>> {
// Assume a simple contiguous chain for testing.
let num_clusters =
(self.file_data.len() + self.sectors_per_cluster - 1) / self.sectors_per_cluster;
(0..num_clusters).map(move |i| Ok(Cluster((root.value() + i) as u32)))
}
}
async fn setup_file_test(content: &[u8]) -> Fat32FileNode<MockFs> {
let fs = Arc::new(MockFs::new(content, 512, 4));
Fat32FileNode::new(
fs,
Cluster(2),
FileAttr {
size: content.len() as _,
..FileAttr::default()
},
)
.unwrap()
}
#[tokio::test]
async fn test_read_simple() {
let file_content: Vec<u8> = (0..100).collect();
let inode = setup_file_test(&file_content).await;
let mut buf = vec![0; 50];
let bytes_read = inode.read_at(10, &mut buf).await.unwrap();
assert_eq!(bytes_read, 50);
assert_eq!(buf, &file_content[10..60]);
}
#[tokio::test]
async fn test_read_crossing_sector_boundary() {
let file_content: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
let inode = setup_file_test(&file_content).await;
// Read from offset 510 for 4 bytes. Should read 2 bytes from sector 0
// and 2 bytes from sector 1.
let mut buf = vec![0; 4];
let bytes_read = inode.read_at(510, &mut buf).await.unwrap();
assert_eq!(bytes_read, 4);
assert_eq!(buf, &file_content[510..514]);
}
#[tokio::test]
async fn test_read_crossing_cluster_boundary() {
// Sector size = 512, Sectors per cluster = 4 -> Cluster size = 2048
let file_content: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
let inode = setup_file_test(&file_content).await;
// Read from offset 2040 for 16 bytes. Should cross from cluster 2 to cluster 3.
let mut buf = vec![0; 16];
let bytes_read = inode.read_at(2040, &mut buf).await.unwrap();
assert_eq!(bytes_read, 16);
assert_eq!(buf, &file_content[2040..2056]);
}
#[tokio::test]
async fn test_read_past_eof() {
let file_content: Vec<u8> = (0..100).collect();
let inode = setup_file_test(&file_content).await;
let mut buf = vec![0; 50];
// Start reading at offset 80, but buffer is 50. Should only read 20 bytes.
let bytes_read = inode.read_at(80, &mut buf).await.unwrap();
assert_eq!(bytes_read, 20);
assert_eq!(buf[..20], file_content[80..100]);
}
#[tokio::test]
async fn test_read_at_eof() {
let file_content: Vec<u8> = (0..100).collect();
let inode = setup_file_test(&file_content).await;
let mut buf = vec![0; 50];
let bytes_read = inode.read_at(100, &mut buf).await.unwrap();
assert_eq!(bytes_read, 0);
}
}

View File

@@ -0,0 +1,183 @@
use crate::{
error::{FsError, Result},
fs::{FileType, Filesystem, Inode, InodeId, attr::FileAttr, blk::buffer::BlockBuffer},
};
use alloc::{
boxed::Box,
sync::{Arc, Weak},
};
use async_trait::async_trait;
use bpb::BiosParameterBlock;
use core::{
cmp::min,
fmt::Display,
ops::{Add, Mul},
};
use dir::Fat32DirNode;
use fat::Fat;
use log::warn;
mod bpb;
mod dir;
mod fat;
mod file;
mod reader;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Sector(u32);
impl Mul<usize> for Sector {
type Output = Sector;
fn mul(self, rhs: usize) -> Self::Output {
Self(self.0 * rhs as u32)
}
}
impl Add<Sector> for Sector {
type Output = Sector;
fn add(self, rhs: Sector) -> Self::Output {
Self(self.0 + rhs.0)
}
}
impl Sector {
pub fn sectors_until(self, other: Self) -> impl Iterator<Item = Self> {
(self.0..other.0).map(Sector)
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct Cluster(u32);
impl Cluster {
pub fn value(self) -> usize {
self.0 as _
}
pub fn from_high_low(clust_high: u16, clust_low: u16) -> Cluster {
Cluster((clust_high as u32) << 16 | clust_low as u32)
}
pub fn is_valid(self) -> bool {
self.0 >= 2
}
}
impl Display for Cluster {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.0.fmt(f)
}
}
pub struct Fat32Filesystem {
dev: BlockBuffer,
bpb: BiosParameterBlock,
fat: Fat,
id: u64,
this: Weak<Self>,
}
impl Fat32Filesystem {
pub async fn new(dev: BlockBuffer, id: u64) -> Result<Arc<Self>> {
let bpb = BiosParameterBlock::new(&dev).await?;
let fat = Fat::read_fat(&dev, &bpb, 0).await?;
for fat_num in 1..bpb.num_fats {
let other_fat = Fat::read_fat(&dev, &bpb, fat_num as _).await?;
if other_fat != fat {
warn!("Failing to mount, FAT disagree.");
return Err(FsError::InvalidFs.into());
}
}
Ok(Arc::new_cyclic(|weak| Self {
bpb,
dev,
fat,
this: weak.clone(),
id,
}))
}
}
trait Fat32Operations: Send + Sync + 'static {
fn read_sector(
&self,
sector: Sector,
offset: usize,
buf: &mut [u8],
) -> impl Future<Output = Result<usize>> + Send;
fn id(&self) -> u64;
fn sector_size(&self) -> usize;
fn sectors_per_cluster(&self) -> usize;
fn bytes_per_cluster(&self) -> usize {
self.sectors_per_cluster() * self.sector_size()
}
fn cluster_to_sectors(&self, cluster: Cluster) -> Result<impl Iterator<Item = Sector> + Send>;
fn iter_clusters(&self, root: Cluster) -> impl Iterator<Item = Result<Cluster>> + Send;
}
impl Fat32Operations for Fat32Filesystem {
async fn read_sector(&self, sector: Sector, offset: usize, buf: &mut [u8]) -> Result<usize> {
debug_assert!(offset < self.bpb.sector_size());
let bytes_left_in_sec = self.bpb.sector_size() - offset;
let read_sz = min(buf.len(), bytes_left_in_sec);
self.dev
.read_at(
self.bpb.sector_offset(sector) + offset as u64,
&mut buf[..read_sz],
)
.await?;
Ok(read_sz)
}
fn id(&self) -> u64 {
self.id
}
fn sector_size(&self) -> usize {
self.bpb.sector_size()
}
fn sectors_per_cluster(&self) -> usize {
self.bpb.sectors_per_cluster as _
}
fn cluster_to_sectors(&self, cluster: Cluster) -> Result<impl Iterator<Item = Sector>> {
self.bpb.cluster_to_sectors(cluster)
}
fn iter_clusters(&self, root: Cluster) -> impl Iterator<Item = Result<Cluster>> {
self.fat.get_cluster_chain(root)
}
}
#[async_trait]
impl Filesystem for Fat32Filesystem {
fn id(&self) -> u64 {
self.id
}
/// Get the root inode of this filesystem.
async fn root_inode(&self) -> Result<Arc<dyn Inode>> {
Ok(Arc::new(Fat32DirNode::new(
self.this.upgrade().unwrap(),
self.bpb.root_cluster,
FileAttr {
id: InodeId::from_fsid_and_inodeid(self.id, self.bpb.root_cluster.0 as _),
file_type: FileType::Directory,
..FileAttr::default()
},
)))
}
}

View File

@@ -0,0 +1,106 @@
use crate::error::Result;
use alloc::sync::Arc;
use super::{Cluster, Fat32Operations};
pub struct Fat32Reader<T: Fat32Operations> {
fs: Arc<T>,
root: Cluster,
max_sz: u64,
}
impl<T: Fat32Operations> Clone for Fat32Reader<T> {
fn clone(&self) -> Self {
Self {
fs: self.fs.clone(),
root: self.root,
max_sz: self.max_sz,
}
}
}
impl<T: Fat32Operations> Fat32Reader<T> {
pub fn new(fs: Arc<T>, root: Cluster, max_sz: u64) -> Self {
Self { fs, root, max_sz }
}
pub async fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<usize> {
// Ensure we don't read past the end of the stream.
if offset >= self.max_sz {
return Ok(0);
}
let bytes_to_read = core::cmp::min(buf.len() as u64, self.max_sz - offset) as usize;
if bytes_to_read == 0 {
return Ok(0);
}
let buf = &mut buf[..bytes_to_read];
let mut total_bytes_read = 0;
let bpc = self.fs.bytes_per_cluster();
let sector_size = self.fs.sector_size();
// Determine the maximum possible length of the cluster chain from the
// file size. This acts as a safety rail against cycles in the FAT.
let max_clusters = self.max_sz.div_ceil(bpc as _);
// Calculate the starting position.
let start_cluster_idx = (offset / bpc as u64) as usize;
let offset_in_first_cluster = (offset % bpc as u64) as usize;
let start_sector_idx_in_cluster = offset_in_first_cluster / sector_size;
let offset_in_first_sector = offset_in_first_cluster % sector_size;
// Get the cluster iterator and advance it to our starting cluster.
let mut cluster_iter = self.fs.iter_clusters(self.root).take(max_clusters as _);
if let Some(cluster_result) = cluster_iter.nth(start_cluster_idx) {
let cluster = cluster_result?;
let mut sectors = self.fs.cluster_to_sectors(cluster)?;
// Advance the sector iterator to the correct starting sector.
if let Some(sector) = sectors.nth(start_sector_idx_in_cluster) {
// Read the first, possibly partial, chunk from the first sector.
let bytes_read = self
.fs
.read_sector(sector, offset_in_first_sector, &mut buf[total_bytes_read..])
.await?;
total_bytes_read += bytes_read;
// Read any remaining full sectors within this first cluster.
for sector in sectors {
if total_bytes_read >= bytes_to_read {
break;
}
let buf_slice = &mut buf[total_bytes_read..];
let bytes_read = self.fs.read_sector(sector, 0, buf_slice).await?;
total_bytes_read += bytes_read;
}
}
}
'aligned_loop: for cluster_result in cluster_iter {
if total_bytes_read >= bytes_to_read {
break;
}
let cluster = cluster_result?;
// Read all sectors in a full cluster.
for sector in self.fs.cluster_to_sectors(cluster)? {
if total_bytes_read >= bytes_to_read {
break 'aligned_loop;
}
let buf_slice = &mut buf[total_bytes_read..];
let bytes_read = self.fs.read_sector(sector, 0, buf_slice).await?;
total_bytes_read += bytes_read;
}
}
// The bounds checks on `effective_buf` throughout the loops ensure that
// the final `read_sector` call will be passed a smaller slice,
// correctly reading only the remaining bytes and handling the tail
// misalignment automagically. See `self.fs.read_sector`.
Ok(total_bytes_read)
}
}

View File

@@ -0,0 +1 @@
pub mod fat32;

207
libkernel/src/fs/mod.rs Normal file
View File

@@ -0,0 +1,207 @@
//! Virtual Filesystem (VFS) Interface Definitions
//!
//! This module defines the core traits and data structures for the kernel's I/O subsystem.
//! It is based on a layered design:
//!
//! 1. `BlockDevice`: An abstraction for raw block-based hardware (e.g., disks).
//! 2. `Filesystem`: An abstraction for a mounted filesystem instance (e.g.,
//! ext4, fat32). Its main role is to provide the root `Inode`.
//! 3. `Inode`: A stateless representation of a filesystem object (file,
//! directory, etc.). It handles operations by explicit offsets (`read_at`,
//! `write_at`).
//! 4. `File`: A stateful open file handle. It maintains a cursor and provides
//! the familiar `read`, `write`, and `seek` operations.
extern crate alloc;
pub mod attr;
pub mod blk;
pub mod filesystems;
pub mod path;
pub mod pathbuf;
use crate::{
driver::CharDevDescriptor,
error::{FsError, KernelError, Result},
};
use alloc::{boxed::Box, string::String, sync::Arc};
use async_trait::async_trait;
use attr::FileAttr;
bitflags::bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct OpenFlags: u32 {
const O_RDONLY = 0b000;
const O_WRONLY = 0b001;
const O_RDWR = 0b010;
const O_ACCMODE = 0b011;
const O_CREAT = 0o100;
const O_EXCL = 0o200;
const O_TRUNC = 0o1000;
const O_DIRECTORY = 0o200000;
const O_APPEND = 0o2000;
const O_NONBLOCK = 0o4000;
const O_CLOEXEC = 0o2000000;
}
}
// Reserved psuedo filesystem instances created internally in the kernel.
pub const DEVFS_ID: u64 = 1;
pub const FS_ID_START: u64 = 10;
/// Trait for a mounted filesystem instance. Its main role is to act as a
/// factory for Inodes.
#[async_trait]
pub trait Filesystem: Send + Sync {
/// Get the root inode of this filesystem.
async fn root_inode(&self) -> Result<Arc<dyn Inode>>;
/// Returns the instance ID for this FS.
fn id(&self) -> u64;
}
// A unique identifier for an inode across the entire VFS. A tuple of
// (filesystem_id, inode_number).
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct InodeId(u64, u64);
impl InodeId {
pub fn from_fsid_and_inodeid(fs_id: u64, inode_id: u64) -> Self {
Self(fs_id, inode_id)
}
pub fn dummy() -> Self {
Self(u64::MAX, u64::MAX)
}
pub fn fs_id(self) -> u64 {
self.0
}
pub fn inode_id(self) -> u64 {
self.1
}
}
/// Standard POSIX file types.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum FileType {
File,
Directory,
Symlink,
BlockDevice(CharDevDescriptor),
CharDevice(CharDevDescriptor),
Fifo,
Socket,
}
/// A stateful, streaming iterator for reading directory entries.
#[async_trait]
pub trait DirStream: Send + Sync {
/// Fetches the next directory entry in the stream. Returns `Ok(None)` when
/// the end of the directory is reached.
async fn next_entry(&mut self) -> Result<Option<Dirent>>;
}
/// Represents a single directory entry.
#[derive(Debug, Clone)]
pub struct Dirent {
pub id: InodeId,
pub name: String,
pub file_type: FileType,
pub offset: u64,
}
impl Dirent {
pub fn new(name: String, id: InodeId, file_type: FileType, offset: u64) -> Self {
Self {
id,
name,
file_type,
offset,
}
}
}
/// Specifies how to seek within a file, mirroring `std::io::SeekFrom`.
#[derive(Debug, Copy, Clone)]
pub enum SeekFrom {
Start(u64),
End(i64),
Current(i64),
}
/// Trait for a raw block device.
#[async_trait]
pub trait BlockDevice: Send + Sync {
/// Read one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn read(&self, block_id: u64, buf: &mut [u8]) -> Result<()>;
/// Write one or more blocks starting at `block_id`.
/// The `buf` length must be a multiple of `block_size`.
async fn write(&self, block_id: u64, buf: &[u8]) -> Result<()>;
/// The size of a single block in bytes.
fn block_size(&self) -> usize;
/// Flushes any caches to the underlying device.
async fn sync(&self) -> Result<()>;
}
/// A stateless representation of a filesystem object.
///
/// This trait represents an object on the disk (a file, a directory, etc.). All
/// operations are stateless from the VFS's perspective; for instance, `read_at`
/// takes an explicit offset instead of using a hidden cursor.
#[async_trait]
pub trait Inode: Send + Sync {
/// Get the unique ID for this inode.
fn id(&self) -> InodeId;
/// Reads data from the inode at a specific `offset`.
/// Returns the number of bytes read.
async fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<usize> {
Err(KernelError::NotSupported)
}
/// Writes data to the inode at a specific `offset`.
/// Returns the number of bytes written.
async fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<usize> {
Err(KernelError::NotSupported)
}
/// Truncates the inode to a specific `size`.
async fn truncate(&self, _size: u64) -> Result<()> {
Err(KernelError::NotSupported)
}
/// Gets the metadata for this inode.
async fn getattr(&self) -> Result<FileAttr> {
Err(KernelError::NotSupported)
}
/// Looks up a name within a directory, returning the corresponding inode.
async fn lookup(&self, _name: &str) -> Result<Arc<dyn Inode>> {
Err(KernelError::NotSupported)
}
/// Creates a new object within a directory.
async fn create(
&self,
_name: &str,
_file_type: FileType,
_permissions: u16,
) -> Result<Arc<dyn Inode>> {
Err(KernelError::NotSupported)
}
/// Removes a link to an inode from a directory.
async fn unlink(&self, _name: &str) -> Result<()> {
Err(KernelError::NotSupported)
}
/// Reads the contents of a directory.
async fn readdir(&self, _start_offset: u64) -> Result<Box<dyn DirStream>> {
Err(FsError::NotADirectory.into())
}
}

337
libkernel/src/fs/path.rs Normal file
View File

@@ -0,0 +1,337 @@
//! A module for path manipulation that works with string slices.
//!
//! This module provides a `Path` struct that is a thin wrapper around `&str`,
//! offering various methods for path inspection and manipulation.
use alloc::vec::Vec;
use super::pathbuf::PathBuf;
/// Represents a path slice, akin to `&str`.
///
/// This struct provides a number of methods for inspecting a path,
/// including breaking the path into its components, determining if it's
/// absolute, and more.
#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct Path {
inner: str,
}
impl Path {
/// Creates a new `Path` from a string slice.
///
/// This is a cost-free conversion.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// let path = Path::new("/usr/bin/ls");
/// ```
pub fn new<S: AsRef<str> + ?Sized>(s: &S) -> &Self {
unsafe { &*(s.as_ref() as *const str as *const Path) }
}
/// Returns the underlying string slice.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// let path = Path::new("/etc/passwd");
/// assert_eq!(path.as_str(), "/etc/passwd");
/// ```
pub fn as_str(&self) -> &str {
&self.inner
}
/// Determines whether the path is absolute.
///
/// An absolute path starts with a `/`.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// assert!(Path::new("/home/user").is_absolute());
/// assert!(!Path::new("home/user").is_absolute());
/// ```
pub fn is_absolute(&self) -> bool {
self.inner.starts_with('/')
}
/// Determines whether the path is relative.
///
/// A relative path does not start with a `/`.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// assert!(Path::new("src/main.rs").is_relative());
/// assert!(!Path::new("/src/main.rs").is_relative());
/// ```
pub fn is_relative(&self) -> bool {
!self.is_absolute()
}
/// Produces an iterator over the components of the path.
///
/// Components are the non-empty parts of the path separated by `/`.
/// Repeated separators are ignored.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// let path = Path::new("/some//path/./file.txt");
/// let components: Vec<_> = path.components().collect();
/// assert_eq!(components, vec!["some", "path", "file.txt"]);
/// ```
pub fn components(&self) -> Components<'_> {
Components {
remaining: &self.inner,
}
}
/// Joins two paths together.
///
/// This will allocate a new `PathBuf` to hold the joined path. If the
/// `other` path is absolute, it replaces the current one.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
/// use libkernel::fs::pathbuf::PathBuf;
//
/// let path1 = Path::new("/usr/local");
/// let path2 = Path::new("bin/rustc");
/// assert_eq!(path1.join(path2), PathBuf::from("/usr/local/bin/rustc"));
///
/// let path3 = Path::new("/etc/init.d");
/// assert_eq!(path1.join(path3), PathBuf::from("/etc/init.d"));
/// ```
pub fn join(&self, other: &Path) -> PathBuf {
let mut ret: PathBuf = PathBuf::with_capacity(self.inner.len() + other.inner.len());
ret.push(self);
ret.push(other);
ret
}
/// Strips a prefix from the path.
///
/// If the path starts with `base`, returns a new `Path` slice with the
/// prefix removed. Otherwise, returns `None`.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// let path = Path::new("/usr/lib/x86_64-linux-gnu");
/// let prefix = Path::new("/usr");
///
/// assert_eq!(path.strip_prefix(prefix), Some(Path::new("lib/x86_64-linux-gnu")));
/// assert_eq!(prefix.strip_prefix(path), None);
/// ```
pub fn strip_prefix(&self, base: &Path) -> Option<&Path> {
if self.inner.starts_with(&base.inner) {
// If the prefixes are the same and they have the same length, the
// whole string is the prefix.
if base.inner.len() == self.inner.len() {
return Some(Path::new(""));
}
if self.inner.as_bytes().get(base.inner.len()) == Some(&b'/') {
let stripped = &self.inner[base.inner.len()..];
// If the base ends with a slash, we don't want a leading slash on the result
if base.inner.ends_with('/') {
return Some(Path::new(stripped));
}
return Some(Path::new(&stripped[1..]));
}
}
None
}
/// Returns the parent directory of the path.
///
/// Returns `None` if the path is a root directory or has no parent.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// let path = Path::new("/foo/bar");
/// assert_eq!(path.parent(), Some(Path::new("/foo")));
///
/// let path_root = Path::new("/");
/// assert_eq!(path_root.parent(), None);
/// ```
pub fn parent(&self) -> Option<&Path> {
let mut components = self.components().collect::<Vec<_>>();
if components.len() <= 1 {
return None;
}
components.pop();
let parent_len = components.iter().map(|s| s.len()).sum::<usize>() + components.len() - 1;
let end = if self.is_absolute() {
parent_len + 1
} else {
parent_len
};
Some(Path::new(&self.inner[..end]))
}
/// Returns the final component of the path, if there is one.
///
/// This is the file name for a file path, or the directory name for a
/// directory path.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
///
/// assert_eq!(Path::new("/home/user/file.txt").file_name(), Some("file.txt"));
/// assert_eq!(Path::new("/home/user/").file_name(), Some("user"));
/// assert_eq!(Path::new("/").file_name(), None);
/// ```
pub fn file_name(&self) -> Option<&str> {
self.components().last()
}
}
impl AsRef<Path> for str {
fn as_ref(&self) -> &Path {
Path::new(self)
}
}
/// An iterator over the components of a `Path`.
#[derive(Clone, Debug)]
pub struct Components<'a> {
remaining: &'a str,
}
impl<'a> Iterator for Components<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
// Trim leading slashes
self.remaining = self.remaining.trim_start_matches('/');
if self.remaining.is_empty() {
return None;
}
match self.remaining.find('/') {
Some(index) => {
let component = &self.remaining[..index];
self.remaining = &self.remaining[index..];
if component == "." {
self.next()
} else {
Some(component)
}
}
None => {
let component = self.remaining;
self.remaining = "";
if component == "." {
self.next()
} else {
Some(component)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::Path;
use alloc::vec::Vec;
#[test]
fn test_new_path() {
let p = Path::new("/a/b/c");
assert_eq!(p.as_str(), "/a/b/c");
}
#[test]
fn test_is_absolute() {
assert!(Path::new("/").is_absolute());
assert!(Path::new("/a/b").is_absolute());
assert!(!Path::new("a/b").is_absolute());
assert!(!Path::new("").is_absolute());
}
#[test]
fn test_is_relative() {
assert!(Path::new("a/b").is_relative());
assert!(Path::new("").is_relative());
assert!(!Path::new("/a/b").is_relative());
}
#[test]
fn test_components() {
let p = Path::new("/a/b/c");
let mut comps = p.components();
assert_eq!(comps.next(), Some("a"));
assert_eq!(comps.next(), Some("b"));
assert_eq!(comps.next(), Some("c"));
assert_eq!(comps.next(), None);
let p2 = Path::new("a//b/./c/");
assert_eq!(p2.components().collect::<Vec<_>>(), vec!["a", "b", "c"]);
}
#[test]
fn test_join() {
assert_eq!(Path::new("/a/b").join(Path::new("c/d")), "/a/b/c/d".into());
assert_eq!(Path::new("/a/b/").join(Path::new("c")), "/a/b/c".into());
assert_eq!(Path::new("a").join(Path::new("b")), "a/b".into());
assert_eq!(Path::new("/a").join(Path::new("/b")), "/b".into());
assert_eq!(Path::new("").join(Path::new("a")), "a".into());
}
#[test]
fn test_strip_prefix() {
let p = Path::new("/a/b/c");
assert_eq!(p.strip_prefix(Path::new("/a")), Some(Path::new("b/c")));
assert_eq!(p.strip_prefix(Path::new("/a/b")), Some(Path::new("c")));
assert_eq!(p.strip_prefix(Path::new("/a/b/c")), Some(Path::new("")));
assert_eq!(p.strip_prefix(Path::new("/d")), None);
assert_eq!(p.strip_prefix(Path::new("/a/b/c/d")), None);
assert_eq!(Path::new("/a/bc").strip_prefix(Path::new("/a/b")), None);
}
#[test]
fn test_parent() {
assert_eq!(Path::new("/a/b/c").parent(), Some(Path::new("/a/b")));
assert_eq!(Path::new("/a/b").parent(), Some(Path::new("/a")));
assert_eq!(Path::new("/a").parent(), None);
assert_eq!(Path::new("/").parent(), None);
assert_eq!(Path::new("a/b").parent(), Some(Path::new("a")));
assert_eq!(Path::new("a").parent(), None);
}
#[test]
fn test_file_name() {
assert_eq!(Path::new("/a/b/c.txt").file_name(), Some("c.txt"));
assert_eq!(Path::new("/a/b/").file_name(), Some("b"));
assert_eq!(Path::new("c.txt").file_name(), Some("c.txt"));
assert_eq!(Path::new("/").file_name(), None);
assert_eq!(Path::new(".").file_name(), None);
}
}

226
libkernel/src/fs/pathbuf.rs Normal file
View File

@@ -0,0 +1,226 @@
//! A module for owned, mutable paths.
//!
//! This module provides a `PathBuf` struct that is an owned, mutable counterpart
//! to the `Path` slice. It provides methods for manipulating the path in place,
//! such as `push` and `pop`.
use super::path::Path;
use alloc::string::String;
use core::ops::Deref;
/// An owned, mutable path, akin to `String`.
///
/// This struct provides methods like `push` and `pop` that mutate the path
/// in place. It also implements `Deref` to `Path`, meaning that all methods
/// on `Path` slices are available on `PathBuf` values as well.
#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Default)]
pub struct PathBuf {
inner: String,
}
impl PathBuf {
/// Creates a new, empty `PathBuf`.
///
/// # Examples
///
/// ```
/// use libkernel::fs::pathbuf::PathBuf;
///
/// let path = PathBuf::new();
/// ```
pub fn new() -> Self {
Self {
inner: String::new(),
}
}
/// Creates a new `PathBuf` with a given capacity.
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: String::with_capacity(capacity),
}
}
/// Coerces to a `Path` slice.
///
/// # Examples
///
/// ```
/// use libkernel::fs::path::Path;
/// use libkernel::fs::pathbuf::PathBuf;
///
/// let p = PathBuf::from("/test");
/// assert_eq!(Path::new("/test"), p.as_path());
/// ```
pub fn as_path(&self) -> &Path {
self
}
/// Extends `self` with `path`.
///
/// If `path` is absolute, it replaces the current path.
///
/// # Examples
///
/// ```
/// use libkernel::fs::pathbuf::PathBuf;
///
/// let mut path = PathBuf::from("/usr");
/// path.push("bin");
/// assert_eq!(path.as_str(), "/usr/bin");
///
/// let mut path2 = PathBuf::from("/tmp");
/// path2.push("/etc/passwd");
/// assert_eq!(path2.as_str(), "/etc/passwd");
/// ```
pub fn push<P: AsRef<Path>>(&mut self, path: P) {
let path = path.as_ref();
if path.is_absolute() {
self.inner = path.as_str().into();
return;
}
if !self.inner.is_empty() && !self.inner.ends_with('/') {
self.inner.push('/');
}
self.inner.push_str(path.as_str());
}
/// Truncates `self` to its parent.
///
/// Returns `true` if the path was truncated, `false` otherwise.
///
/// # Examples
///
/// ```
/// use libkernel::fs::pathbuf::PathBuf;
///
/// let mut path = PathBuf::from("/a/b/c");
/// assert!(path.pop());
/// assert_eq!(path.as_str(), "/a/b");
/// assert!(path.pop());
/// assert_eq!(path.as_str(), "/a");
/// assert!(!path.pop());
/// assert_eq!(path.as_str(), "/a");
/// ```
pub fn pop(&mut self) -> bool {
match self.as_path().parent() {
Some(parent) => {
self.inner.truncate(parent.as_str().len());
true
}
None => false,
}
}
/// Updates the file name of the path.
///
/// If there is no file name, it is appended. Otherwise, the existing
/// file name is replaced.
///
/// # Examples
///
/// ```
/// use libkernel::fs::pathbuf::PathBuf;
///
/// let mut path = PathBuf::from("/tmp/foo");
/// path.set_file_name("bar.txt");
/// assert_eq!(path.as_str(), "/tmp/bar.txt");
///
/// let mut path2 = PathBuf::from("/tmp");
/// path2.set_file_name("foo");
/// assert_eq!(path2.as_str(), "/tmp/foo");
/// ```
pub fn set_file_name<S: AsRef<str>>(&mut self, file_name: S) {
if self.as_path().file_name().is_some() {
self.pop();
}
self.push(Path::new(file_name.as_ref()));
}
}
impl AsRef<Path> for Path {
fn as_ref(&self) -> &Path {
self
}
}
impl<T: AsRef<str>> From<T> for PathBuf {
fn from(s: T) -> Self {
Self {
inner: s.as_ref().into(),
}
}
}
impl Deref for PathBuf {
type Target = Path;
fn deref(&self) -> &Self::Target {
Path::new(&self.inner)
}
}
#[cfg(test)]
mod tests {
use super::PathBuf;
#[test]
fn test_push() {
let mut p = PathBuf::from("/a/b");
p.push("c");
assert_eq!(p.as_str(), "/a/b/c");
let mut p2 = PathBuf::from("/a/b/");
p2.push("c");
assert_eq!(p2.as_str(), "/a/b/c");
let mut p3 = PathBuf::from("a");
p3.push("b");
assert_eq!(p3.as_str(), "a/b");
let mut p4 = PathBuf::new();
p4.push("a");
assert_eq!(p4.as_str(), "a");
let mut p5 = PathBuf::from("/a/b/");
p5.push("c/d");
assert_eq!(p5.as_str(), "/a/b/c/d");
}
#[test]
fn test_pop() {
let mut p = PathBuf::from("/a/b/c");
assert!(p.pop());
assert_eq!(p.as_str(), "/a/b");
assert!(p.pop());
assert_eq!(p.as_str(), "/a");
assert!(!p.pop());
assert_eq!(p.as_str(), "/a");
let mut p2 = PathBuf::from("a/b");
assert!(p2.pop());
assert_eq!(p2.as_str(), "a");
assert!(!p2.pop());
}
#[test]
fn test_set_file_name() {
let mut p = PathBuf::from("/a/b");
p.set_file_name("c");
assert_eq!(p.as_str(), "/a/c");
let mut p2 = PathBuf::from("/a/");
p2.set_file_name("b");
assert_eq!(p2.as_str(), "/a/b");
}
#[test]
fn test_deref() {
let p = PathBuf::from("/a/b/c");
assert!(p.is_absolute());
}
}

269
libkernel/src/lib.rs Normal file
View File

@@ -0,0 +1,269 @@
#![cfg_attr(not(test), no_std)]
use alloc::vec::Vec;
use error::Result;
use memory::{
address::VA,
page::PageFrame,
permissions::PtePermissions,
region::{PhysMemoryRegion, VirtMemoryRegion},
};
use sync::spinlock::SpinLockIrq;
pub mod arch;
pub mod driver;
pub mod error;
pub mod fs;
pub mod memory;
pub mod pod;
pub mod proc;
pub mod sync;
extern crate alloc;
pub trait CpuOps {
/// Returns the ID of the currently executing core.
fn id() -> usize;
/// Halts the CPU indefinitely.
fn halt() -> !;
/// Disables all maskable interrupts on the current CPU core, returning the
/// previous state prior to masking.
fn disable_interrupts() -> usize;
/// Restore the previous interrupt state obtained from `disable_interrupts`.
fn restore_interrupt_state(flags: usize);
/// Explicitly enables maskable interrupts on the current CPU core.
fn enable_interrupts();
}
/// An architecture-independent representation of a page table entry (PTE).
pub struct PageInfo {
pub pfn: PageFrame,
pub perms: PtePermissions,
}
/// Represents a process's memory context, abstracting the hardware-specific
/// details of page tables, address translation, and TLB management.
///
/// This trait defines the fundamental interface that the kernel's
/// architecture-independent memory management code uses to interact with an
/// address space. Each supported architecture must provide a concrete
/// implementation.
pub trait UserAddressSpace: Send + Sync {
/// Creates a new, empty page table hierarchy for a new process.
///
/// The resulting address space should be configured for user space access
/// but will initially contain no user mappings. It is the responsibility of
/// the implementation to also map the kernel's own address space into the
/// upper region of the virtual address space, making kernel code and data
/// accessible.
///
/// # Returns
///
/// `Ok(Self)` on success, or an `Err` if memory for the top-level page
/// table could not be allocated.
fn new() -> Result<Self>
where
Self: Sized;
/// Activates this address space for the current CPU.
///
/// The implementation must load the address of this space's root page table
/// into the appropriate CPU register (e.g., `CR3` on x86-64 or `TTBR0_EL1`
/// on AArch64). This action makes the virtual address mappings defined by
/// this space active on the current core.
///
/// The implementation must also ensure that any necessary TLB invalidations
/// occur so that stale translations from the previously active address
/// space are flushed.
fn activate(&self);
/// Decativate this address space for the current CPU.
///
/// This should be called to leave the CPU without any current process
/// state. Used on process termination code-paths.
fn deactivate(&self);
/// Maps a single physical page frame to a virtual address.
///
/// This function creates a page table entry (PTE) that maps the given
/// physical `page` to the specified virtual address `va` with the provided
/// `perms`. The implementation must handle the allocation and setup of any
/// intermediate page tables (e.g., L1 or L2 tables) if they do not already
/// exist.
///
/// # Arguments
///
/// * `page`: The `PageFrame` of physical memory to map.
/// * `va`: The page-aligned virtual address to map to.
/// * `perms`: The `PtePermissions` to apply to the mapping.
///
/// # Errors
///
/// Returns an error if a mapping already exists at `va` or if memory for
/// intermediate page tables cannot be allocated.
fn map_page(&mut self, page: PageFrame, va: VA, perms: PtePermissions) -> Result<()>;
/// Unmaps a single virtual page, returning the physical page it was mapped
/// to.
///
/// The implementation must invalidate the PTE for the given `va` and ensure
/// the corresponding TLB entry is flushed.
///
/// # Returns
///
/// The `PageFrame` that was previously mapped at `va`. This allows the
/// caller to manage the lifecycle of the physical memory (e.g., decrement a
/// reference count or free it). Returns an error if no page is mapped at
/// `va`.
fn unmap(&mut self, va: VA) -> Result<PageFrame>;
/// Atomically unmaps a page at `va` and maps a new page in its place.
///
/// # Returns
///
/// The `PageFrame` of the *previously* mapped page, allowing the caller to
/// manage its lifecycle. Returns an error if no page was originally mapped
/// at `va`.
fn remap(&mut self, va: VA, new_page: PageFrame, perms: PtePermissions) -> Result<PageFrame>;
/// Changes the protection flags for a range of virtual addresses.
///
/// This is the low-level implementation for services like `mprotect`. It
/// walks the page tables for the given `va_range` and updates the
/// permissions of each PTE to match `perms`.
///
/// The implementation must ensure that the TLB is invalidated for the
/// entire range.
fn protect_range(&mut self, va_range: VirtMemoryRegion, perms: PtePermissions) -> Result<()>;
/// Unmaps an entire range of virtual addresses.
///
/// This is the low-level implementation for services like `munmap`. It
/// walks the page tables for the given `va_range` and invalidates all PTEs
/// within it.
///
/// # Returns
///
/// A `Vec<PageFrame>` containing all the physical frames that were
/// unmapped. This allows the caller to free all associated physical memory.
fn unmap_range(&mut self, va_range: VirtMemoryRegion) -> Result<Vec<PageFrame>>;
/// Translates a virtual address to its corresponding physical mapping
/// information.
///
/// This function performs a page table walk to find the PTE for the given
/// `va`. It is a read-only operation used by fault handlers and other
/// kernel subsystems to inspect the state of a mapping.
///
/// # Returns
///
/// `Some(PageInfo)` containing the mapped physical frame and the *actual*
/// hardware and software permissions from the PTE, or `None` if no valid
/// mapping exists for `va`.
fn translate(&self, va: VA) -> Option<PageInfo>;
/// Atomically protects a region in the source address space and clones the
/// mappings into a destination address space.
///
/// This is the core operation for implementing an efficient `fork()` system
/// call using Copy-on-Write (CoW). It is specifically optimized to perform
/// the protection of the parent's pages and the creation of the child's
/// shared mappings in a single, atomic walk of the source page tables.
///
/// ## Operation
///
/// For each present Page Table Entry (PTE) within the `region` of `self`:
///
/// 1. The physical frame's reference count is incremented to reflect the
/// new mapping in `other`. This *must* be done by the
/// architecture-specific implementation.
/// 2. The PTE in the `other` (child) address space is created, mapping the
/// same physical frame at the same virtual address with the given
/// `perms`.
/// 3. The PTE in the `self` (parent) address space has its permissions
/// updated to `perms`.
///
/// # Arguments
///
/// * `&mut self`: The source (parent) address space. Its permissions for
/// the region will be modified.
/// * `region`: The virtual memory region to operate on.
/// * `other`: The destination (child) address space.
/// * `perms`: The `PtePermissions` to set for the mappings in both `self`
/// and `other`.
fn protect_and_clone_region(
&mut self,
region: VirtMemoryRegion,
other: &mut Self,
perms: PtePermissions,
) -> Result<()>
where
Self: Sized;
}
/// Represents the kernel's memory context.
pub trait KernAddressSpace: Send {
/// Map the given region as MMIO memory.
fn map_mmio(&mut self, region: PhysMemoryRegion) -> Result<VA>;
/// Map the given region as normal memory.
fn map_normal(
&mut self,
phys_range: PhysMemoryRegion,
virt_range: VirtMemoryRegion,
perms: PtePermissions,
) -> Result<()>;
}
/// The types and functions required for the virtual memory subsystem.
pub trait VirtualMemory: CpuOps + Sized {
/// The type representing an entry in the top-level page table. For AArch64
/// this is L0, for x86_64 it's PML4.
type PageTableRoot;
/// The address space type used for all user-space processes.
type ProcessAddressSpace: UserAddressSpace;
/// The address space type used for the kernel.
type KernelAddressSpace: KernAddressSpace;
/// The starting address for the logical mapping of all physical ram.
const PAGE_OFFSET: usize;
/// Obtain a reference to the kernel's address space.
fn kern_address_space() -> &'static SpinLockIrq<Self::KernelAddressSpace, Self>;
}
#[cfg(test)]
pub mod test {
use core::hint::spin_loop;
use crate::CpuOps;
// A CPU mock object that can be used in unit-tests.
pub struct MockCpuOps {}
impl CpuOps for MockCpuOps {
fn id() -> usize {
0
}
fn halt() -> ! {
loop {
spin_loop();
}
}
fn disable_interrupts() -> usize {
0
}
fn restore_interrupt_state(_flags: usize) {}
fn enable_interrupts() {}
}
}

View File

@@ -0,0 +1,395 @@
//! `address` module: Type-safe handling of virtual and physical addresses.
//!
//! This module defines strongly-typed address representations for both physical
//! and virtual memory. It provides abstractions to ensure correct usage and
//! translation between address spaces, as well as alignment and page-based
//! operations.
//!
//! ## Key Features
//!
//! - Typed Addresses: Differentiates between physical and virtual addresses at
//! compile time.
//! - Generic Address Types: Supports pointer-type phantom typing for safety and
//! clarity.
//! - Page Alignment & Arithmetic: Includes utility methods for page alignment
//! and manipulation.
//! - Translators: Provides an `AddressTranslator` trait for converting between
//! physical and virtual addresses.
//!
//! ## Safety
//! Raw pointer access to physical addresses (`TPA<T>::as_ptr`, etc.) is marked
//! `unsafe` and assumes the caller ensures correctness (e.g., MMU off/ID
//! mappings).
//!
//! ## Example
//! ```rust
//! use libkernel::memory::address::*;
//!
//! let pa: PA = PA::from_value(0x1000);
//! assert!(pa.is_page_aligned());
//!
//! let va = pa.to_va::<IdentityTranslator>();
//! let pa2 = va.to_pa::<IdentityTranslator>();
//! assert_eq!(pa, pa2);
//! ```
use super::{PAGE_MASK, PAGE_SHIFT, PAGE_SIZE, page::PageFrame, region::MemoryRegion};
use core::{
fmt::{self, Debug, Display},
marker::PhantomData,
};
mod sealed {
/// Sealed trait to prevent external implementations of `MemKind`.
pub trait Sealed {}
}
/// Marker trait for kinds of memory addresses (virtual or physical).
///
/// Implemented only for `Virtual` and `Physical`.
pub trait MemKind: sealed::Sealed + Ord + Clone + Copy + PartialEq + Eq {}
/// Marker for virtual memory address type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Virtual;
/// Marker for physical memory address type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Physical;
/// Marker for user memory address type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct User;
impl sealed::Sealed for Virtual {}
impl sealed::Sealed for Physical {}
impl sealed::Sealed for User {}
impl MemKind for Virtual {}
impl MemKind for Physical {}
impl MemKind for User {}
/// A memory address with a kind (`Virtual`, `Physical`, or `User`) and an
/// associated data type.
///
/// The `T` phantom type is useful for distinguishing address purposes (e.g.,
/// different hardware devices).
#[derive(PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct Address<K: MemKind, T> {
inner: usize,
_phantom: PhantomData<K>,
_phantom_type: PhantomData<T>,
}
impl<K: MemKind, T> Clone for Address<K, T> {
fn clone(&self) -> Self {
*self
}
}
impl<K: MemKind, T> Copy for Address<K, T> {}
impl<K: MemKind, T> Debug for Address<K, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "0x{:04x}", self.inner)
}
}
impl<K: MemKind, T> Address<K, T> {
/// Construct an address from a raw usize value.
pub const fn from_value(addr: usize) -> Self {
Self {
inner: addr,
_phantom: PhantomData,
_phantom_type: PhantomData,
}
}
/// Return the underlying raw address value.
pub const fn value(self) -> usize {
self.inner
}
/// Check if the address is aligned to the system's page size.
pub const fn is_page_aligned(self) -> bool {
self.inner & PAGE_MASK == 0
}
/// Add `count` pages to the address.
#[must_use]
pub const fn add_pages(self, count: usize) -> Self {
Self::from_value(self.inner + (PAGE_SIZE * count))
}
/// Return an address aligned down to `align` (must be a power of two).
#[must_use]
pub const fn align(self, align: usize) -> Self {
assert!(align.is_power_of_two());
Self::from_value(self.inner & !(align - 1))
}
/// Return an address aligned down to the next page boundary.
#[must_use]
pub const fn page_aligned(self) -> Self {
Self::from_value(self.inner & !PAGE_MASK)
}
/// Return an address aligned up to `align` (must be a power of two).
#[must_use]
pub const fn align_up(self, align: usize) -> Self {
assert!(align.is_power_of_two());
Self::from_value((self.inner + (align - 1)) & !(align - 1))
}
/// Get the offset of the address within its page.
pub fn page_offset(self) -> usize {
self.inner & PAGE_MASK
}
pub const fn null() -> Self {
Self::from_value(0)
}
#[must_use]
pub fn add_bytes(self, n: usize) -> Self {
Self::from_value(self.value() + n)
}
#[must_use]
pub fn sub_bytes(self, n: usize) -> Self {
Self::from_value(self.value() - n)
}
pub fn is_null(self) -> bool {
self.inner == 0
}
pub fn to_pfn(&self) -> PageFrame {
PageFrame::from_pfn(self.inner >> PAGE_SHIFT)
}
}
impl<K: MemKind, T: Sized> Address<K, T> {
#[must_use]
/// Increments the pointer by the number of T *objects* n.
pub fn add_objs(self, n: usize) -> Self {
Self::from_value(self.value() + core::mem::size_of::<T>() * n)
}
/// Increments the pointer by the number of T *objects* n.
pub fn sub_objs(self, n: usize) -> Self {
Self::from_value(self.value() - core::mem::size_of::<T>() * n)
}
}
/// A typed physical address.
pub type TPA<T> = Address<Physical, T>;
/// A typed virtual address.
pub type TVA<T> = Address<Virtual, T>;
/// A typed user address.
pub type TUA<T> = Address<User, T>;
/// An untyped physical address.
pub type PA = Address<Physical, ()>;
/// An untyped virtual address.
pub type VA = Address<Virtual, ()>;
/// An untyped user address.
pub type UA = Address<User, ()>;
impl<T> TPA<T> {
/// Convert to a raw const pointer.
///
/// # Safety
/// Caller must ensure memory is accessible and valid for read.
pub unsafe fn as_ptr(self) -> *const T {
self.value() as _
}
/// Convert to a raw mutable pointer.
///
/// # Safety
/// Caller must ensure memory is accessible and valid for read/write.
pub unsafe fn as_ptr_mut(self) -> *mut T {
self.value() as _
}
/// Convert to an untyped physical address.
pub fn to_untyped(self) -> PA {
PA::from_value(self.value())
}
/// Convert to a virtual address using a translator.
pub fn to_va<A: AddressTranslator<T>>(self) -> TVA<T> {
A::phys_to_virt(self)
}
}
impl<T> Display for TPA<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Px{:08x}", self.inner)
}
}
impl<T> TVA<T> {
/// Convert to a raw const pointer.
pub fn as_ptr(self) -> *const T {
self.value() as _
}
/// Convert to a raw mutable pointer.
pub fn as_ptr_mut(self) -> *mut T {
self.value() as _
}
/// Convert a raw pointer to a TVA.
pub fn from_ptr(ptr: *const T) -> TVA<T> {
Self::from_value(ptr.addr())
}
/// Convert a raw mutable pointer to a TVA.
pub fn from_ptr_mut(ptr: *mut T) -> TVA<T> {
Self::from_value(ptr.addr())
}
/// Convert to an untyped virtual address.
pub fn to_untyped(self) -> VA {
VA::from_value(self.value())
}
/// Convert to a physical address using a translator.
pub fn to_pa<A: AddressTranslator<T>>(self) -> TPA<T> {
A::virt_to_phys(self)
}
}
impl<T> TUA<T> {
/// Convert to an untyped user address.
pub fn to_untyped(self) -> UA {
UA::from_value(self.value())
}
}
impl UA {
/// Cast to a typed user address.
pub fn cast<T>(self) -> TUA<T> {
TUA::from_value(self.value())
}
}
impl VA {
/// Cast to a typed virtual address.
pub fn cast<T>(self) -> TVA<T> {
TVA::from_value(self.value())
}
/// Return a region representing the page to which this address belongs.
pub fn page_region(self) -> MemoryRegion<Virtual> {
MemoryRegion::new(self.page_aligned().cast(), PAGE_SIZE)
}
}
impl PA {
/// Cast to a typed physical address.
pub fn cast<T>(self) -> TPA<T> {
TPA::from_value(self.value())
}
}
/// Trait for translating between physical and virtual addresses.
pub trait AddressTranslator<T> {
fn virt_to_phys(va: TVA<T>) -> TPA<T>;
fn phys_to_virt(pa: TPA<T>) -> TVA<T>;
}
/// A simple address translator that performs identity mapping.
pub struct IdentityTranslator {}
impl<T> AddressTranslator<T> for IdentityTranslator {
fn virt_to_phys(va: TVA<T>) -> TPA<T> {
TPA::from_value(va.value())
}
fn phys_to_virt(pa: TPA<T>) -> TVA<T> {
TVA::from_value(pa.value())
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_ADDR: usize = 0x1000;
const NON_ALIGNED_ADDR: usize = 0x1234;
#[test]
fn test_va_creation_and_value() {
let va = VA::from_value(TEST_ADDR);
assert_eq!(va.value(), TEST_ADDR);
}
#[test]
fn test_pa_creation_and_value() {
let pa = PA::from_value(TEST_ADDR);
assert_eq!(pa.value(), TEST_ADDR);
}
#[test]
fn test_page_alignment() {
let aligned_va = VA::from_value(0x2000);
let unaligned_va = VA::from_value(0x2001);
assert!(aligned_va.is_page_aligned());
assert!(!unaligned_va.is_page_aligned());
}
#[test]
fn test_add_pages() {
let base = VA::from_value(0x1000);
let added = base.add_pages(2);
assert_eq!(added.value(), 0x1000 + 2 * PAGE_SIZE);
}
#[test]
fn test_align_down() {
let addr = VA::from_value(NON_ALIGNED_ADDR);
let aligned = addr.align(0x1000);
assert_eq!(aligned.value(), NON_ALIGNED_ADDR & !0xfff);
}
#[test]
fn test_align_up() {
let addr = VA::from_value(NON_ALIGNED_ADDR);
let aligned = addr.align_up(0x1000);
let expected = (NON_ALIGNED_ADDR + 0xfff) & !0xfff;
assert_eq!(aligned.value(), expected);
}
#[test]
fn test_identity_translation_va_to_pa() {
let va = VA::from_value(TEST_ADDR);
let pa = va.to_pa::<IdentityTranslator>();
assert_eq!(pa.value(), TEST_ADDR);
}
#[test]
fn test_identity_translation_pa_to_va() {
let pa = PA::from_value(TEST_ADDR);
let va = pa.to_va::<IdentityTranslator>();
assert_eq!(va.value(), TEST_ADDR);
}
#[test]
fn test_va_pointer_conversion() {
let va = VA::from_value(0x1000);
let ptr: *const u8 = va.as_ptr() as *const _;
assert_eq!(ptr as usize, 0x1000);
}
#[test]
fn test_va_mut_pointer_conversion() {
let va = VA::from_value(0x1000);
let ptr: *mut u8 = va.as_ptr_mut() as *mut _;
assert_eq!(ptr as usize, 0x1000);
}
}

View File

@@ -0,0 +1,713 @@
//! A page-backed async-aware circular kernel buffer.
use crate::{
CpuOps,
sync::{
spinlock::SpinLockIrq,
waker_set::{WakerSet, wait_until},
},
};
use alloc::sync::Arc;
use core::{cmp::min, future, mem::MaybeUninit, task::Poll};
use ringbuf::{
SharedRb,
storage::Storage,
traits::{Consumer, Observer, Producer, SplitRef},
};
struct KBufInner<T, S: Storage<Item = T>> {
buf: SharedRb<S>,
read_waiters: WakerSet,
write_waiters: WakerSet,
}
pub struct KBufCore<T, S: Storage<Item = T>, C: CpuOps> {
inner: Arc<SpinLockIrq<KBufInner<T, S>, C>>,
}
impl<T, S: Storage<Item = T>, C: CpuOps> Clone for KBufCore<T, S, C> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T, S: Storage<Item = T>, C: CpuOps> KBufCore<T, S, C> {
pub fn new(storage: S) -> Self {
let rb = unsafe { SharedRb::from_raw_parts(storage, 0, 0) };
Self {
inner: Arc::new(SpinLockIrq::new(KBufInner {
buf: rb,
read_waiters: WakerSet::new(),
write_waiters: WakerSet::new(),
})),
}
}
pub fn read_ready(&self) -> impl Future<Output = ()> + use<T, S, C> {
let lock = self.inner.clone();
wait_until(
lock,
|inner| &mut inner.read_waiters,
|inner| {
if inner.buf.is_empty() { None } else { Some(()) }
},
)
}
pub async fn write_ready(&self) {
wait_until(
self.inner.clone(),
|inner| &mut inner.write_waiters,
|inner| if inner.buf.is_full() { None } else { Some(()) },
)
.await
}
/// Pushes a value of type `T` into the buffer. If the buffer is full, this
/// function will wait for a slot.
pub async fn push(&self, mut obj: T) {
loop {
self.write_ready().await;
match self.try_push(obj) {
Ok(()) => return,
Err(o) => obj = o,
}
}
}
pub fn try_push(&self, obj: T) -> core::result::Result<(), T> {
let mut inner = self.inner.lock_save_irq();
let res = inner.buf.try_push(obj);
if res.is_ok() {
inner.read_waiters.wake_one();
}
res
}
pub async fn pop(&self) -> T {
loop {
self.read_ready().await;
if let Some(obj) = self.try_pop() {
return obj;
}
}
}
pub fn try_pop(&self) -> Option<T> {
let mut inner = self.inner.lock_save_irq();
let res = inner.buf.try_pop();
if res.is_some() {
inner.write_waiters.wake_one();
}
res
}
}
impl<T: Copy, S: Storage<Item = T>, C: CpuOps> KBufCore<T, S, C> {
pub async fn pop_slice(&self, buf: &mut [T]) -> usize {
wait_until(
self.inner.clone(),
|inner| &mut inner.read_waiters,
|inner| {
let size = inner.buf.pop_slice(buf);
if size != 0 {
// Wake up any writers that may be waiting on the pipe.
inner.write_waiters.wake_one();
Some(size)
} else {
// Sleep.
None
}
},
)
.await
}
pub fn try_pop_slice(&self, buf: &mut [T]) -> usize {
let mut guard = self.inner.lock_save_irq();
let size = guard.buf.pop_slice(buf);
if size > 0 {
guard.write_waiters.wake_one();
}
size
}
pub async fn push_slice(&self, buf: &[T]) -> usize {
wait_until(
self.inner.clone(),
|inner| &mut inner.write_waiters,
|inner| {
let bytes_written = inner.buf.push_slice(buf);
// If we didn't fill the buffer completely, other pending writes may be
// able to complete.
if !inner.buf.is_full() {
inner.write_waiters.wake_one();
}
if bytes_written > 0 {
// We wrote some data, wake up any blocking readers.
inner.read_waiters.wake_one();
Some(bytes_written)
} else {
// Sleep.
None
}
},
)
.await
}
pub fn try_push_slice(&self, buf: &[T]) -> usize {
let mut guard = self.inner.lock_save_irq();
let size = guard.buf.push_slice(buf);
if size > 0 {
guard.read_waiters.wake_one();
}
size
}
/// Moves up to `count` objs from `source` KBuf into `self`.
///
/// It performs a direct memory copy between the kernel buffers without an
/// intermediate stack buffer. It also handles async waiting and deadlock
/// avoidance.
pub async fn splice_from(&self, source: &KBufCore<T, S, C>, count: usize) -> usize {
if count == 0 {
return 0;
}
// Splicing from a buffer to itself is a no-op that would instantly
// deadlock.
if Arc::ptr_eq(&self.inner, &source.inner) {
return 0;
}
future::poll_fn(|cx| -> Poll<usize> {
// Lock two KBufs with the lower memory address first to prevent
// AB-BA deadlocks.
let self_ptr = Arc::as_ptr(&self.inner);
let source_ptr = Arc::as_ptr(&source.inner);
let (mut self_guard, mut source_guard) = if self_ptr < source_ptr {
(self.inner.lock_save_irq(), source.inner.lock_save_irq())
} else {
let source_g = source.inner.lock_save_irq();
let self_g = self.inner.lock_save_irq();
(self_g, source_g)
};
let (_, source_consumer) = source_guard.buf.split_ref();
let (mut self_producer, _) = self_guard.buf.split_ref();
// Determine the maximum number of bytes we can move in one go.
let bytes_to_move = min(
count,
min(source_consumer.occupied_len(), self_producer.vacant_len()),
);
if bytes_to_move > 0 {
// We can move data. Get the memory slices for direct copy.
let (src_head, src_tail) = source_consumer.occupied_slices();
let (dst_head, dst_tail) = self_producer.vacant_slices_mut();
// Perform the copy, which may involve multiple steps if the
// source or destination wraps around the end of the ring
// buffer.
let copied =
Self::copy_slices((src_head, src_tail), (dst_head, dst_tail), bytes_to_move);
// Advance the read/write heads in the ring buffers.
unsafe {
source_consumer.advance_read_index(copied);
self_producer.advance_write_index(copied);
}
drop(source_consumer);
drop(self_producer);
// Wake up anyone waiting for the opposite condition. A reader
// might be waiting for data in `self`.
self_guard.read_waiters.wake_one();
// A writer might be waiting for space in `source`.
source_guard.write_waiters.wake_one();
Poll::Ready(copied)
} else {
// We can't move data. We need to wait. If source is empty, we
// must wait for a writer on the source.
if source_consumer.is_empty() {
drop(source_consumer);
source_guard.read_waiters.register(cx.waker());
}
// If destination is full, we must wait for a reader on the
// destination.
if self_producer.is_full() {
drop(self_producer);
self_guard.write_waiters.register(cx.waker());
}
Poll::Pending
}
})
.await
}
/// Helper function to copy data between pairs of buffer slices.
fn copy_slices(
(src_head, src_tail): (&[MaybeUninit<T>], &[MaybeUninit<T>]),
(dst_head, dst_tail): (&mut [MaybeUninit<T>], &mut [MaybeUninit<T>]),
mut amount: usize,
) -> usize {
let original_amount = amount;
// Copy from src_head to dst_head
let n1 = min(amount, min(src_head.len(), dst_head.len()));
if n1 > 0 {
dst_head[..n1].copy_from_slice(&src_head[..n1]);
amount -= n1;
}
if amount == 0 {
return original_amount;
}
// Copy from the remainder of src_head to dst_tail
let src_after_head = &src_head[n1..];
let n2 = min(amount, min(src_after_head.len(), dst_tail.len()));
if n2 > 0 {
dst_tail[..n2].copy_from_slice(&src_after_head[..n2]);
amount -= n2;
}
if amount == 0 {
return original_amount;
}
// Copy from src_tail to the remainder of dst_head
let dst_after_head = &mut dst_head[n1..];
let n3 = min(amount, min(src_tail.len(), dst_after_head.len()));
if n3 > 0 {
dst_after_head[..n3].copy_from_slice(&src_tail[..n3]);
amount -= n3;
}
original_amount - amount
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{memory::PAGE_SIZE, test::MockCpuOps};
use ringbuf::storage::Heap;
use tokio::time::{Duration, timeout};
// Helper to create a KBufCore backed by a dynamically allocated buffer for testing.
fn make_kbuf(size: usize) -> KBufCore<u8, Heap<u8>, MockCpuOps> {
let storage = Heap::new(size);
KBufCore::new(storage)
}
#[tokio::test]
async fn simple_read_write() {
let kbuf = make_kbuf(16);
let in_buf = [1, 2, 3];
let mut out_buf = [0; 3];
let written = kbuf.push_slice(&in_buf).await;
assert_eq!(written, 3);
// read_ready should complete immediately since there's data.
kbuf.read_ready().await;
let read = kbuf.pop_slice(&mut out_buf).await;
assert_eq!(read, 3);
assert_eq!(in_buf, out_buf);
}
#[tokio::test]
async fn read_blocks_when_empty() {
let kbuf = make_kbuf(16);
let mut out_buf = [0; 3];
// We expect the read to time out because the buffer is empty and it should block.
let result = timeout(Duration::from_millis(10), kbuf.pop_slice(&mut out_buf)).await;
assert!(result.is_err(), "Read should have blocked and timed out");
}
#[tokio::test]
async fn write_blocks_when_full() {
let kbuf = make_kbuf(16);
let big_buf = [0; 16];
let small_buf = [1];
// Fill the buffer completely.
let written = kbuf.push_slice(&big_buf).await;
assert_eq!(written, 16);
assert!(kbuf.inner.lock_save_irq().buf.is_full());
// The next write should block.
let result = timeout(Duration::from_millis(10), kbuf.push_slice(&small_buf)).await;
assert!(result.is_err(), "Write should have blocked and timed out");
}
#[tokio::test]
async fn write_wakes_reader() {
let kbuf = make_kbuf(16);
let kbuf_clone = kbuf.clone();
// Spawn a task that will block reading from the empty buffer.
let reader_task = tokio::spawn(async move {
assert_eq!(kbuf_clone.pop().await, 10);
assert_eq!(kbuf_clone.pop().await, 20);
assert_eq!(kbuf_clone.pop().await, 30);
assert_eq!(kbuf_clone.pop().await, 40);
});
// Give the reader a moment to start and block.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, write to the buffer. This should wake up the reader.
kbuf.push(10).await;
// Give the reader a moment to block again.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, write to the buffer. This should wake up the reader.
kbuf.push(20).await;
assert!(kbuf.try_push(30).is_ok());
assert!(kbuf.try_push(40).is_ok());
reader_task.await.unwrap();
}
#[tokio::test]
async fn write_slice_wakes_reader() {
let kbuf = make_kbuf(16);
let kbuf_clone = kbuf.clone();
let in_buf = [10, 20, 30];
let mut out_buf = [0; 3];
// Spawn a task that will block reading from the empty buffer.
let reader_task = tokio::spawn(async move {
kbuf_clone.pop_slice(&mut out_buf).await;
out_buf // Return the buffer to check the result
});
// Give the reader a moment to start and block.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, write to the buffer. This should wake up the reader.
kbuf.push_slice(&in_buf).await;
// The reader task should now complete.
let result_buf = reader_task.await.unwrap();
assert_eq!(result_buf, in_buf);
}
#[tokio::test]
async fn read_wakes_writer() {
let kbuf = make_kbuf(8);
let kbuf_clone = kbuf.clone();
let mut buf = [0; 8];
// Fill the buffer.
kbuf.push_slice(&[1; 8]).await;
assert!(kbuf.inner.lock_save_irq().buf.is_full());
// Spawn a task that will block reading from the empty buffer.
let reader_task = tokio::spawn(async move {
kbuf_clone.push(10).await;
kbuf_clone.push(20).await;
kbuf_clone.push(30).await;
kbuf_clone.push(40).await;
});
// Give the writer a moment to start and block.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, read from the buffer. This should wake up the writer.
assert_eq!(kbuf.pop().await, 1);
// Give the writer a moment to block again.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, write to the buffer. This should wake up the reader.
assert_eq!(kbuf.pop().await, 1);
assert!(kbuf.try_pop().is_some());
assert!(kbuf.try_pop().is_some());
reader_task.await.unwrap();
kbuf.pop_slice(&mut buf).await;
assert_eq!(&buf, &[1, 1, 1, 1, 10, 20, 30, 40]);
}
#[tokio::test]
async fn read_slice_wakes_writer() {
let kbuf = make_kbuf(8);
let kbuf_clone = kbuf.clone();
let mut out_buf = [0; 4];
// Fill the buffer.
kbuf.push_slice(&[1; 8]).await;
assert!(kbuf.inner.lock_save_irq().buf.is_full());
// Spawn a task that will block trying to write to the full buffer.
let writer_task = tokio::spawn(async move {
let written = kbuf_clone.push_slice(&[2; 4]).await;
assert_eq!(written, 4);
});
// Give the writer a moment to start and block.
tokio::time::sleep(Duration::from_millis(5)).await;
// Now, read from the buffer. This should make space and wake the writer.
let read = kbuf.pop_slice(&mut out_buf).await;
assert_eq!(read, 4);
// The writer task should now complete.
writer_task.await.unwrap();
// The buffer should contain the remaining 4 ones and the 4 twos from the writer.
kbuf.pop_slice(&mut out_buf).await;
assert_eq!(out_buf, [1, 1, 1, 1]);
kbuf.pop_slice(&mut out_buf).await;
assert_eq!(out_buf, [2, 2, 2, 2]);
}
#[tokio::test]
async fn concurrent_producer_consumer() {
const ITERATIONS: usize = 5000;
let kbuf = make_kbuf(PAGE_SIZE);
let producer_kbuf = kbuf.clone();
let consumer_kbuf = kbuf.clone();
let producer = tokio::spawn(async move {
for i in 0..ITERATIONS {
let byte = (i % 256) as u8;
producer_kbuf.push_slice(&[byte]).await;
}
});
let consumer = tokio::spawn(async move {
let mut received = 0;
while received < ITERATIONS {
let mut buf = [0; 1];
let count = consumer_kbuf.pop_slice(&mut buf).await;
if count > 0 {
let expected_byte = (received % 256) as u8;
assert_eq!(buf[0], expected_byte);
received += 1;
}
}
});
let (prod_res, cons_res) = tokio::join!(producer, consumer);
prod_res.unwrap();
cons_res.unwrap();
}
// --- Splice Tests ---
#[tokio::test]
async fn splice_simple_transfer() {
let src = make_kbuf(PAGE_SIZE);
let dest = make_kbuf(PAGE_SIZE);
let data: Vec<u8> = (0..100).collect();
let mut out_buf = vec![0; 100];
src.push_slice(&data).await;
assert_eq!(src.inner.lock_save_irq().buf.occupied_len(), 100);
assert_eq!(dest.inner.lock_save_irq().buf.occupied_len(), 0);
let spliced = dest.splice_from(&src, 100).await;
assert_eq!(spliced, 100);
assert_eq!(src.inner.lock_save_irq().buf.occupied_len(), 0);
assert_eq!(dest.inner.lock_save_irq().buf.occupied_len(), 100);
dest.pop_slice(&mut out_buf).await;
assert_eq!(out_buf, data);
}
#[tokio::test]
async fn splice_limited_by_count() {
let src = make_kbuf(PAGE_SIZE);
let dest = make_kbuf(PAGE_SIZE);
let data: Vec<u8> = (0..100).collect();
src.push_slice(&data).await;
let spliced = dest.splice_from(&src, 50).await;
assert_eq!(spliced, 50);
assert_eq!(src.inner.lock_save_irq().buf.occupied_len(), 50);
assert_eq!(dest.inner.lock_save_irq().buf.occupied_len(), 50);
let mut out_buf = vec![0; 50];
dest.pop_slice(&mut out_buf).await;
assert_eq!(out_buf, &data[0..50]);
}
#[tokio::test]
async fn splice_limited_by_dest_capacity() {
let src = make_kbuf(200);
let dest = make_kbuf(100); // Smaller destination
let data: Vec<u8> = (0..200).collect();
src.push_slice(&data).await;
// Splice more than dest has capacity for.
let spliced = dest.splice_from(&src, 200).await;
assert_eq!(spliced, 100);
assert_eq!(src.inner.lock_save_irq().buf.occupied_len(), 100);
assert_eq!(dest.inner.lock_save_irq().buf.occupied_len(), 100);
assert!(dest.inner.lock_save_irq().buf.is_full());
}
#[tokio::test]
async fn splice_blocks_on_full_dest_and_wakes() {
let src = make_kbuf(100);
let dest = make_kbuf(50);
let dest_clone = dest.clone();
src.push_slice(&(0..100).collect::<Vec<u8>>()).await;
dest.push_slice(&(0..50).collect::<Vec<u8>>()).await;
assert!(dest.inner.lock_save_irq().buf.is_full());
// This splice will block because dest is full.
let splice_task = tokio::spawn(async move {
// It will only be able to splice 0 bytes initially, then block,
// then it will splice 25 bytes once space is made.
let spliced_bytes = dest_clone.splice_from(&src, 100).await;
assert_eq!(spliced_bytes, 25);
});
tokio::time::sleep(Duration::from_millis(5)).await;
// Make space in the destination buffer.
let mut read_buf = [0; 25];
let read_bytes = dest.pop_slice(&mut read_buf).await;
assert_eq!(read_bytes, 25);
// The splice task should now unblock and complete.
timeout(Duration::from_millis(50), splice_task)
.await
.expect("Splice task should have completed")
.unwrap();
}
#[tokio::test]
async fn splice_to_self_is_noop_and_doesnt_deadlock() {
let kbuf = make_kbuf(100);
let data = [1, 2, 3, 4, 5];
kbuf.push_slice(&data).await;
// This should return immediately with 0 and not deadlock.
let spliced = kbuf.splice_from(&kbuf, 50).await;
assert_eq!(spliced, 0);
// Verify data is untouched.
let mut out_buf = [0; 5];
kbuf.pop_slice(&mut out_buf).await;
assert_eq!(out_buf, data);
}
#[tokio::test]
async fn splice_into_partially_full_buffer() {
let src = make_kbuf(PAGE_SIZE);
let dest = make_kbuf(PAGE_SIZE);
// Setup: `dest` already has 20 bytes of data.
let old_data = vec![255; 20];
dest.push_slice(&old_data).await;
// `src` has 50 bytes of new data to be spliced.
let splice_data: Vec<u8> = (0..50).collect();
src.push_slice(&splice_data).await;
// Action: Splice the 50 bytes from `src` into `dest`.
// There is enough room, so the full amount should be spliced.
let spliced = dest.splice_from(&src, 50).await;
// Assertions:
assert_eq!(spliced, 50, "Should have spliced the requested 50 bytes");
// `src` should now be empty.
assert!(src.inner.lock_save_irq().buf.is_empty());
// `dest` should contain the old data followed by the new data.
assert_eq!(
dest.inner.lock_save_irq().buf.occupied_len(),
old_data.len() + splice_data.len()
);
let mut final_dest_data = vec![0; 70];
dest.pop_slice(&mut final_dest_data).await;
// Check that the original data is at the start.
assert_eq!(&final_dest_data[0..20], &old_data[..]);
// Check that the spliced data comes after it.
assert_eq!(&final_dest_data[20..70], &splice_data[..]);
}
#[tokio::test]
async fn splice_into_almost_full_buffer_is_limited() {
let src = make_kbuf(PAGE_SIZE);
// Use a smaller destination buffer to make capacity relevant.
let dest = make_kbuf(100);
// `dest` has 80 bytes, leaving only 20 bytes of free space.
let old_data = vec![255; 80];
dest.push_slice(&old_data).await;
// `src` has 50 bytes, more than the available space in `dest`.
let splice_data: Vec<u8> = (0..50).collect();
src.push_slice(&splice_data).await;
// Attempt to splice 50 bytes. This should be limited by `dest`'s
// capacity.
let spliced = dest.splice_from(&src, 50).await;
assert_eq!(
spliced, 20,
"Splice should be limited to the 20 bytes of available space"
);
// `dest` should now be completely full.
assert!(dest.inner.lock_save_irq().buf.is_full());
// `src` should have the remaining 30 bytes that couldn't be spliced.
assert_eq!(src.inner.lock_save_irq().buf.occupied_len(), 30);
// Verify the contents of `dest`.
let mut final_dest_data = vec![0; 100];
dest.pop_slice(&mut final_dest_data).await;
assert_eq!(&final_dest_data[0..80], &old_data[..]);
assert_eq!(&final_dest_data[80..100], &splice_data[0..20]); // Only the first 20 bytes
// Verify the remaining contents of `src`.
let mut remaining_src_data = vec![0; 30];
src.pop_slice(&mut remaining_src_data).await;
assert_eq!(&remaining_src_data[..], &splice_data[20..50]); // The last 30 bytes
}
}

View File

@@ -0,0 +1,13 @@
pub mod address;
pub mod kbuf;
pub mod page;
pub mod page_alloc;
pub mod permissions;
pub mod pg_offset;
pub mod proc_vm;
pub mod region;
pub mod smalloc;
pub const PAGE_SIZE: usize = 4096;
pub const PAGE_SHIFT: usize = PAGE_SIZE.trailing_zeros() as usize;
pub const PAGE_MASK: usize = PAGE_SIZE - 1;

View File

@@ -0,0 +1,38 @@
use core::fmt::Display;
use super::{PAGE_SHIFT, PAGE_SIZE, address::PA, region::PhysMemoryRegion};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct PageFrame {
n: usize,
}
impl Display for PageFrame {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.n.fmt(f)
}
}
impl PageFrame {
pub fn from_pfn(n: usize) -> Self {
Self { n }
}
pub fn pa(&self) -> PA {
PA::from_value(self.n << PAGE_SHIFT)
}
pub fn as_phys_range(&self) -> PhysMemoryRegion {
PhysMemoryRegion::new(self.pa(), PAGE_SIZE)
}
pub fn value(&self) -> usize {
self.n
}
pub(crate) fn buddy(self, order: usize) -> Self {
Self {
n: self.n ^ (1 << order),
}
}
}

View File

@@ -0,0 +1,784 @@
use crate::{
CpuOps,
error::{KernelError, Result},
memory::{PAGE_SHIFT, address::AddressTranslator, page::PageFrame, smalloc::Smalloc},
sync::spinlock::SpinLockIrq,
};
use core::{
cmp::min,
mem::{MaybeUninit, size_of, transmute},
};
use intrusive_collections::{LinkedList, LinkedListLink, UnsafeRef, intrusive_adapter};
use log::info;
use super::region::PhysMemoryRegion;
// The maximum order for the buddy system. This corresponds to blocks of size
// 2^MAX_ORDER pages.
const MAX_ORDER: usize = 10;
#[derive(Clone, Copy, Debug)]
pub struct AllocatedInfo {
/// Current ref count of the allocated block.
pub ref_count: u32,
/// The order of the entire allocated block.
pub order: u8,
}
/// Holds metadata for a page that is part of an allocated block but is not the head.
/// It simply points back to the head of the block.
#[derive(Clone, Copy, Debug)]
pub struct TailInfo {
pub head: PageFrame,
}
#[derive(Debug, Clone)]
pub enum FrameState {
/// The frame has not yet been processed by the allocator's init function.
Uninitialized,
/// The frame is the head of a free block of a certain order.
Free { order: u8 },
/// The frame is the head of an allocated block.
AllocatedHead(AllocatedInfo),
/// The frame is a tail page of an allocated block.
AllocatedTail(TailInfo),
/// The frame is reserved by hardware/firmware.
Reserved,
/// The frame is part of the kernel's own image.
Kernel,
}
#[derive(Debug, Clone)]
struct Frame {
state: FrameState,
link: LinkedListLink, // only used in free nodes.
pfn: PageFrame,
}
intrusive_adapter!(FrameAdapter = UnsafeRef<Frame>: Frame { link: LinkedListLink });
impl Frame {
fn new(pfn: PageFrame) -> Self {
Self {
state: FrameState::Uninitialized,
link: LinkedListLink::new(),
pfn,
}
}
}
struct FrameAllocatorInner {
pages: &'static mut [Frame],
base_page: PageFrame,
total_pages: usize,
free_pages: usize,
free_lists: [LinkedList<FrameAdapter>; MAX_ORDER + 1],
}
impl FrameAllocatorInner {
/// Frees a previously allocated block of frames.
/// The PFN can point to any page within the allocated block.
fn free_frames(&mut self, region: PhysMemoryRegion) {
let head_pfn = region.start_address().to_pfn();
debug_assert!(matches!(
self.get_frame(head_pfn).state,
FrameState::AllocatedHead(_)
));
let initial_order =
if let FrameState::AllocatedHead(ref mut info) = self.get_frame_mut(head_pfn).state {
if info.ref_count > 1 {
info.ref_count -= 1;
return;
}
info.order as usize
} else {
unreachable!("Logic error: head PFN is not an AllocatedHead");
};
// Before merging, the block we're freeing is no longer allocated. Set
// it to a temporary state. This prevents stale AllocatedHead states if
// this block gets absorbed by its lower buddy.
self.get_frame_mut(head_pfn).state = FrameState::Uninitialized;
let mut merged_order = initial_order;
let mut current_pfn = head_pfn;
for order in initial_order..MAX_ORDER {
let buddy_pfn = current_pfn.buddy(order);
if buddy_pfn < self.base_page
|| buddy_pfn.value() >= self.base_page.value() + self.total_pages
{
break;
}
if let FrameState::Free { order: buddy_order } = self.get_frame(buddy_pfn).state
&& buddy_order as usize == order
{
// Buddy is free and of the same order. Merge them.
// Remove the existing free buddy from its list. This function
// already sets its state to Uninitialized.
self.remove_from_free_list(buddy_pfn, order);
// The new, larger block's PFN is the lower of the two.
current_pfn = min(current_pfn, buddy_pfn);
merged_order += 1;
} else {
break;
}
}
// Update the state of the final merged block's head.
self.get_frame_mut(current_pfn).state = FrameState::Free {
order: merged_order as u8,
}; // Add the correctly-stated block to the correct free list.
self.add_to_free_list(current_pfn, merged_order);
self.free_pages += 1 << initial_order;
}
#[inline]
fn pfn_to_slice_index(&self, pfn: PageFrame) -> usize {
assert!(pfn.value() >= self.base_page.value(), "PFN is below base");
let offset = pfn.value() - self.base_page.value();
assert!(offset < self.pages.len(), "PFN is outside managed range");
offset
}
#[inline]
fn get_frame(&self, pfn: PageFrame) -> &Frame {
&self.pages[self.pfn_to_slice_index(pfn)]
}
#[inline]
fn get_frame_mut(&mut self, pfn: PageFrame) -> &mut Frame {
let idx = self.pfn_to_slice_index(pfn);
&mut self.pages[idx]
}
fn add_to_free_list(&mut self, pfn: PageFrame, order: usize) {
#[cfg(test)]
assert!(matches!(self.get_frame(pfn).state, FrameState::Free { .. }));
self.free_lists[order]
.push_front(unsafe { UnsafeRef::from_raw(self.get_frame(pfn) as *const _) });
}
fn remove_from_free_list(&mut self, pfn: PageFrame, order: usize) {
let Some(_) = (unsafe {
self.free_lists[order]
.cursor_mut_from_ptr(self.get_frame(pfn) as *const _)
.remove()
}) else {
panic!("Attempted to remove non-free block");
};
// Mark the removed frame as uninitialized to prevent dangling pointers.
self.get_frame_mut(pfn).state = FrameState::Uninitialized;
}
}
pub struct FrameAllocator<CPU: CpuOps> {
inner: SpinLockIrq<FrameAllocatorInner, CPU>,
}
pub struct PageAllocation<'a, CPU: CpuOps> {
region: PhysMemoryRegion,
inner: &'a SpinLockIrq<FrameAllocatorInner, CPU>,
}
impl<CPU: CpuOps> PageAllocation<'_, CPU> {
pub fn leak(self) -> PhysMemoryRegion {
let region = self.region;
core::mem::forget(self);
region
}
pub fn region(&self) -> &PhysMemoryRegion {
&self.region
}
}
impl<CPU: CpuOps> Clone for PageAllocation<'_, CPU> {
fn clone(&self) -> Self {
let mut inner = self.inner.lock_save_irq();
match inner
.get_frame_mut(self.region.start_address().to_pfn())
.state
{
FrameState::AllocatedHead(ref mut alloc_info) => {
alloc_info.ref_count += 1;
}
_ => panic!("Inconsistent memory metadata detected"),
}
Self {
region: self.region,
inner: self.inner,
}
}
}
impl<CPU: CpuOps> Drop for PageAllocation<'_, CPU> {
fn drop(&mut self) {
self.inner.lock_save_irq().free_frames(self.region);
}
}
unsafe impl Send for FrameAllocatorInner {}
impl<CPU: CpuOps> FrameAllocator<CPU> {
/// Allocates a physically contiguous block of frames.
///
/// # Arguments
/// * `order`: The order of the allocation, where the number of pages is `2^order`.
/// `order = 0` requests a single page.
pub fn alloc_frames(&self, order: u8) -> Result<PageAllocation<'_, CPU>> {
let mut inner = self.inner.lock_save_irq();
let requested_order = order as usize;
if requested_order > MAX_ORDER {
return Err(KernelError::InvalidValue);
}
// Find the smallest order >= the requested order that has a free block.
let Some((free_block, mut current_order)) =
(requested_order..=MAX_ORDER).find_map(|order| {
let pg_block = inner.free_lists[order].pop_front()?;
Some((pg_block, order))
})
else {
return Err(KernelError::NoMemory);
};
let free_block = inner.get_frame_mut(free_block.pfn);
free_block.state = FrameState::Uninitialized;
let block_pfn = free_block.pfn;
// Split the block down until it's the correct size.
while current_order > requested_order {
current_order -= 1;
let buddy = block_pfn.buddy(current_order);
inner.get_frame_mut(buddy).state = FrameState::Free {
order: current_order as _,
};
inner.add_to_free_list(buddy, current_order);
}
// Mark the final block metadata.
let pfn_idx = inner.pfn_to_slice_index(block_pfn);
inner.pages[pfn_idx].state = FrameState::AllocatedHead(AllocatedInfo {
ref_count: 1,
order: requested_order as u8,
});
let num_pages_in_block = 1 << requested_order;
for i in 1..num_pages_in_block {
inner.pages[pfn_idx + i].state =
FrameState::AllocatedTail(TailInfo { head: block_pfn });
}
inner.free_pages -= num_pages_in_block;
Ok(PageAllocation {
region: PhysMemoryRegion::new(block_pfn.pa(), num_pages_in_block << PAGE_SHIFT),
inner: &self.inner,
})
}
/// Constructs an allocation from a phys mem region.
///
/// # Safety
///
/// This function does no checks to ensure that the region passed is
/// actually allocated and the region is of the correct size. The *only* way
/// to ensure safety is to use a region that was previously leaked with
/// [PageAllocation::leak].
pub unsafe fn alloc_from_region(&self, region: PhysMemoryRegion) -> PageAllocation<'_, CPU> {
PageAllocation {
region,
inner: &self.inner,
}
}
/// Returns `true` if the page is part of an allocated block, `false`
/// otherwise.
pub fn is_allocated(&self, pfn: PageFrame) -> bool {
matches!(
self.inner.lock_save_irq().get_frame(pfn).state,
FrameState::AllocatedHead(_) | FrameState::AllocatedTail(_)
)
}
/// Returns `true` if the page is part of an allocated block and has a ref
/// count of 1, `false` otherwise.
pub fn is_allocated_exclusive(&self, mut pfn: PageFrame) -> bool {
let inner = self.inner.lock_save_irq();
loop {
match inner.get_frame(pfn).state {
FrameState::AllocatedTail(TailInfo { head }) => pfn = head,
FrameState::AllocatedHead(AllocatedInfo { ref_count: 1, .. }) => {
return true;
}
_ => return false,
}
}
}
/// Initializes the frame allocator. This is the main bootstrap function.
///
/// # Safety
/// It's unsafe because it deals with raw pointers and takes ownership of
/// the metadata memory. It should only be called once.
pub unsafe fn init<T: AddressTranslator<()>>(mut smalloc: Smalloc<T>) -> Self {
let highest_addr = smalloc
.iter_memory()
.map(|r| r.end_address())
.max()
.unwrap();
let lowest_addr = smalloc
.iter_memory()
.map(|r| r.start_address())
.min()
.unwrap();
let total_pages = (highest_addr.value() - lowest_addr.value()) >> PAGE_SHIFT;
let metadata_size = total_pages * size_of::<Frame>();
let metadata_addr = smalloc
.alloc(metadata_size, align_of::<Frame>())
.expect("Failed to allocate memory for page metadata")
.cast::<MaybeUninit<Frame>>();
let pages_uninit: &mut [MaybeUninit<Frame>] = unsafe {
core::slice::from_raw_parts_mut(
metadata_addr.to_untyped().to_va::<T>().cast().as_ptr_mut(),
total_pages,
)
};
// Initialize all frames to a known state.
for (i, p) in pages_uninit.iter_mut().enumerate() {
p.write(Frame::new(PageFrame::from_pfn(
lowest_addr.to_pfn().value() + i,
)));
}
// The transmute is safe because we just initialized all elements.
let pages: &mut [Frame] = unsafe { transmute(pages_uninit) };
let mut allocator = FrameAllocatorInner {
pages,
base_page: lowest_addr.to_pfn(),
total_pages,
free_pages: 0,
free_lists: core::array::from_fn(|_| LinkedList::new(FrameAdapter::new())),
};
for region in smalloc.res.iter() {
for pfn in region.iter_pfns() {
if pfn >= allocator.base_page
&& pfn.value() < allocator.base_page.value() + allocator.total_pages
{
allocator.get_frame_mut(pfn).state = FrameState::Kernel;
}
}
}
for region in smalloc.iter_free() {
// Align the start address to the first naturally aligned MAX_ORDER
// block.
let region =
region.with_start_address(region.start_address().align_up(1 << (MAX_ORDER + 12)));
let mut current_pfn = region.start_address().to_pfn();
let end_pfn = region.end_address().to_pfn();
while current_pfn.value() + (1 << MAX_ORDER) <= end_pfn.value() {
allocator.get_frame_mut(current_pfn).state = FrameState::Free {
order: MAX_ORDER as _,
};
allocator.add_to_free_list(current_pfn, MAX_ORDER);
allocator.free_pages += 1 << MAX_ORDER;
current_pfn = PageFrame::from_pfn(current_pfn.value() + (1 << MAX_ORDER));
}
}
info!(
"Buddy allocator initialized. Managing {} pages, {} free.",
allocator.total_pages, allocator.free_pages
);
FrameAllocator {
inner: SpinLockIrq::new(allocator),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
memory::{
address::{IdentityTranslator, PA},
region::PhysMemoryRegion,
smalloc::{RegionList, Smalloc},
},
test::MockCpuOps,
};
use core::{alloc::Layout, mem::MaybeUninit};
use std::vec::Vec; // For collecting results in tests
const KIB: usize = 1024;
const MIB: usize = 1024 * KIB;
const PAGE_SIZE: usize = 4096;
struct TestFixture {
allocator: FrameAllocator<MockCpuOps>,
base_ptr: *mut u8,
layout: Layout,
}
impl TestFixture {
/// Creates a new test fixture.
///
/// - `mem_regions`: A slice of `(start, size)` tuples defining available memory regions.
/// The `start` is relative to the beginning of the allocated memory block.
/// - `res_regions`: A slice of `(start, size)` tuples for reserved regions (e.g., kernel).
fn new(mem_regions: &[(usize, usize)], res_regions: &[(usize, usize)]) -> Self {
// Determine the total memory size required for the test environment.
let total_size = mem_regions
.iter()
.map(|(start, size)| start + size)
.max()
.unwrap_or(16 * MIB);
let layout =
Layout::from_size_align(total_size, 1 << (MAX_ORDER + PAGE_SHIFT)).unwrap();
let base_ptr = unsafe { std::alloc::alloc(layout) };
assert!(!base_ptr.is_null(), "Test memory allocation failed");
// Leaking is a common pattern in kernel test code to get static slices.
let mem_region_list: &mut [MaybeUninit<PhysMemoryRegion>] =
Vec::from([MaybeUninit::uninit(); 16]).leak();
let res_region_list: &mut [MaybeUninit<PhysMemoryRegion>] =
Vec::from([MaybeUninit::uninit(); 16]).leak();
let mut smalloc: Smalloc<IdentityTranslator> = Smalloc::new(
RegionList::new(16, mem_region_list.as_mut_ptr().cast()),
RegionList::new(16, res_region_list.as_mut_ptr().cast()),
);
let base_addr = base_ptr as usize;
for &(start, size) in mem_regions {
smalloc
.add_memory(PhysMemoryRegion::new(
PA::from_value(base_addr + start),
size,
))
.unwrap();
}
for &(start, size) in res_regions {
smalloc
.add_reservation(PhysMemoryRegion::new(
PA::from_value(base_addr + start),
size,
))
.unwrap();
}
let allocator = unsafe { FrameAllocator::init(smalloc) };
Self {
allocator,
base_ptr,
layout,
}
}
/// Get the state of a specific frame.
fn frame_state(&self, pfn: PageFrame) -> FrameState {
self.allocator
.inner
.lock_save_irq()
.get_frame(pfn)
.state
.clone()
}
/// Checks that the number of blocks in each free list matches the expected counts.
fn assert_free_list_counts(&self, expected_counts: &[usize; MAX_ORDER + 1]) {
for order in 0..=MAX_ORDER {
let count = self.allocator.inner.lock_save_irq().free_lists[order]
.iter()
.count();
assert_eq!(
count, expected_counts[order],
"Mismatch in free list count for order {}",
order
);
}
}
fn free_pages(&self) -> usize {
self.allocator.inner.lock_save_irq().free_pages
}
}
impl Drop for TestFixture {
fn drop(&mut self) {
unsafe {
self.allocator
.inner
.lock_save_irq()
.free_lists
.iter_mut()
.for_each(|x| x.clear());
std::alloc::dealloc(self.base_ptr, self.layout);
}
}
}
/// Tests basic allocator initialization with a single large, contiguous memory region.
#[test]
fn init_simple() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let pages_in_max_block = 1 << MAX_ORDER;
assert_eq!(fixture.free_pages(), pages_in_max_block);
assert!(!fixture.allocator.inner.lock_save_irq().free_lists[MAX_ORDER].is_empty());
// Check that all other lists are empty
for i in 0..MAX_ORDER {
assert!(fixture.allocator.inner.lock_save_irq().free_lists[i].is_empty());
}
}
#[test]
fn init_with_kernel_reserved() {
let block_size = (1 << MAX_ORDER) * PAGE_SIZE;
// A region large enough for 3 max-order blocks
let total_size = 4 * block_size;
// Reserve the middle block. Even a single page anywhere in that block
// should wipe out the whole block.
let res_regions = &[(block_size * 2 + 4 * PAGE_SIZE, PAGE_SIZE)];
let fixture = TestFixture::new(&[(0, total_size)], res_regions);
let pages_in_max_block = 1 << MAX_ORDER;
// We should have 2 max-order blocks, not 3.
assert_eq!(fixture.free_pages(), 2 * pages_in_max_block);
// The middle pages should be marked as Kernel
let reserved_pfn = PageFrame::from_pfn(
fixture.allocator.inner.lock_save_irq().base_page.value()
+ (pages_in_max_block * 2 + 4),
);
assert!(matches!(
fixture.frame_state(reserved_pfn),
FrameState::Kernel
));
// Allocation of a MAX_ORDER block should succeed twice.
fixture
.allocator
.alloc_frames(MAX_ORDER as u8)
.unwrap()
.leak();
fixture
.allocator
.alloc_frames(MAX_ORDER as u8)
.unwrap()
.leak();
// A third should fail.
assert!(matches!(
fixture.allocator.alloc_frames(MAX_ORDER as u8),
Err(KernelError::NoMemory)
));
}
/// Tests a simple allocation and deallocation cycle.
#[test]
fn simple_alloc_and_free() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let initial_free_pages = fixture.free_pages();
// Ensure we start with a single MAX_ORDER block.
let mut expected_counts = [0; MAX_ORDER + 1];
expected_counts[MAX_ORDER] = 1;
fixture.assert_free_list_counts(&expected_counts);
// Allocate a single page
let alloc = fixture
.allocator
.alloc_frames(0)
.expect("Allocation failed");
assert_eq!(fixture.free_pages(), initial_free_pages - 1);
// Check its state
match fixture.frame_state(alloc.region.start_address().to_pfn()) {
FrameState::AllocatedHead(info) => {
assert_eq!(info.order, 0);
assert_eq!(info.ref_count, 1);
}
_ => panic!("Incorrect frame state after allocation"),
}
// Free the page
drop(alloc);
assert_eq!(fixture.free_pages(), initial_free_pages);
// Ensure we merged back to a single MAX_ORDER block.
fixture.assert_free_list_counts(&expected_counts);
}
/// Tests allocation that requires splitting a large block.
#[test]
fn alloc_requires_split() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
// Allocate a single page (order 0)
let _pfn = fixture.allocator.alloc_frames(0).unwrap();
// Check free pages
let pages_in_block = 1 << MAX_ORDER;
assert_eq!(fixture.free_pages(), pages_in_block - 1);
// Splitting a MAX_ORDER block to get an order 0 page should leave
// one free block at each intermediate order.
let mut expected_counts = [0; MAX_ORDER + 1];
for i in 0..MAX_ORDER {
expected_counts[i] = 1;
}
fixture.assert_free_list_counts(&expected_counts);
}
/// Tests the allocation of a multi-page block and verifies head/tail metadata.
#[test]
fn alloc_multi_page_block() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let order = 3; // 8 pages
let head_region = fixture.allocator.alloc_frames(order).unwrap();
assert_eq!(head_region.region.iter_pfns().count(), 8);
// Check head page
match fixture.frame_state(head_region.region.iter_pfns().next().unwrap()) {
FrameState::AllocatedHead(info) => assert_eq!(info.order, order as u8),
_ => panic!("Head page has incorrect state"),
}
// Check tail pages
for (i, pfn) in head_region.region.iter_pfns().skip(1).enumerate() {
match fixture.frame_state(pfn) {
FrameState::AllocatedTail(info) => {
assert_eq!(info.head, head_region.region.start_address().to_pfn())
}
_ => panic!("Tail page {} has incorrect state", i),
}
}
}
/// Tests that freeing a tail page correctly frees the entire block.
#[test]
fn free_tail_page() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let initial_free = fixture.free_pages();
let order = 4; // 16 pages
let num_pages = 1 << order;
let head_alloc = fixture.allocator.alloc_frames(order as u8).unwrap();
assert_eq!(fixture.free_pages(), initial_free - num_pages);
drop(head_alloc);
// All pages should be free again
assert_eq!(fixture.free_pages(), initial_free);
}
/// Tests exhausting memory and handling the out-of-memory condition.
#[test]
fn alloc_out_of_memory() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let total_pages = fixture.free_pages();
assert!(total_pages > 0);
let mut allocs = Vec::new();
for _ in 0..total_pages {
match fixture.allocator.alloc_frames(0) {
Ok(pfn) => allocs.push(pfn),
Err(e) => panic!("Allocation failed prematurely: {:?}", e),
}
}
assert_eq!(fixture.free_pages(), 0);
// Next allocation should fail
let result = fixture.allocator.alloc_frames(0);
assert!(matches!(result, Err(KernelError::NoMemory)));
// Free everything and check if memory is recovered
drop(allocs);
assert_eq!(fixture.free_pages(), total_pages);
}
/// Tests that requesting an invalid order fails gracefully.
#[test]
fn alloc_invalid_order() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let result = fixture.allocator.alloc_frames((MAX_ORDER + 1) as u8);
assert!(matches!(result, Err(KernelError::InvalidValue)));
}
/// Tests the reference counting mechanism in `free_frames`.
#[test]
fn ref_count_free() {
let fixture = TestFixture::new(&[(0, (1 << (MAX_ORDER + PAGE_SHIFT)) * 2)], &[]);
let initial_free = fixture.free_pages();
let alloc1 = fixture.allocator.alloc_frames(2).unwrap();
let alloc2 = alloc1.clone();
let alloc3 = alloc2.clone();
let pages_in_block = 1 << 2;
assert_eq!(fixture.free_pages(), initial_free - pages_in_block);
let pfn = alloc1.region().start_address().to_pfn();
// First free should just decrement the count
drop(alloc1);
assert_eq!(fixture.free_pages(), initial_free - pages_in_block);
if let FrameState::AllocatedHead(info) = fixture.frame_state(pfn) {
assert_eq!(info.ref_count, 2);
} else {
panic!("Page state changed unexpectedly");
}
// Second free, same thing
drop(alloc2);
assert_eq!(fixture.free_pages(), initial_free - pages_in_block);
if let FrameState::AllocatedHead(info) = fixture.frame_state(pfn) {
assert_eq!(info.ref_count, 1);
} else {
panic!("Page state changed unexpectedly");
}
// Third free should actually release the memory
drop(alloc3);
assert_eq!(fixture.free_pages(), initial_free);
assert!(matches!(fixture.frame_state(pfn), FrameState::Free { .. }));
}
}

View File

@@ -0,0 +1,286 @@
use core::fmt;
use super::proc_vm::vmarea::VMAPermissions;
/// Represents the memory permissions for a virtual memory mapping.
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct PtePermissions {
read: bool,
write: bool,
execute: bool,
user: bool,
cow: bool,
}
impl From<VMAPermissions> for PtePermissions {
fn from(value: VMAPermissions) -> Self {
Self {
read: value.read,
write: value.write,
execute: value.execute,
user: true, // VMAs only represent user address spaces.
cow: false, // a VMA will only be COW when it's cloned.
}
}
}
impl PtePermissions {
/// Creates a new `PtePermissions` from its raw boolean components.
///
/// This constructor is intended exclusively for use by the
/// architecture-specific MMU implementation when decoding a raw page table
/// entry. It is marked `pub(crate)` to prevent its use outside the
/// libkernel crate, preserving the safety invariants for the rest of the
/// kernel.
#[inline]
pub(crate) const fn from_raw_bits(
read: bool,
write: bool,
execute: bool,
user: bool,
cow: bool,
) -> Self {
debug_assert!(
!(write && cow),
"PTE permissions cannot be simultaneously writable and CoW"
);
Self {
read,
write,
execute,
user,
cow,
}
}
/// Creates a new read-only permission set.
pub const fn ro(user: bool) -> Self {
Self {
read: true,
write: false,
execute: false,
user,
cow: false,
}
}
/// Creates a new read-write permission set.
pub const fn rw(user: bool) -> Self {
Self {
read: true,
write: true,
execute: false,
user,
cow: false,
}
}
/// Creates a new read-execute permission set.
pub const fn rx(user: bool) -> Self {
Self {
read: true,
write: false,
execute: true,
user,
cow: false,
}
}
/// Creates a new read-write-execute permission set.
pub const fn rwx(user: bool) -> Self {
Self {
read: true,
write: true,
execute: true,
user,
cow: false,
}
}
/// Returns `true` if the mapping is readable.
pub const fn is_read(&self) -> bool {
self.read
}
/// Returns `true` if the mapping is writable. This will be `false` for a
/// CoW mapping.
pub const fn is_write(&self) -> bool {
self.write
}
/// Returns `true` if the mapping is executable.
pub const fn is_execute(&self) -> bool {
self.execute
}
/// Returns `true` if the mapping is accessible from user space.
pub const fn is_user(&self) -> bool {
self.user
}
/// Returns `true` if the mapping is a Copy-on-Write mapping.
pub const fn is_cow(&self) -> bool {
self.cow
}
/// Converts a writable permission set into its Copy-on-Write equivalent.
///
/// This method enforces the invariant that a mapping cannot be both
/// writable and CoW by explicitly setting `write` to `false`.
///
/// # Example
/// ```
/// use libkernel::memory::permissions::PtePermissions;
///
/// let perms = PtePermissions::rw(true);
/// let cow_perms = perms.into_cow();
/// assert!(!cow_perms.is_write());
/// assert!(cow_perms.is_cow());
/// assert!(cow_perms.is_read());
/// ```
///
/// # Panics
///
/// Panics in debug builds if the permissions are not originally writable,
/// as it is a logical error to make a non-writable page Copy-on-Write.
pub fn into_cow(self) -> Self {
debug_assert!(self.write, "Cannot make a non-writable mapping CoW");
Self {
write: false,
cow: true,
..self
}
}
/// Converts a Copy-on-Write permission set back into a writable one.
///
/// This is used by the page fault handler after a page has been copied or
/// exclusively claimed. It makes the page writable by the hardware and
/// removes the kernel's `CoW` marker.
///
/// # Example
/// ```
/// use libkernel::memory::permissions::PtePermissions;
///
/// let cow_perms = PtePermissions::rw(true).into_cow();
/// let writable_perms = cow_perms.from_cow();
/// assert!(writable_perms.is_write());
/// assert!(!writable_perms.is_cow());
/// assert!(writable_perms.is_read());
/// ```
///
/// # Panics
///
/// Panics in debug builds if the permissions are not CoW, as this indicates a
/// logic error in the fault handler.
pub fn from_cow(self) -> Self {
debug_assert!(self.cow, "Cannot make a non-CoW mapping writable again");
Self {
write: true, // The invariant is enforced here.
cow: false,
..self
}
}
}
impl fmt::Display for PtePermissions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let r = if self.read { 'r' } else { '-' };
let x = if self.execute { 'x' } else { '-' };
let user = if self.user { 'u' } else { 'k' };
// Display 'w' for writable, or 'c' for CoW. The invariant guarantees
// that `self.write` and `self.cow` cannot both be true.
let w_or_c = if self.write {
'w'
} else if self.cow {
'c'
} else {
'-'
};
write!(f, "{}{}{} {}", r, w_or_c, x, user)
}
}
impl fmt::Debug for PtePermissions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemPermissions")
.field("read", &self.read)
.field("write", &self.write)
.field("execute", &self.execute)
.field("user", &self.user)
.field("cow", &self.cow)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constructors() {
let p = PtePermissions::rw(true);
assert!(p.is_read());
assert!(p.is_write());
assert!(!p.is_execute());
assert!(p.is_user());
assert!(!p.is_cow());
let p = PtePermissions::rx(false);
assert!(p.is_read());
assert!(!p.is_write());
assert!(p.is_execute());
assert!(!p.is_user());
assert!(!p.is_cow());
}
#[test]
fn test_cow_transition() {
let p_rw = PtePermissions::rw(true);
let p_cow = p_rw.into_cow();
// Check CoW state
assert!(p_cow.is_read());
assert!(!p_cow.is_write());
assert!(!p_cow.is_execute());
assert!(p_cow.is_user());
assert!(p_cow.is_cow());
// Transition back
let p_final = p_cow.from_cow();
assert_eq!(p_rw, p_final);
}
#[test]
#[should_panic]
fn test_into_cow_panic() {
// Cannot make a read-only page CoW
let p_ro = PtePermissions::ro(true);
let _ = p_ro.into_cow();
}
#[test]
#[should_panic]
fn test_from_cow_panic() {
// Cannot convert a non-CoW page from CoW
let p_rw = PtePermissions::rw(true);
let _ = p_rw.from_cow();
}
#[test]
fn test_display_format() {
assert_eq!(format!("{}", PtePermissions::rw(true)), "rw- u");
assert_eq!(format!("{}", PtePermissions::rwx(false)), "rwx k");
assert_eq!(format!("{}", PtePermissions::ro(true)), "r-- u");
assert_eq!(format!("{}", PtePermissions::rx(false)), "r-x k");
let cow_perms = PtePermissions::rw(true).into_cow();
assert_eq!(format!("{}", cow_perms), "rc- u");
let cow_exec_perms = PtePermissions::rwx(false).into_cow();
assert_eq!(format!("{}", cow_exec_perms), "rcx k");
}
}

View File

@@ -0,0 +1,25 @@
use super::address::{AddressTranslator, TPA, TVA};
use crate::VirtualMemory;
use core::marker::PhantomData;
pub struct PageOffsetTranslator<VM: VirtualMemory> {
_phantom: PhantomData<VM>,
}
impl<T, VM: VirtualMemory> AddressTranslator<T> for PageOffsetTranslator<VM> {
fn virt_to_phys(va: TVA<T>) -> TPA<T> {
let mut v = va.value();
v -= VM::PAGE_OFFSET;
TPA::from_value(v)
}
fn phys_to_virt(pa: TPA<T>) -> TVA<T> {
let mut v = pa.value();
v += VM::PAGE_OFFSET;
TVA::from_value(v)
}
}

View File

@@ -0,0 +1,447 @@
use super::vmarea::{VMAPermissions, VMArea, VMAreaKind};
use crate::{
UserAddressSpace,
error::{KernelError, Result},
memory::{
PAGE_MASK, PAGE_SIZE, address::VA, page::PageFrame, permissions::PtePermissions,
region::VirtMemoryRegion,
},
};
use alloc::{collections::BTreeMap, vec::Vec};
const MMAP_BASE: usize = 0x4000_0000_0000;
/// Manages mappings in a process's address space.
pub struct MemoryMap<AS: UserAddressSpace> {
vmas: BTreeMap<VA, VMArea>,
address_space: AS,
}
#[derive(Debug, PartialEq, Eq)]
pub enum AddressRequest {
Any,
Hint(VA),
Fixed { address: VA, permit_overlap: bool },
}
impl<AS: UserAddressSpace> MemoryMap<AS> {
/// Creates a new, empty address space.
pub fn new() -> Result<Self> {
Ok(Self {
vmas: BTreeMap::new(),
address_space: AS::new()?,
})
}
pub(super) fn with_addr_spc(address_space: AS) -> Self {
Self {
vmas: BTreeMap::new(),
address_space,
}
}
/// Create an address space from a pre-populated list of VMAs. Used by the
/// ELF loader.
pub fn from_vmas(vmas: Vec<VMArea>) -> Result<Self> {
let mut map = BTreeMap::new();
for vma in vmas {
map.insert(vma.region.start_address(), vma);
}
Ok(Self {
vmas: map,
address_space: AS::new()?,
})
}
/// Finds the `VMArea` that contains the given virtual address.
///
/// # Arguments
/// * `addr`: The virtual address to look up.
///
/// # Returns
/// * `Some(VMArea)` if the address is part of a valid mapping.
/// * `None` if the address is in a "hole" in the address space.
pub fn find_vma(&self, addr: VA) -> Option<&VMArea> {
let candidate = self.vmas.range(..=addr).next_back();
match candidate {
Some((_, vma)) => {
if vma.contains_address(addr) {
Some(vma)
} else {
None
}
}
None => None, // No VMA starts at or before this address.
}
}
/// Maps a region of memory.
pub fn mmap(
&mut self,
requested_address: AddressRequest,
mut len: usize,
perms: VMAPermissions,
kind: VMAreaKind,
) -> Result<VA> {
if len == 0 {
return Err(KernelError::InvalidValue);
}
// Ensure the length is page-aligned.
if len & PAGE_MASK != 0 {
len = (len & !PAGE_MASK) + PAGE_SIZE;
}
let region = match requested_address {
AddressRequest::Any => self.find_free_region(len).ok_or(KernelError::NoMemory)?,
AddressRequest::Hint(address) => {
// Be more permissive when it's a hint.
let address = if !address.is_page_aligned() {
address.page_aligned()
} else {
address
};
let region = VirtMemoryRegion::new(address, len);
if self.is_region_free(region) {
region
} else {
self.find_free_region(len).ok_or(KernelError::NoMemory)?
}
}
AddressRequest::Fixed {
address,
permit_overlap,
} => {
if !address.is_page_aligned() {
return Err(KernelError::InvalidValue);
}
let region = VirtMemoryRegion::new(address, len);
if !permit_overlap && !self.is_region_free(region) {
return Err(KernelError::InvalidValue);
}
region
}
};
// At this point, `start_addr` points to a valid, free region.
// We can now create and insert the new VMA, handling merges.
let new_vma = VMArea::new(region, kind, perms);
self.insert_and_merge(new_vma);
Ok(region.start_address())
}
/// Unmaps a region of memory, similar to the `munmap` syscall.
///
/// This is the most complex operation, as it may involve removing,
/// resizing, or splitting one or more existing VMAs.
///
/// # Arguments
/// * `addr`: The starting address of the region to unmap. Must be page-aligned.
/// * `len`: The length of the region to unmap. Will be rounded up.
///
/// # Returns
/// * `Ok(())` on success.
/// * `Err(MunmapError)` on failure.
pub fn munmap(&mut self, range: VirtMemoryRegion) -> Result<Vec<PageFrame>> {
if !range.is_page_aligned() {
return Err(KernelError::InvalidValue);
}
if range.size() == 0 {
return Err(KernelError::InvalidValue);
}
// Ensure len is page-sized.
self.unmap_region(range.align_to_page_boundary(), None)
}
/// Checks if a given virtual memory region is completely free.
fn is_region_free(&self, region: VirtMemoryRegion) -> bool {
// Find the VMA that might overlap with the start of our desired region.
let candidate = self.vmas.range(..=region.start_address()).next_back();
if let Some((_, prev_vma)) = candidate {
// If the previous VMA extends into our desired region, it's not
// free.
if prev_vma.region.end_address() > region.start_address() {
return false;
}
}
// Check if the next VMA starts within our desired region.
if let Some((next_vma_start, _)) = self.vmas.range(region.start_address()..).next()
&& *next_vma_start < region.end_address()
{
false
} else {
true
}
}
/// Finds a free region of at least `len` bytes. Searches downwards from
/// `MMAP_BASE`.
fn find_free_region(&self, len: usize) -> Option<VirtMemoryRegion> {
let mut last_vma_end = VA::from_value(MMAP_BASE);
// Iterate through VMAs in reverse order to find a gap.
for (_, vma) in self.vmas.iter().rev() {
let vma_start = vma.region.start_address();
let vma_end = vma.region.end_address();
if last_vma_end >= vma_end {
let gap_start = vma_end;
let gap_size = last_vma_end.value() - gap_start.value();
if gap_size >= len {
// Found a large enough gap. Place the new mapping at the top of it.
return Some(VirtMemoryRegion::new(
VA::from_value(last_vma_end.value() - len),
len,
));
}
}
last_vma_end = vma_start;
}
// Check the final gap at the beginning of the mmap area.
if last_vma_end.value() >= len {
Some(VirtMemoryRegion::new(
VA::from_value(last_vma_end.value() - len),
len,
))
} else {
None
}
}
/// Inserts a new VMA, handling overlaps and merging it with neighbors if
/// possible.
pub(super) fn insert_and_merge(&mut self, mut vma: VMArea) {
let _ = self.unmap_region(vma.region, Some(vma.clone()));
// Try to merge with next VMA.
if let Some(next_vma) = self.vmas.get(&vma.region.end_address())
&& vma.can_merge_with(next_vma)
{
// The properties are compatible. We take the region from the
// next VMA, remove it from the map, and expand our new VMA
// to cover the combined area.
let next_vma_region = self
.vmas
.remove(&next_vma.region.start_address())
.unwrap() // Should not fail, as we just got this VMA.
.region;
vma.region.expand(next_vma_region.size());
// `vma` now represents the merged region of [new, next].
}
// Try to merge with the previous VMA.
if let Some((_key, prev_vma)) = self
.vmas
.range_mut(..vma.region.start_address())
.next_back()
{
// Check if it's contiguous and compatible.
if prev_vma.region.end_address() == vma.region.start_address()
&& prev_vma.can_merge_with(&vma)
{
// The VMAs are mergeable. Expand the previous VMA to absorb the
// new one's region.
prev_vma.region.expand(vma.region.size());
return;
}
}
// If we didn't merge into a previous VMA, insert the new (and possibly
// already merged with the next) VMA into the map.
self.vmas.insert(vma.region.start_address(), vma);
}
/// Fixup the unerlying page tables whenever a VMArea is being modified.
fn fixup_pg_tables(
&mut self,
fixup_region: VirtMemoryRegion,
old_vma: VMArea,
new_vma: Option<VMArea>,
) -> Result<Vec<PageFrame>> {
let intersecting_region = fixup_region.intersection(old_vma.region);
if let Some(intersection) = intersecting_region {
match new_vma {
Some(new_vma) => {
// We always unmap if file backing-stores are involoved.
if old_vma.is_file_backed() || new_vma.is_file_backed() {
self.address_space.unmap_range(intersection)
} else {
// the VMAs are anonymously mapped. Preserve data.
if new_vma.permissions != old_vma.permissions {
self.address_space
.protect_range(
intersection,
PtePermissions::from(new_vma.permissions),
)
.map(|_| Vec::new())
} else {
// If permissions match, fixup is a noop
Ok(Vec::new())
}
}
}
None => self.address_space.unmap_range(intersection),
}
} else {
Ok(Vec::new())
}
}
/// Create a hole in the address space identifed by the region. If regions
/// overlap, shrink them. If regions lie inside the region, remove them.
///
/// This function is called by both the unmap code (replace_with = None),
/// and the insert_and_merge code (replace_with = Some(<new vma>)). The
/// `replace_with` parameter can be used to update the underlying page
/// tables accordingly.
///
/// # Returns
/// A list of all pages that were unmapped.
fn unmap_region(
&mut self,
unmap_region: VirtMemoryRegion,
replace_with: Option<VMArea>,
) -> Result<Vec<PageFrame>> {
let mut affected_vmas = Vec::new();
let unmap_start = unmap_region.start_address();
let unmap_end = unmap_region.end_address();
let mut pages_unmapped = Vec::new();
// Find all VMAs that intersect with the unmap region. Start with the
// VMA that could contain the start address.
if let Some((_, vma)) = self.vmas.range(..unmap_start).next_back()
&& vma.region.end_address() > unmap_start
{
affected_vmas.push(vma.clone());
}
// Add all other VMAs that start within the unmap region.
for (_, vma) in self.vmas.range(unmap_start..) {
if vma.region.start_address() < unmap_end {
affected_vmas.push(vma.clone());
} else {
break; // We're past the unmap region now.
}
}
if affected_vmas.is_empty() {
return Ok(Vec::new());
}
for vma in affected_vmas {
let vma_start = vma.region.start_address();
let vma_end = vma.region.end_address();
self.vmas.remove(&vma_start).unwrap();
pages_unmapped.append(&mut self.fixup_pg_tables(
unmap_region,
vma.clone(),
replace_with.clone(),
)?);
// VMA is completely contained within the unmap region. Handled by
// just removing it.
// VMA needs to be split (unmap punches a hole).
if vma_start < unmap_start && vma_end > unmap_end {
// Create left part.
let left_region =
VirtMemoryRegion::new(vma_start, unmap_start.value() - vma_start.value());
let left_vma = vma.clone_with_new_region(left_region);
self.vmas.insert(left_vma.region.start_address(), left_vma);
// Create right part.
let right_region =
VirtMemoryRegion::new(unmap_end, vma_end.value() - unmap_end.value());
let right_vma = vma.clone_with_new_region(right_region);
self.vmas
.insert(right_vma.region.start_address(), right_vma);
continue;
}
// VMA needs to be truncated at the end.
if vma_start < unmap_start {
let new_size = unmap_start.value() - vma_start.value();
let new_region = VirtMemoryRegion::new(vma_start, new_size);
let new_vma = vma.clone_with_new_region(new_region);
self.vmas.insert(new_vma.region.start_address(), new_vma);
}
// VMA needs to be truncated at the beginning.
if vma_end > unmap_end {
let new_start = unmap_end;
let new_size = vma_end.value() - new_start.value();
let new_region = VirtMemoryRegion::new(new_start, new_size);
let mut new_vma = vma.clone_with_new_region(new_region);
// Adjust file mapping offset if it's a file-backed VMA.
if let VMAreaKind::File(mapping) = &mut new_vma.kind {
let offset_change = new_start.value() - vma_start.value();
mapping.offset += offset_change as u64;
}
self.vmas.insert(new_vma.region.start_address(), new_vma);
}
}
Ok(pages_unmapped)
}
/// Attempts to clone this memory map, sharing any already-mapped writable
/// pages as CoW pages. If the VMA isn't writable, the ref count is
/// incremented.
pub fn clone_as_cow(&mut self) -> Result<Self> {
let mut new_as = AS::new()?;
let new_vmas = self.vmas.clone();
for vma in new_vmas.values() {
let mut pte_perms = PtePermissions::from(vma.permissions);
// Mark all writable pages as CoW.
if pte_perms.is_write() {
pte_perms = pte_perms.into_cow();
}
self.address_space.protect_and_clone_region(
vma.region.align_to_page_boundary(),
&mut new_as,
pte_perms,
)?;
}
Ok(Self {
vmas: new_vmas,
address_space: new_as,
})
}
pub fn address_space_mut(&mut self) -> &mut AS {
&mut self.address_space
}
pub fn vma_count(&self) -> usize {
self.vmas.len()
}
}
#[cfg(test)]
pub mod tests;

View File

@@ -0,0 +1,648 @@
use super::MemoryMap;
use crate::{
PageInfo, UserAddressSpace,
error::Result,
fs::Inode,
memory::{
PAGE_SIZE,
address::VA,
page::PageFrame,
permissions::PtePermissions,
proc_vm::{
memory_map::{AddressRequest, MMAP_BASE},
vmarea::{VMAPermissions, VMArea, VMAreaKind, VMFileMapping, tests::DummyTestInode},
},
region::VirtMemoryRegion,
},
};
use alloc::sync::Arc;
use std::sync::Mutex;
/// Represents a single operation performed on the mock page table.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum MockPageTableOp {
UnmapRange {
region: VirtMemoryRegion,
},
ProtectRange {
region: VirtMemoryRegion,
perms: PtePermissions,
},
}
pub struct MockAddressSpace {
pub ops_log: Mutex<Vec<MockPageTableOp>>,
}
impl UserAddressSpace for MockAddressSpace {
fn new() -> Result<Self> {
Ok(Self {
ops_log: Mutex::new(Vec::new()),
})
}
fn activate(&self) {
unimplemented!()
}
fn deactivate(&self) {
unimplemented!()
}
fn map_page(&mut self, _page: PageFrame, _va: VA, _perms: PtePermissions) -> Result<()> {
panic!("Should be called by the demand-pager");
}
fn unmap(&mut self, va: VA) -> Result<PageFrame> {
let region = VirtMemoryRegion::new(va, PAGE_SIZE);
self.ops_log
.lock()
.unwrap()
.push(MockPageTableOp::UnmapRange { region });
// Return a dummy page, as the caller doesn't use it.
Ok(PageFrame::from_pfn(0))
}
fn protect_range(&mut self, va_range: VirtMemoryRegion, perms: PtePermissions) -> Result<()> {
self.ops_log
.lock()
.unwrap()
.push(MockPageTableOp::ProtectRange {
region: va_range,
perms,
});
Ok(())
}
fn unmap_range(&mut self, va_range: VirtMemoryRegion) -> Result<Vec<PageFrame>> {
self.ops_log
.lock()
.unwrap()
.push(MockPageTableOp::UnmapRange { region: va_range });
Ok(Vec::new())
}
fn translate(&self, _va: VA) -> Option<PageInfo> {
None
}
fn protect_and_clone_region(
&mut self,
_region: VirtMemoryRegion,
_other: &mut Self,
_perms: PtePermissions,
) -> Result<()>
where
Self: Sized,
{
unreachable!("Not called")
}
fn remap(
&mut self,
_va: VA,
_new_page: PageFrame,
_perms: PtePermissions,
) -> Result<PageFrame> {
unreachable!("Not called")
}
}
// Helper to create a new inode Arc.
fn new_inode() -> Arc<dyn Inode> {
Arc::new(DummyTestInode)
}
// Creates a file-backed VMA for testing.
fn create_file_vma(
start: usize,
size: usize,
perms: VMAPermissions,
offset: u64,
inode: Arc<dyn Inode>,
) -> VMArea {
VMArea::new(
VirtMemoryRegion::new(VA::from_value(start), size),
VMAreaKind::File(VMFileMapping {
file: inode,
offset,
len: size as u64,
}),
perms,
)
}
// Creates an anonymous VMA for testing.
fn create_anon_vma(start: usize, size: usize, perms: VMAPermissions) -> VMArea {
VMArea::new(
VirtMemoryRegion::new(VA::from_value(start), size),
VMAreaKind::Anon,
perms,
)
}
/// Asserts that a VMA with the given properties exists.
fn assert_vma_exists(pvm: &MemoryMap<MockAddressSpace>, start: usize, size: usize) {
let vma = pvm
.find_vma(VA::from_value(start))
.expect("VMA not found at start address");
assert_eq!(
vma.region.start_address().value(),
start,
"VMA start address mismatch"
);
assert_eq!(vma.region.size(), size, "VMA size mismatch");
}
#[test]
fn test_mmap_any_empty() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let size = 3 * PAGE_SIZE;
let addr = pvm
.mmap(
AddressRequest::Any,
size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(addr.value(), MMAP_BASE - size);
assert_eq!(pvm.vmas.len(), 1);
assert_vma_exists(&pvm, MMAP_BASE - size, size);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_mmap_any_with_existing() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let size = 2 * PAGE_SIZE;
let existing_addr = MMAP_BASE - 5 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(existing_addr, size, VMAPermissions::rw()));
// This should find the gap above the existing VMA.
let new_addr = pvm
.mmap(
AddressRequest::Any,
size,
VMAPermissions::ro(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(new_addr.value(), MMAP_BASE - size);
assert_eq!(pvm.vmas.len(), 2);
// This should find the gap below the existing VMA.
let bottom_addr = pvm
.mmap(
AddressRequest::Any,
size,
VMAPermissions::ro(), // different permissions to prevent merge.
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(bottom_addr.value(), existing_addr - size);
assert_eq!(pvm.vmas.len(), 3);
assert_vma_exists(&pvm, existing_addr, 2 * PAGE_SIZE);
assert_vma_exists(&pvm, MMAP_BASE - 2 * PAGE_SIZE, 2 * PAGE_SIZE);
assert_vma_exists(&pvm, MMAP_BASE - 7 * PAGE_SIZE, 2 * PAGE_SIZE);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_mmap_hint_free() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let size = PAGE_SIZE;
let hint_addr = VA::from_value(MMAP_BASE - 10 * PAGE_SIZE);
let addr = pvm
.mmap(
AddressRequest::Hint(hint_addr),
size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(addr, hint_addr);
assert_vma_exists(&pvm, hint_addr.value(), size);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_mmap_hint_taken() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let size = 2 * PAGE_SIZE;
let hint_addr = VA::from_value(MMAP_BASE - 10 * PAGE_SIZE);
// Occupy the space where the hint is.
pvm.insert_and_merge(create_anon_vma(
hint_addr.value(),
size,
VMAPermissions::rw(),
));
// The mmap should ignore the hint and find a new spot at the top.
let new_addr = pvm
.mmap(
AddressRequest::Hint(hint_addr),
size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_ne!(new_addr, hint_addr);
assert_eq!(new_addr.value(), MMAP_BASE - size);
assert_eq!(pvm.vmas.len(), 2);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_mmap_fixed_clobber_complete_overlap() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
// Old VMA, read-only
pvm.insert_and_merge(create_anon_vma(addr, 3 * PAGE_SIZE, VMAPermissions::ro()));
// New VMA, completely overwrites the old one
let mapped_addr = pvm
.mmap(
AddressRequest::Fixed {
address: VA::from_value(addr),
permit_overlap: true,
},
3 * PAGE_SIZE,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(mapped_addr.value(), addr);
assert_eq!(pvm.vmas.len(), 1);
let vma = pvm.find_vma(VA::from_value(addr)).unwrap();
assert!(vma.permissions().write); // Check it's the new VMA
assert_vma_exists(&pvm, addr, 3 * PAGE_SIZE);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::ProtectRange {
region: VirtMemoryRegion::new(VA::from_value(addr), 3 * PAGE_SIZE),
perms: PtePermissions::rw(true)
}]
);
}
#[test]
fn test_mmap_fixed_clobber_partial_end() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, 5 * PAGE_SIZE, VMAPermissions::ro()));
// New VMA overwrites the end of the old one.
let new_addr = addr + 3 * PAGE_SIZE;
let new_size = 2 * PAGE_SIZE;
pvm.mmap(
AddressRequest::Fixed {
address: VA::from_value(new_addr),
permit_overlap: true,
},
new_size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(pvm.vmas.len(), 2);
assert_vma_exists(&pvm, addr, 3 * PAGE_SIZE); // Original is truncated
assert_vma_exists(&pvm, new_addr, new_size); // New VMA exists
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::ProtectRange {
region: VirtMemoryRegion::new(VA::from_value(new_addr), 2 * PAGE_SIZE),
perms: PtePermissions::rw(true),
}]
);
}
#[test]
fn test_mmap_fixed_clobber_partial_end_spill() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, 5 * PAGE_SIZE, VMAPermissions::ro()));
// New VMA overwrites the end of the old one.
let new_addr = addr + 3 * PAGE_SIZE;
let new_size = 4 * PAGE_SIZE;
pvm.mmap(
AddressRequest::Fixed {
address: VA::from_value(new_addr),
permit_overlap: true,
},
new_size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(pvm.vmas.len(), 2);
assert_vma_exists(&pvm, addr, 3 * PAGE_SIZE); // Original is truncated
assert_vma_exists(&pvm, new_addr, new_size); // New VMA exists
// Ensure protect region is just the overlapping region.
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::ProtectRange {
region: VirtMemoryRegion::new(VA::from_value(new_addr), 2 * PAGE_SIZE),
perms: PtePermissions::rw(true),
}]
);
}
#[test]
fn test_mmap_fixed_no_clobber_fails() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, 5 * PAGE_SIZE, VMAPermissions::ro()));
let new_addr = addr + 3 * PAGE_SIZE;
let new_size = 2 * PAGE_SIZE;
assert!(
pvm.mmap(
AddressRequest::Fixed {
address: VA::from_value(new_addr),
permit_overlap: false,
},
new_size,
VMAPermissions::rw(),
VMAreaKind::Anon,
)
.is_err()
);
}
#[test]
fn test_mmap_fixed_clobber_punch_hole() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
// A large VMA
pvm.insert_and_merge(create_anon_vma(addr, 10 * PAGE_SIZE, VMAPermissions::rw()));
// A new VMA is mapped right in the middle.
let new_addr = addr + 3 * PAGE_SIZE;
let new_size = 4 * PAGE_SIZE;
// Use different perms to prevent merging.
pvm.mmap(
AddressRequest::Fixed {
address: VA::from_value(new_addr),
permit_overlap: true,
},
new_size,
VMAPermissions::ro(),
VMAreaKind::Anon,
)
.unwrap();
assert_eq!(pvm.vmas.len(), 3);
// Left part of the original VMA
assert_vma_exists(&pvm, addr, 3 * PAGE_SIZE);
// The new VMA
assert_vma_exists(&pvm, new_addr, new_size);
// Right part of the original VMA
assert_vma_exists(&pvm, new_addr + new_size, 3 * PAGE_SIZE);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::ProtectRange {
region: VirtMemoryRegion::new(VA::from_value(new_addr), new_size),
perms: PtePermissions::ro(true),
}]
);
}
#[test]
fn test_merge_with_previous_and_next() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let perms = VMAPermissions::rw();
let addr1 = MMAP_BASE - 20 * PAGE_SIZE;
let addr2 = addr1 + 5 * PAGE_SIZE;
let addr3 = addr2 + 5 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr1, 5 * PAGE_SIZE, perms));
pvm.insert_and_merge(create_anon_vma(addr3, 5 * PAGE_SIZE, perms));
assert_eq!(pvm.vmas.len(), 2);
// Insert the middle part, which should merge with both.
pvm.insert_and_merge(create_anon_vma(addr2, 5 * PAGE_SIZE, perms));
assert_eq!(pvm.vmas.len(), 1);
assert_vma_exists(&pvm, addr1, 15 * PAGE_SIZE);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_merge_file_backed_contiguous() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let perms = VMAPermissions::rw();
let inode = new_inode();
let addr1 = MMAP_BASE - 10 * PAGE_SIZE;
let size1 = 2 * PAGE_SIZE;
let offset1 = 0;
let addr2 = addr1 + size1;
let size2 = 3 * PAGE_SIZE;
let offset2 = offset1 + size1 as u64;
// Insert two contiguous, file-backed VMAs. They should merge.
pvm.insert_and_merge(create_file_vma(
addr1,
size1,
perms,
offset1,
Arc::clone(&inode),
));
pvm.insert_and_merge(create_file_vma(
addr2,
size2,
perms,
offset2,
Arc::clone(&inode),
));
assert_eq!(pvm.vmas.len(), 1);
assert_vma_exists(&pvm, addr1, size1 + size2);
let vma = pvm.find_vma(VA::from_value(addr1)).unwrap();
match &vma.kind {
VMAreaKind::File(fm) => assert_eq!(fm.offset, offset1),
_ => panic!("Expected file-backed VMA"),
}
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_no_merge_file_backed_non_contiguous() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let perms = VMAPermissions::rw();
let inode = new_inode();
let addr1 = MMAP_BASE - 10 * PAGE_SIZE;
let size1 = 2 * PAGE_SIZE;
let offset1 = 0;
let addr2 = addr1 + size1;
let size2 = 3 * PAGE_SIZE;
let offset2 = offset1 + size1 as u64 + 123; // Non-contiguous offset!
pvm.insert_and_merge(create_file_vma(
addr1,
size1,
perms,
offset1,
Arc::clone(&inode),
));
pvm.insert_and_merge(create_file_vma(
addr2,
size2,
perms,
offset2,
Arc::clone(&inode),
));
assert_eq!(pvm.vmas.len(), 2); // Should not merge
assert_vma_exists(&pvm, addr1, size1);
assert_vma_exists(&pvm, addr2, size2);
assert!(pvm.address_space.ops_log.lock().unwrap().is_empty());
}
#[test]
fn test_munmap_full_vma() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
let size = 5 * PAGE_SIZE;
let region = VirtMemoryRegion::new(VA::from_value(addr), size);
pvm.insert_and_merge(create_anon_vma(addr, size, VMAPermissions::rw()));
assert_eq!(pvm.vmas.len(), 1);
pvm.munmap(region).unwrap();
assert!(pvm.vmas.is_empty());
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::UnmapRange { region: region }]
);
}
#[test]
fn test_munmap_truncate_start() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
let size = 5 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, size, VMAPermissions::rw()));
let unmap_size = 2 * PAGE_SIZE;
let region = VirtMemoryRegion::new(VA::from_value(addr), unmap_size);
pvm.munmap(region).unwrap();
assert_eq!(pvm.vmas.len(), 1);
let new_start = addr + unmap_size;
let new_size = size - unmap_size;
assert_vma_exists(&pvm, new_start, new_size);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::UnmapRange { region: region }]
);
}
#[test]
fn test_munmap_truncate_end() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
let size = 5 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, size, VMAPermissions::rw()));
// Unmap the last two pages
let unmap_size = 2 * PAGE_SIZE;
let region = VirtMemoryRegion::new(VA::from_value(addr + (size - unmap_size)), unmap_size);
pvm.munmap(region).unwrap();
assert_eq!(pvm.vmas.len(), 1);
let new_size = size - unmap_size;
assert_vma_exists(&pvm, addr, new_size);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::UnmapRange { region: region }]
);
}
#[test]
fn test_munmap_punch_hole() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr = MMAP_BASE - 10 * PAGE_SIZE;
let size = 10 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr, size, VMAPermissions::rw()));
// Unmap a 4-page hole in the middle
let unmap_start = addr + 3 * PAGE_SIZE;
let unmap_size = 4 * PAGE_SIZE;
let region = VirtMemoryRegion::new(VA::from_value(unmap_start), unmap_size);
pvm.munmap(region).unwrap();
assert_eq!(pvm.vmas.len(), 2);
// Left part
assert_vma_exists(&pvm, addr, 3 * PAGE_SIZE);
// Right part
let right_start = unmap_start + unmap_size;
let right_size = 3 * PAGE_SIZE;
assert_vma_exists(&pvm, right_start, right_size);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[MockPageTableOp::UnmapRange { region: region }]
);
}
#[test]
fn test_munmap_over_multiple_vmas() {
let mut pvm: MemoryMap<MockAddressSpace> = MemoryMap::new().unwrap();
let addr1 = MMAP_BASE - 20 * PAGE_SIZE;
let addr2 = addr1 + 5 * PAGE_SIZE;
let addr3 = addr2 + 5 * PAGE_SIZE;
pvm.insert_and_merge(create_anon_vma(addr1, 3 * PAGE_SIZE, VMAPermissions::rw()));
pvm.insert_and_merge(create_anon_vma(addr2, 3 * PAGE_SIZE, VMAPermissions::rw()));
pvm.insert_and_merge(create_anon_vma(addr3, 3 * PAGE_SIZE, VMAPermissions::rw()));
assert_eq!(pvm.vmas.len(), 3);
// Unmap from the middle of the first VMA to the middle of the last one.
let unmap_start = addr1 + PAGE_SIZE;
let unmap_end = addr3 + 2 * PAGE_SIZE;
let unmap_len = unmap_end - unmap_start;
let region = VirtMemoryRegion::new(VA::from_value(unmap_start), unmap_len);
pvm.munmap(region).unwrap();
assert_eq!(pvm.vmas.len(), 2);
// First VMA is truncated at the end
assert_vma_exists(&pvm, addr1, PAGE_SIZE);
// Last VMA is truncated at the start
assert_vma_exists(&pvm, unmap_end, PAGE_SIZE);
assert_eq!(
*pvm.address_space.ops_log.lock().unwrap(),
&[
MockPageTableOp::UnmapRange {
region: VirtMemoryRegion::new(VA::from_value(addr1 + PAGE_SIZE), 2 * PAGE_SIZE)
},
MockPageTableOp::UnmapRange {
region: VirtMemoryRegion::new(VA::from_value(addr2), 3 * PAGE_SIZE)
},
MockPageTableOp::UnmapRange {
region: VirtMemoryRegion::new(VA::from_value(addr3), 2 * PAGE_SIZE)
},
]
);
}

View File

@@ -0,0 +1,334 @@
//! Manages the virtual memory address space of a process.
use crate::{
UserAddressSpace,
error::{KernelError, Result},
};
use memory_map::{AddressRequest, MemoryMap};
use vmarea::{AccessKind, FaultValidation, VMAPermissions, VMArea, VMAreaKind};
use super::{PAGE_SIZE, address::VA, region::VirtMemoryRegion};
pub mod memory_map;
pub mod vmarea;
const BRK_PERMISSIONS: VMAPermissions = VMAPermissions::rw();
pub struct ProcessVM<AS: UserAddressSpace> {
mm: MemoryMap<AS>,
brk: VirtMemoryRegion,
}
impl<AS: UserAddressSpace> ProcessVM<AS> {
/// Constructs a new Process VM structure from the given VMA. The heap is
/// placed *after* the given VMA.
///
/// # Safety
/// Any pages that have been mapped into the provided address space *must*
/// corresponde to the provided VMA.
pub unsafe fn from_vma_and_address_space(vma: VMArea, addr_spc: AS) -> Self {
let mut mm = MemoryMap::with_addr_spc(addr_spc);
mm.insert_and_merge(vma.clone());
let brk = VirtMemoryRegion::new(vma.region.end_address().align_up(PAGE_SIZE), 0);
Self { mm, brk }
}
/// Constructs a new Process VM structure from the given VMA. The heap is
/// placed *after* the given VMA.
pub fn from_vma(vma: VMArea) -> Result<Self> {
let mut mm = MemoryMap::new()?;
mm.insert_and_merge(vma.clone());
let brk = VirtMemoryRegion::new(vma.region.end_address().align_up(PAGE_SIZE), 0);
Ok(Self { mm, brk })
}
pub fn from_map(map: MemoryMap<AS>, brk: VA) -> Self {
Self {
mm: map,
brk: VirtMemoryRegion::new(brk.align_up(PAGE_SIZE), 0),
}
}
pub fn empty() -> Result<Self> {
Ok(Self {
mm: MemoryMap::new()?,
brk: VirtMemoryRegion::empty(),
})
}
pub fn find_vma_for_fault(&self, addr: VA, access_type: AccessKind) -> Option<&VMArea> {
let vma = self.mm.find_vma(addr)?;
match vma.validate_fault(addr, access_type) {
FaultValidation::Valid => Some(vma),
FaultValidation::NotPresent => unreachable!(""),
FaultValidation::PermissionDenied => None,
}
}
pub fn mm_mut(&mut self) -> &mut MemoryMap<AS> {
&mut self.mm
}
pub fn current_brk(&self) -> VA {
self.brk.end_address()
}
/// Resizes the program break (the heap).
///
/// This function implements the semantics of the `brk` system call. It can
/// either grow or shrink the heap area. The new end address is always
/// aligned up to the nearest page boundary.
///
/// # Arguments
/// * `new_end_addr`: The desired new end address for the program break.
///
/// # Returns
/// * `Ok(())` on success.
/// * `Err(KernelError)` on failure. This can happen if the requested memory
/// region conflicts with an existing mapping, or if the request is invalid
/// (e.g., shrinking the break below its initial start address).
pub fn resize_brk(&mut self, mut new_end_addr: VA) -> Result<VA> {
let brk_start = self.brk.start_address();
let current_end = self.brk.end_address();
// The break cannot be shrunk to an address lower than its starting
// point.
if new_end_addr < brk_start {
return Err(KernelError::InvalidValue);
}
new_end_addr = new_end_addr.align_up(PAGE_SIZE);
let new_brk_region = VirtMemoryRegion::from_start_end_address(brk_start, new_end_addr);
if new_end_addr == current_end {
// The requested break is the same as the current one. This is a
// no-op.
return Ok(new_end_addr);
}
// Grow the break
if new_end_addr > current_end {
let growth_size = new_end_addr.value() - current_end.value();
self.mm.mmap(
AddressRequest::Fixed {
address: current_end,
permit_overlap: false,
},
growth_size,
BRK_PERMISSIONS,
VMAreaKind::Anon,
)?;
self.brk = new_brk_region;
return Ok(new_end_addr);
}
// Shrink the break
// At this point, we know `new_end_aligned < current_end`.
let unmap_region = VirtMemoryRegion::from_start_end_address(new_end_addr, current_end);
self.mm.munmap(unmap_region)?;
self.brk = new_brk_region;
Ok(new_end_addr)
}
pub fn clone_as_cow(&mut self) -> Result<Self> {
Ok(Self {
mm: self.mm.clone_as_cow()?,
brk: self.brk,
})
}
}
#[cfg(test)]
mod tests {
use super::memory_map::tests::MockAddressSpace;
use super::*;
use crate::error::KernelError;
fn setup_vm() -> ProcessVM<MockAddressSpace> {
let text_vma = VMArea {
region: VirtMemoryRegion::new(VA::from_value(0x1000), PAGE_SIZE),
kind: VMAreaKind::Anon, // Simplification for test
permissions: VMAPermissions::rx(),
};
ProcessVM::from_vma(text_vma).unwrap()
}
#[test]
fn test_initial_state() {
// Given: a newly created ProcessVM
let vm = setup_vm();
let initial_brk_start = VA::from_value(0x1000 + PAGE_SIZE);
assert_eq!(vm.brk.start_address(), initial_brk_start);
assert_eq!(vm.brk.size(), 0);
assert_eq!(vm.current_brk(), initial_brk_start);
// And the break region itself should not be mapped
assert!(vm.mm.find_vma(initial_brk_start).is_none());
}
#[test]
fn test_brk_first_growth() {
// Given: a VM with a zero-sized heap
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
let brk_addr = initial_brk_start.add_bytes(1);
let new_brk = vm.resize_brk(brk_addr).unwrap();
// The new break should be page-aligned
let expected_brk_end = brk_addr.align_up(PAGE_SIZE);
assert_eq!(new_brk, expected_brk_end);
assert_eq!(vm.current_brk(), expected_brk_end);
assert_eq!(vm.brk.size(), PAGE_SIZE);
// And a new VMA for the heap should now exist with RW permissions
let heap_vma = vm
.mm
.find_vma(initial_brk_start)
.expect("Heap VMA should exist");
assert_eq!(heap_vma.region.start_address(), initial_brk_start);
assert_eq!(heap_vma.region.end_address(), expected_brk_end);
assert_eq!(heap_vma.permissions, VMAPermissions::rw());
}
#[test]
fn test_brk_subsequent_growth() {
// Given: a VM with an existing heap
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
vm.resize_brk(initial_brk_start.add_bytes(1)).unwrap(); // First growth
assert_eq!(vm.brk.size(), PAGE_SIZE);
// When: we grow the break again
let new_brk = vm.resize_brk(vm.current_brk().add_pages(1)).unwrap();
// Then: the break should be extended
let expected_brk_end = initial_brk_start.add_pages(2);
assert_eq!(new_brk, expected_brk_end);
assert_eq!(vm.current_brk(), expected_brk_end);
assert_eq!(vm.brk.size(), 2 * PAGE_SIZE);
// And the single heap VMA should be larger, not a new one
let heap_vma = vm.mm.find_vma(initial_brk_start).unwrap();
assert_eq!(heap_vma.region.end_address(), expected_brk_end);
assert_eq!(vm.mm.vma_count(), 2); // Text VMA + one Heap VMA
}
#[test]
fn test_brk_shrink() {
// Given: a VM with a 3-page heap
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
vm.resize_brk(initial_brk_start.add_pages(3)).unwrap();
assert_eq!(vm.brk.size(), 3 * PAGE_SIZE);
// When: we shrink the break by one page
let new_brk_addr = initial_brk_start.add_pages(2);
let new_brk = vm.resize_brk(new_brk_addr).unwrap();
// Then: the break should be updated
assert_eq!(new_brk, new_brk_addr);
assert_eq!(vm.current_brk(), new_brk_addr);
assert_eq!(vm.brk.size(), 2 * PAGE_SIZE);
// And the memory for the shrunken page should now be unmapped
assert!(vm.mm.find_vma(new_brk_addr.add_bytes(1)).is_none());
// But the remaining heap should still be mapped
assert!(vm.mm.find_vma(initial_brk_start).is_some());
}
#[test]
fn test_brk_shrink_to_zero() {
// Given: a VM with a 2-page heap
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
vm.resize_brk(initial_brk_start.add_pages(2)).unwrap();
// When: we shrink the break all the way back to its start
let new_brk = vm.resize_brk(initial_brk_start).unwrap();
// Then: the break should be zero-sized again
assert_eq!(new_brk, initial_brk_start);
assert_eq!(vm.current_brk(), initial_brk_start);
assert_eq!(vm.brk.size(), 0);
// And the heap VMA should be completely gone
assert!(vm.mm.find_vma(initial_brk_start).is_none());
assert_eq!(vm.mm.vma_count(), 1); // Only the text VMA remains
}
#[test]
fn test_brk_no_op() {
// Given: a VM with a 2-page heap
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
let current_brk_end = vm.resize_brk(initial_brk_start.add_pages(2)).unwrap();
// When: we resize the break to its current end
let new_brk = vm.resize_brk(current_brk_end).unwrap();
// Then: nothing should change
assert_eq!(new_brk, current_brk_end);
assert_eq!(vm.brk.size(), 2 * PAGE_SIZE);
assert_eq!(vm.mm.vma_count(), 2);
}
#[test]
fn test_brk_invalid_shrink_below_start() {
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
vm.resize_brk(initial_brk_start.add_pages(1)).unwrap();
let original_len = vm.brk.size();
// We try to shrink the break below its starting point
let result = vm.resize_brk(VA::from_value(initial_brk_start.value() - 1));
// It should fail with an InvalidValue error
assert!(matches!(result, Err(KernelError::InvalidValue)));
// And the state of the break should not have changed
assert_eq!(vm.brk.start_address(), initial_brk_start);
assert_eq!(vm.brk.size(), original_len);
}
#[test]
fn test_brk_growth_collision() {
// Given: a VM with another mapping right where the heap would grow
let mut vm = setup_vm();
let initial_brk_start = vm.brk.start_address();
let obstacle_addr = initial_brk_start.add_pages(2);
let obstacle_vma = VMArea {
region: VirtMemoryRegion::new(obstacle_addr, PAGE_SIZE),
kind: VMAreaKind::Anon,
permissions: VMAPermissions::ro(),
};
vm.mm.insert_and_merge(obstacle_vma);
assert_eq!(vm.mm.vma_count(), 2);
// When: we try to grow the break past the obstacle
let result = vm.resize_brk(initial_brk_start.add_pages(3));
// Then: the mmap should fail, resulting in an error
// The specific error comes from your mmap implementation.
assert!(matches!(result, Err(KernelError::InvalidValue)));
// And the break should not have grown at all
assert_eq!(vm.brk.size(), 0);
assert_eq!(vm.current_brk(), initial_brk_start);
}
}

View File

@@ -0,0 +1,566 @@
//! Virtual Memory Areas (VMAs) within a process's address space.
//!
//! The [`VMArea`] represents a contiguous range of virtual memory with a
//! uniform set of properties, such as permissions and backing source. A
//! process's memory map is composed of a set of these VMAs.
//!
//! A VMA can be either:
//! - File-backed (via [`VMAreaKind::File`]): Used for loading executable code
//! and initialized data from files, most notably ELF binaries.
//! - Anonymous (via [`VMAreaKind::Anon`]): Used for demand-zeroed memory like
//! the process stack, heap, and BSS sections.
use crate::{
fs::Inode,
memory::{PAGE_MASK, PAGE_SIZE, address::VA, region::VirtMemoryRegion},
};
use alloc::sync::Arc;
use object::{
Endian,
elf::{PF_R, PF_W, PF_X, ProgramHeader64},
read::elf::ProgramHeader,
};
/// Describes the permissions assigned for this VMA.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct VMAPermissions {
pub read: bool,
pub write: bool,
pub execute: bool,
}
impl VMAPermissions {
pub const fn rw() -> Self {
Self {
read: true,
write: true,
execute: false,
}
}
pub const fn rx() -> Self {
Self {
read: true,
write: false,
execute: true,
}
}
pub const fn ro() -> Self {
Self {
read: true,
write: false,
execute: false,
}
}
}
/// Describes the kind of access that occured during a page fault.
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum AccessKind {
/// The CPU attempted to read the faulting address.
Read,
/// The CPU attempted to write to the faulting address.
Write,
/// The CPU attempted to execute the instruciton at the faulting address.
Execute,
}
/// The result of checking a memory access against a `VMArea`.
///
/// This enum tells the fault handler how to proceed.
#[derive(Debug, PartialEq, Eq)]
pub enum FaultValidation {
/// The access is valid. The address is within the VMA and has the required
/// permissions. The fault handler should proceed with populating the page.
Valid,
/// The address is not within this VMA's region. The fault handler should
/// continue searching other VMAs.
NotPresent,
/// The address is within this VMA's region, but the access kind is not
/// permitted (e.g., writing to a read-only page). This is a definitive
/// segmentation fault. The fault handler can immediately stop its search
/// and terminate the process.
PermissionDenied,
}
/// Describes a read operation from a file required to satisfy a page fault.
pub struct VMAFileRead {
/// The absolute offset into the backing file to start reading from.
pub file_offset: u64,
/// The offset into the destination page where the read data should be
/// written.
pub page_offset: usize,
/// The number of bytes to read from the file and write to the page.
pub read_len: usize,
/// The file that backs this VMA mapping.
pub inode: Arc<dyn Inode>,
}
/// Represents a mapping to a region of a file that backs a `VMArea`.
///
/// This specifies a "slice" of a file, defined by an `offset` and `len`, that
/// contains the initialized data for a memory segment (e.g., the .text or .data
/// sections of an ELF binary).
#[derive(Clone)]
pub struct VMFileMapping {
pub(super) file: Arc<dyn Inode>,
pub(super) offset: u64,
pub(super) len: u64,
}
impl PartialEq for VMFileMapping {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.file, &other.file) && self.offset == other.offset && self.len == other.len
}
}
impl VMFileMapping {
/// Returns a clone of the reference-counted `Inode` for this mapping.
pub fn file(&self) -> Arc<dyn Inode> {
self.file.clone()
}
/// Returns the starting offset of the mapping's data within the file.
pub fn offset(&self) -> u64 {
self.offset
}
/// Returns the length of the mapping's data within the file (`p_filesz`).
pub fn file_len(&self) -> u64 {
self.len
}
}
/// Defines the backing source for a `VMArea`.
#[derive(Clone, PartialEq)]
pub enum VMAreaKind {
/// The VMA is backed by a file.
///
/// On a page fault, the kernel will read data from the specified file
/// region to populate the page. Any part of the VMA's memory region that
/// extends beyond the file mapping's length (`p_memsz` > `p_filesz`) is
/// treated as BSS and will be zero-filled.
File(VMFileMapping),
/// The VMA is an anonymous, demand-zeroed mapping.
///
/// It has no backing file. On a page fault, the kernel will provide a new
/// physical page that has been zero-filled. This is used for the heap,
/// and the stack.
Anon,
}
impl VMAreaKind {
pub fn new_anon() -> Self {
Self::Anon
}
pub fn new_file(file: Arc<dyn Inode>, offset: u64, len: u64) -> Self {
Self::File(VMFileMapping { file, offset, len })
}
}
/// A Virtual Memory Area (VMA).
///
/// This represents a contiguous region of virtual memory within a process's
/// address space that shares a common set of properties, such as memory
/// permissions and backing source. It is the kernel's primary abstraction for
/// managing a process's memory layout.
#[derive(Clone, PartialEq)]
pub struct VMArea {
pub(super) region: VirtMemoryRegion,
pub(super) kind: VMAreaKind,
pub(super) permissions: VMAPermissions,
}
impl VMArea {
/// Creates a new `VMArea`.
///
/// # Arguments
/// * `region`: The virtual address range for this VMA.
/// * `kind`: The backing source (`File` or `Anon`).
/// * `permissions`: The memory permissions for the region.
pub fn new(region: VirtMemoryRegion, kind: VMAreaKind, permissions: VMAPermissions) -> Self {
Self {
region,
kind,
permissions,
}
}
/// Creates a file-backed `VMArea` directly from an ELF program header.
///
/// This is a convenience function used by the ELF loader. It parses the
/// header to determine the virtual address range, file mapping details, and
/// memory permissions.
///
/// # Arguments
/// * `f`: A handle to the ELF file's inode.
/// * `hdr`: The ELF program header (`LOAD` segment) to create the VMA from.
/// * `endian`: The endianness of the ELF file, for correctly parsing header fields.
pub fn from_pheader<E: Endian>(
f: Arc<dyn Inode>,
hdr: ProgramHeader64<E>,
endian: E,
) -> VMArea {
let mut permissions = VMAPermissions {
read: false,
write: false,
execute: false,
};
if hdr.p_flags(endian) & PF_X != 0 {
permissions.execute = true;
}
if hdr.p_flags(endian) & PF_R != 0 {
permissions.read = true;
}
if hdr.p_flags(endian) & PF_W != 0 {
permissions.write = true;
}
Self {
region: VirtMemoryRegion::new(
VA::from_value(hdr.p_vaddr(endian) as usize),
hdr.p_memsz(endian) as usize,
),
kind: VMAreaKind::File(VMFileMapping {
file: f,
offset: hdr.p_offset(endian),
len: hdr.p_filesz(endian),
}),
permissions,
}
}
/// Checks if a page fault is valid for this VMA.
///
/// This is the primary function for a fault handler to use. It verifies
/// both that the faulting address is within the VMA's bounds and that the
/// type of access (read, write, or execute) is permitted by the VMA's
/// permissions.
///
/// # Returns
///
/// - [`AccessValidation::Valid`]: If the address and permissions are valid.
/// - [`AccessValidation::NotPresent`]: If the address is outside this VMA.
/// - [`AccessValidation::PermissionDenied`]: If the address is inside this
/// VMA but the access is not allowed. This allows the caller to immediately
/// identify a segmentation fault without checking other VMAs.
pub fn validate_fault(&self, addr: VA, kind: AccessKind) -> FaultValidation {
if !self.contains_address(addr) {
return FaultValidation::NotPresent;
}
// The address is in our region. Now, check permissions.
let permissions_ok = match kind {
AccessKind::Read => self.permissions.read,
AccessKind::Write => self.permissions.write,
AccessKind::Execute => self.permissions.execute,
};
if permissions_ok {
FaultValidation::Valid
} else {
FaultValidation::PermissionDenied
}
}
/// Returns a reference to the kind of backing for this VMA.
pub fn kind(&self) -> &VMAreaKind {
&self.kind
}
/// Resolves a page fault within this VMA.
///
/// If the fault is in a region backed by a file, this function calculates
/// the file offset and page memory offset required to load the data. It
/// correctly handles the ELF alignment congruence rule and mixed pages
/// containing both file-backed data and BSS data.
///
/// # Arguments
/// * `faulting_addr`: The virtual address that caused the fault.
///
/// # Returns
/// * `Some(VMAFileRead)` if the faulting page contains any data that must
/// be loaded from the file. The caller is responsible for zeroing the
/// page *before* performing the read.
/// * `None` if the VMA is anonymous (`Anon`) or if the faulting page is
/// purely BSS (i.e., contains no data from the file) and should simply be
/// zero-filled.
pub fn resolve_fault(&self, faulting_addr: VA) -> Option<VMAFileRead> {
// Match on the kind of VMA. If it's anonymous, there's no file to read from.
let mapping = match &self.kind {
VMAreaKind::Anon => return None,
VMAreaKind::File(mapping) => mapping,
};
let vma_start_addr = self.region.start_address();
let p_vaddr = vma_start_addr.value() as u64;
let p_offset = mapping.offset();
let p_filesz = mapping.file_len();
let vaddr_page_offset = p_vaddr & PAGE_MASK as u64;
// The virtual address where the page-aligned mapping starts.
let map_start_vaddr = vma_start_addr.page_aligned().value() as u64;
// The file offset corresponding to the start of the page-aligned
// mapping.
let map_start_foffset = p_offset - vaddr_page_offset;
let fault_page_vaddr = faulting_addr.page_aligned().value() as u64;
let fault_page_offset_in_map = fault_page_vaddr.saturating_sub(map_start_vaddr);
let file_offset_for_page_start = map_start_foffset + fault_page_offset_in_map;
let page_write_offset = if fault_page_vaddr == map_start_vaddr {
vaddr_page_offset as usize
} else {
0
};
// The starting point in the file we need is the max of our calculated
// start and the actual start of the segment's data in the file.
let read_start_file_offset = core::cmp::max(file_offset_for_page_start, p_offset);
// The end point in the file is the segment's data end.
let read_end_file_offset = p_offset + p_filesz;
// The number of bytes to read is the length of this intersection.
let read_len = read_end_file_offset.saturating_sub(read_start_file_offset) as usize;
// The final read length cannot exceed what's left in the page.
let final_read_len = core::cmp::min(read_len, PAGE_SIZE - page_write_offset);
if final_read_len == 0 {
// There is no file data to read for this page. It is either fully
// BSS or a hole in the file mapping. The caller should use a zeroed
// page.
return None;
}
Some(VMAFileRead {
file_offset: read_start_file_offset,
page_offset: page_write_offset,
read_len: final_read_len,
inode: mapping.file.clone(),
})
}
pub fn permissions(&self) -> VMAPermissions {
self.permissions
}
pub fn contains_address(&self, addr: VA) -> bool {
self.region.contains_address(addr)
}
/// Checks if this VMA can be merged with an adjacent one.
///
/// This function assumes the other VMA is immediately adjacent in memory.
/// Merging is possible if permissions are identical and the backing storage
/// is of a compatible and contiguous nature.
pub(super) fn can_merge_with(&self, other: &VMArea) -> bool {
if self.permissions != other.permissions {
return false;
}
match (&self.kind, &other.kind) {
(VMAreaKind::Anon, VMAreaKind::Anon) => true,
(VMAreaKind::File(self_map), VMAreaKind::File(other_map)) => {
// Check that they point to the same inode.
let same_file = Arc::ptr_eq(&self_map.file, &other_map.file);
// Check that the file offsets are contiguous. `other` VMA's
// offset must be `self`'s offset + `self`'s size.
let contiguous_offset =
other_map.offset == self_map.offset + self.region.size() as u64;
same_file && contiguous_offset
}
_ => false,
}
}
#[must_use]
pub(super) fn clone_with_new_region(&self, new_region: VirtMemoryRegion) -> Self {
let mut clone = self.clone();
clone.region = new_region;
clone
}
/// Returns true if the VMA is backed by a file. False if it's an anonymous
/// mapping.
pub fn is_file_backed(&self) -> bool {
matches!(self.kind, VMAreaKind::File(_))
}
}
#[cfg(test)]
pub mod tests {
use crate::fs::InodeId;
use super::*;
use async_trait::async_trait;
#[derive(Debug)]
pub struct DummyTestInode;
#[async_trait]
impl Inode for DummyTestInode {
fn id(&self) -> InodeId {
unreachable!("Not called")
}
}
pub fn create_test_vma(vaddr: usize, memsz: usize, file_offset: u64, filesz: u64) -> VMArea {
let dummy_inode = Arc::new(DummyTestInode);
VMArea::new(
VirtMemoryRegion::new(VA::from_value(vaddr), memsz),
VMAreaKind::File(VMFileMapping {
file: dummy_inode,
offset: file_offset,
len: filesz,
}),
VMAPermissions::rw(),
)
}
#[test]
fn simple_aligned_segment() {
// A segment that is perfectly aligned to page boundaries.
let vma = create_test_vma(0x20000, 0x2000, 0x4000, 0x2000);
// Fault in the first page
let fault_addr = VA::from_value(0x20500);
let result = vma.resolve_fault(fault_addr).expect("Should resolve");
assert_eq!(result.file_offset, 0x4000);
assert_eq!(result.page_offset, 0);
assert_eq!(result.read_len, PAGE_SIZE);
// Fault in the second page
let fault_addr = VA::from_value(0x21500);
let result = vma.resolve_fault(fault_addr).expect("Should resolve");
assert_eq!(result.file_offset, 0x5000);
assert_eq!(result.page_offset, 0);
assert_eq!(result.read_len, PAGE_SIZE);
}
#[test]
fn unaligned_segment() {
// vaddr and file offset are not page-aligned, but are congruent modulo
// the page size. vaddr starts 0xf00 bytes into page 0x40000. filesz
// (0x300) is smaller than a page.
let vma = create_test_vma(0x40100, 0x300, 0x5100, 0x300);
// Fault anywhere in this small segment
let fault_addr = VA::from_value(0x40280);
let result = vma.resolve_fault(fault_addr).expect("Should resolve");
// The handler needs to map page 0x40000.
// The read must start from the true file offset (p_offset).
assert_eq!(result.file_offset, 0x5100);
// The data must be written 0xf00 bytes into the destination page.
assert_eq!(result.page_offset, 0x100);
// We only read the number of bytes specified in filesz.
assert_eq!(result.read_len, 0x300);
}
#[test]
fn unaligned_segment_spanning_pages() {
// An unaligned segment that is large enough to span multiple pages.
// Starts at 0x40F00, size 0x800. Ends at 0x41700. Covers the last 0x100
// bytes of page 0x40000 and the first 0x700 of page 0x41000.
let vma = create_test_vma(0x40F00, 0x800, 0x5F00, 0x800);
let fault_addr1 = VA::from_value(0x40F80);
let result1 = vma
.resolve_fault(fault_addr1)
.expect("Should resolve first page");
assert_eq!(result1.file_offset, 0x5F00, "File offset for first page");
assert_eq!(result1.page_offset, 0xF00, "Page offset for first page");
assert_eq!(result1.read_len, 0x100, "Read length for first page");
let fault_addr2 = VA::from_value(0x41200);
let result2 = vma
.resolve_fault(fault_addr2)
.expect("Should resolve second page");
// The read should start where the first one left off: 0x5F00 + 0x100
assert_eq!(result2.file_offset, 0x6000, "File offset for second page");
assert_eq!(result2.page_offset, 0, "Page offset for second page");
// The remaining bytes of the file need to be read. 0x800 - 0x100
assert_eq!(result2.read_len, 0x700, "Read length for second page");
}
#[test]
fn mixed_data_bss_page() {
// A segment where the data ends partway through a page, and BSS begins.
// filesz = 0x1250 (one full page, and 0x250 bytes into the second page)
// memsz = 0x3000 (lots of BSS)
let vma = create_test_vma(0x30000, 0x3000, 0x8000, 0x1250);
let fault_addr_data = VA::from_value(0x30100);
let result_data = vma
.resolve_fault(fault_addr_data)
.expect("Should resolve full data page");
assert_eq!(result_data.file_offset, 0x8000);
assert_eq!(result_data.page_offset, 0);
assert_eq!(result_data.read_len, 0x1000);
let fault_addr_mixed_data = VA::from_value(0x31100);
let result_mixed = vma
.resolve_fault(fault_addr_mixed_data)
.expect("Should resolve mixed page");
assert_eq!(result_mixed.file_offset, 0x9000);
assert_eq!(result_mixed.page_offset, 0);
assert_eq!(
result_mixed.read_len, 0x250,
"Should only read the remaining file bytes"
);
// Fault in the *BSS* part of the same mixed page. This should trigger
// the exact same read operation.
let fault_addr_mixed_bss = VA::from_value(0x31800);
let result_mixed_2 = vma
.resolve_fault(fault_addr_mixed_bss)
.expect("Should resolve mixed page from BSS fault");
assert_eq!(result_mixed_2.file_offset, 0x9000);
assert_eq!(result_mixed_2.page_offset, 0);
assert_eq!(result_mixed_2.read_len, 0x250);
}
#[test]
fn pure_bss_fault() {
// Using the same VMA as the mixed test, but faulting in a page that is
// entirely BSS.
let vma = create_test_vma(0x30000, 0x3000, 0x8000, 0x1250);
let fault_addr = VA::from_value(0x32100); // 0x30000 + 0x1250 is the start of BSS.
let result = vma.resolve_fault(fault_addr);
assert!(result.is_none(), "Pure BSS fault should return None");
}
#[test]
fn anonymous_mapping() {
// An anonymous VMA should never result in a file read.
let vma = VMArea::new(
VirtMemoryRegion::new(VA::from_value(0x50000), 0x1000),
VMAreaKind::Anon,
VMAPermissions::rw(),
);
let fault_addr = VA::from_value(0x50500);
let result = vma.resolve_fault(fault_addr);
assert!(result.is_none(), "Anonymous VMA fault should return None");
}
}

View File

@@ -0,0 +1,538 @@
//! `region` module: Contiguous memory regions.
//!
//! This module defines `MemoryRegion<T>`, a generic abstraction for handling
//! ranges of memory in both physical and virtual address spaces.
//!
//! It provides utility methods for checking alignment, computing bounds,
//! checking containment and overlap, and mapping between physical and virtual
//! memory spaces using an `AddressTranslator`.
//!
//! ## Key Types
//! - `MemoryRegion<T>`: Generic over `MemKind` (either `Physical` or `Virtual`).
//! - `PhysMemoryRegion`: A physical memory region.
//! - `VirtMemoryRegion`: A virtual memory region.
//!
//! ## Common Operations
//! - Construction via start address + size or start + end.
//! - Checking containment and overlap of regions.
//! - Page-based offset shifting with `add_pages`.
//! - Mapping between address spaces using an `AddressTranslator`.
//!
//! ## Example
//! ```rust
//! use libkernel::memory::{address::*, region::*};
//!
//! let pa = PA::from_value(0x1000);
//! let region = PhysMemoryRegion::new(pa, 0x2000);
//! assert!(region.contains_address(PA::from_value(0x1FFF)));
//!
//! let mapped = region.map_via::<IdentityTranslator>();
//! ```
use crate::memory::PAGE_MASK;
use super::{
PAGE_SHIFT, PAGE_SIZE,
address::{Address, AddressTranslator, MemKind, Physical, User, Virtual},
page::PageFrame,
};
/// A contiguous memory region of a specific memory kind (e.g., physical or virtual).
///
/// The `T` parameter is either `Physical` or `Virtual`, enforcing type safety
/// between address spaces.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct MemoryRegion<T: MemKind> {
address: Address<T, ()>,
size: usize,
}
impl<T: MemKind> MemoryRegion<T> {
/// Create a new memory region from a start address and a size in bytes.
pub const fn new(address: Address<T, ()>, size: usize) -> Self {
Self { address, size }
}
/// Create an empty region with a size of 0 and address 0.
pub const fn empty() -> Self {
Self {
address: Address::from_value(0),
size: 0,
}
}
/// Create a memory region from a start and end address.
///
/// The size is calculated as `end - start`. No alignment is enforced.
pub fn from_start_end_address(start: Address<T, ()>, end: Address<T, ()>) -> Self {
assert!(end >= start);
Self {
address: start,
size: (end.value() - start.value()),
}
}
/// Return a new region with the same size but a different start address.
pub fn with_start_address(mut self, new_start: Address<T, ()>) -> Self {
self.address = new_start;
self
}
/// Return the starting address of the region.
pub const fn start_address(self) -> Address<T, ()> {
self.address
}
/// Return the size of the region in bytes.
pub const fn size(self) -> usize {
self.size
}
/// Return the end address (exclusive) of the region.
pub const fn end_address(self) -> Address<T, ()> {
Address::from_value(self.address.value() + self.size)
}
/// Return the end address (inclusive) of the region.
pub const fn end_address_inclusive(self) -> Address<T, ()> {
Address::from_value(self.address.value() + self.size.saturating_sub(1))
}
/// Returns `true` if the start address is page-aligned.
pub fn is_page_aligned(self) -> bool {
self.address.is_page_aligned()
}
/// Return a new region with the given size, keeping the same start address.
pub fn with_size(self, size: usize) -> Self {
Self {
address: self.address,
size,
}
}
/// Returns `true` if this region overlaps with `other`.
///
/// Overlap means any portion of the address range intersects.
pub fn overlaps(self, other: Self) -> bool {
let start1 = self.start_address().value();
let end1 = self.end_address().value();
let start2 = other.start_address().value();
let end2 = other.end_address().value();
!(end1 <= start2 || end2 <= start1)
}
/// Returns `true` if this region lies before `other`.
///
/// If the regions are adjacent, this function returns `true`. If the
/// regions overlap, this functions returns `false`.
pub fn is_before(self, other: Self) -> bool {
self.end_address() <= other.start_address()
}
/// Returns `true` if this region lies after `other`.
///
/// If the regions are adjacent, this function returns `true`. If the
/// regions overlap, this functions returns `false`.
pub fn is_after(self, other: Self) -> bool {
self.start_address() >= other.end_address()
}
/// Try to merge this region with another.
///
/// If the regions overlap or are contiguous, returns a new merged region.
/// Otherwise, returns `None`.
pub fn merge(self, other: Self) -> Option<Self> {
let start1 = self.address;
let end1 = self.end_address();
let start2 = other.address;
let end2 = other.end_address();
if end1 >= start2 && start1 <= end2 {
let merged_start = core::cmp::min(start1, start2);
let merged_end = core::cmp::max(end1, end2);
Some(Self {
address: merged_start,
size: merged_end.value() - merged_start.value(),
})
} else {
None
}
}
/// Returns `true` if this region fully contains `other`.
pub fn contains(self, other: Self) -> bool {
self.start_address().value() <= other.start_address().value()
&& self.end_address().value() >= other.end_address().value()
}
/// Returns `true` if this region contains the given address.
pub fn contains_address(self, addr: Address<T, ()>) -> bool {
let val = addr.value();
val >= self.start_address().value() && val < self.end_address().value()
}
/// Shift the region forward by `n` pages.
///
/// Decreases the size by the corresponding number of bytes.
/// Will saturate at zero if size underflows.
#[must_use]
pub fn add_pages(self, n: usize) -> Self {
let offset: Address<T, ()> = Address::from_value(n << PAGE_SHIFT);
Self {
address: Address::from_value(self.address.value() + offset.value()),
size: self.size.saturating_sub(offset.value()),
}
}
/// Converts this region into a `MappableRegion`.
///
/// This calculates the smallest page-aligned region that fully contains the
/// current region, and captures the original start address's offset from
/// the new aligned start.
pub fn to_mappable_region(self) -> MappableRegion<T> {
let aligned_start_addr = self.start_address().align(PAGE_SIZE);
let aligned_end_addr = self.end_address().align_up(PAGE_SIZE);
let aligned_region =
MemoryRegion::from_start_end_address(aligned_start_addr, aligned_end_addr);
MappableRegion {
region: aligned_region,
offset_from_page_start: self.start_address().page_offset(),
}
}
/// Increases the capacity of the region by size bytes.
pub(crate) fn expand(&mut self, size: usize) {
assert!(size >= self.size);
assert!(size & PAGE_MASK == 0);
self.size += size;
}
/// Calculates the common overlapping region between `self` and `other`.
///
/// If the regions overlap, this returns a `Some(MemoryRegion)` representing
/// the shared intersection. If they are merely adjacent or do not overlap at
/// all, this returns `None`.
///
/// Visually, if region `A = [ A_start ... A_end )` and
/// region `B = [ B_start ... B_end )`, the intersection is
/// `[ max(A_start, B_start) ... min(A_end, B_end) )`.
///
/// # Example
///
/// ```
/// use libkernel::memory::region::VirtMemoryRegion;
/// use libkernel::memory::address::VA;
/// let region1 = VirtMemoryRegion::new(VA::from_value(0x1000), 0x2000); // Range: [0x1000, 0x3000)
/// let region2 = VirtMemoryRegion::new(VA::from_value(0x2000), 0x2000); // Range: [0x2000, 0x4000)
///
/// let intersection = region1.intersection(region2).unwrap();
///
/// assert_eq!(intersection.start_address().value(), 0x2000);
/// assert_eq!(intersection.end_address().value(), 0x3000);
/// assert_eq!(intersection.size(), 0x1000);
/// ```
pub fn intersection(self, other: Self) -> Option<Self> {
// Determine the latest start address and the earliest end address.
let intersection_start = core::cmp::max(self.start_address(), other.start_address());
let intersection_end = core::cmp::min(self.end_address(), other.end_address());
// A valid, non-empty overlap exists only if the start of the
// potential intersection is before its end.
if intersection_start < intersection_end {
Some(Self::from_start_end_address(
intersection_start,
intersection_end,
))
} else {
None
}
}
/// Returns a new region that is page-aligned and fully contains the
/// original.
///
/// This is achieved by rounding the region's start address down to the
/// nearest page boundary and rounding the region's end address up to the
/// nearest page boundary.
///
/// # Example
/// ```
/// use libkernel::memory::region::VirtMemoryRegion;
/// use libkernel::memory::address::VA;
/// let region = VirtMemoryRegion::new(VA::from_value(0x1050), 0x1F00);
/// // Original region: [0x1050, 0x2F50)
///
/// let aligned_region = region.align_to_page_boundary();
///
/// // Aligned region: [0x1000, 0x3000)
/// assert_eq!(aligned_region.start_address().value(), 0x1000);
/// assert_eq!(aligned_region.end_address().value(), 0x3000);
/// assert_eq!(aligned_region.size(), 0x2000);
/// ```
#[must_use]
pub fn align_to_page_boundary(self) -> Self {
let aligned_start = self.start_address().align(PAGE_SIZE);
let aligned_end = self.end_address().align_up(PAGE_SIZE);
Self::from_start_end_address(aligned_start, aligned_end)
}
/// Returns an iterator that yields the starting address of each 4KiB page
/// contained within the memory region.
///
/// The iterator starts at the region's base address and advances in 4KiB
/// (`PAGE_SIZE`) increments. If the region's size is not a perfect multiple
/// of `PAGE_SIZE`, any trailing memory fragment at the end of the region
/// that does not constitute a full page is ignored.
///
/// A region with a size of zero will produce an empty iterator.
///
/// # Examples
///
/// ```
/// use libkernel::memory::PAGE_SIZE;
/// use libkernel::memory::address::VA;
/// use libkernel::memory::region::VirtMemoryRegion;
/// let start_va = VA::from_value(0x10000);
/// // A region covering 2.5 pages of memory.
/// let region = VirtMemoryRegion::new(start_va, 2 * PAGE_SIZE + 100);
///
/// let pages: Vec<VA> = region.iter_pages().collect();
///
/// // The iterator should yield addresses for the two full pages.
/// assert_eq!(pages.len(), 2);
/// assert_eq!(pages[0], VA::from_value(0x10000));
/// assert_eq!(pages[1], VA::from_value(0x11000));
/// ```
pub fn iter_pages(self) -> impl Iterator<Item = Address<T, ()>> {
let mut count = 0;
let pages_count = self.size >> PAGE_SHIFT;
core::iter::from_fn(move || {
let addr = self.start_address().add_pages(count);
if count < pages_count {
count += 1;
Some(addr)
} else {
None
}
})
}
pub fn iter_pfns(self) -> impl Iterator<Item = PageFrame> {
let mut count = 0;
let pages_count = self.size >> PAGE_SHIFT;
let start = self.start_address().to_pfn();
core::iter::from_fn(move || {
let pfn = PageFrame::from_pfn(start.value() + count);
if count < pages_count {
count += 1;
Some(pfn)
} else {
None
}
})
}
}
/// A memory region in physical address space.
pub type PhysMemoryRegion = MemoryRegion<Physical>;
impl PhysMemoryRegion {
/// Map the physical region to virtual space using a translator.
pub fn map_via<T: AddressTranslator<()>>(self) -> VirtMemoryRegion {
VirtMemoryRegion::new(self.address.to_va::<T>(), self.size)
}
}
/// A memory region in virtual address space.
pub type VirtMemoryRegion = MemoryRegion<Virtual>;
impl VirtMemoryRegion {
/// Map the virtual region to physical space using a translator.
pub fn map_via<T: AddressTranslator<()>>(self) -> PhysMemoryRegion {
PhysMemoryRegion::new(self.address.to_pa::<T>(), self.size)
}
}
/// A memory region of user-space addresses.
pub type UserMemoryRegion = MemoryRegion<User>;
/// A representation of a `MemoryRegion` that has been expanded to be page-aligned.
///
/// This struct holds the new, larger, page-aligned region, as well as the
/// byte offset of the original region's start address from the start of the
/// new aligned region. This is essential for MMU operations that must map
// full pages but return a pointer to an unaligned location within that page.
#[derive(Copy, Clone)]
pub struct MappableRegion<T: MemKind> {
/// The fully-encompassing, page-aligned memory region.
region: MemoryRegion<T>,
/// The byte offset of the original region's start from the aligned region's start.
offset_from_page_start: usize,
}
impl<T: MemKind> MappableRegion<T> {
/// Returns the full, page-aligned region that is suitable for mapping.
pub const fn region(&self) -> MemoryRegion<T> {
self.region
}
/// Returns the offset of the original data within the aligned region.
pub const fn offset(&self) -> usize {
self.offset_from_page_start
}
}
#[cfg(test)]
mod tests {
use super::PhysMemoryRegion;
use crate::memory::{
PAGE_SIZE,
address::{PA, VA},
region::VirtMemoryRegion,
};
#[test]
fn merge_adjacent() {
let a = PhysMemoryRegion::new(PA::from_value(0x100), 0x10);
let b = PhysMemoryRegion::new(PA::from_value(0x110), 0x10);
let merged = a.merge(b).unwrap();
assert_eq!(merged.address.value(), 0x100);
assert_eq!(merged.size(), 0x20);
}
#[test]
fn merge_overlap() {
let a = PhysMemoryRegion::new(PA::from_value(0x100), 0x20);
let b = PhysMemoryRegion::new(PA::from_value(0x110), 0x20);
let merged = a.merge(b).unwrap();
assert_eq!(merged.address.value(), 0x100);
assert_eq!(merged.size(), 0x30);
}
#[test]
fn merge_identical() {
let a = PhysMemoryRegion::new(PA::from_value(0x100), 0x20);
let b = PhysMemoryRegion::new(PA::from_value(0x100), 0x20);
let merged = a.merge(b).unwrap();
assert_eq!(merged.address.value(), 0x100);
assert_eq!(merged.size(), 0x20);
}
#[test]
fn merge_non_touching() {
let a = PhysMemoryRegion::new(PA::from_value(0x100), 0x10);
let b = PhysMemoryRegion::new(PA::from_value(0x200), 0x10);
assert!(a.merge(b).is_none());
}
#[test]
fn merge_reverse_order() {
let a = PhysMemoryRegion::new(PA::from_value(0x200), 0x20);
let b = PhysMemoryRegion::new(PA::from_value(0x100), 0x100);
let merged = a.merge(b).unwrap();
assert_eq!(merged.address.value(), 0x100);
assert_eq!(merged.size(), 0x120);
}
#[test]
fn merge_partial_overlap() {
let a = PhysMemoryRegion::new(PA::from_value(0x100), 0x30); // [0x100, 0x130)
let b = PhysMemoryRegion::new(PA::from_value(0x120), 0x20); // [0x120, 0x140)
let merged = a.merge(b).unwrap(); // should be [0x100, 0x140)
assert_eq!(merged.address.value(), 0x100);
assert_eq!(merged.size(), 0x40);
}
#[test]
fn test_contains_region() {
let a = PhysMemoryRegion::new(PA::from_value(0x1000), 0x300);
let b = PhysMemoryRegion::new(PA::from_value(0x1100), 0x100);
assert!(a.contains(b));
assert!(!b.contains(a));
}
#[test]
fn test_contains_address() {
let region = PhysMemoryRegion::new(PA::from_value(0x1000), 0x100);
assert!(region.contains_address(PA::from_value(0x1000)));
assert!(region.contains_address(PA::from_value(0x10FF)));
assert!(!region.contains_address(PA::from_value(0x1100)));
}
#[test]
fn test_iter_pages_exact_multiple() {
let start_va = VA::from_value(0x10000);
let num_pages = 3;
let region = VirtMemoryRegion::new(start_va, num_pages * PAGE_SIZE);
let pages: Vec<VA> = region.iter_pages().collect();
assert_eq!(pages.len(), num_pages);
assert_eq!(pages[0], VA::from_value(0x10000));
assert_eq!(pages[1], VA::from_value(0x11000));
assert_eq!(pages[2], VA::from_value(0x12000));
}
#[test]
fn test_iter_pages_single_page() {
let start_va = VA::from_value(0x20000);
let region = VirtMemoryRegion::new(start_va, PAGE_SIZE);
let mut iter = region.iter_pages();
assert_eq!(iter.next(), Some(start_va));
assert_eq!(iter.next(), None);
}
#[test]
fn test_iter_pages_empty_region() {
let start_va = VA::from_value(0x30000);
let region = VirtMemoryRegion::new(start_va, 0);
assert_eq!(region.iter_pages().count(), 0);
}
#[test]
fn test_iter_pages_size_not_a_multiple_of_page_size() {
let start_va = VA::from_value(0x40000);
let num_pages = 5;
let size = num_pages * PAGE_SIZE + 123;
let region = VirtMemoryRegion::new(start_va, size);
assert_eq!(region.iter_pages().count(), num_pages);
assert_eq!(
region.iter_pages().last(),
Some(start_va.add_pages(num_pages - 1))
);
}
#[test]
fn test_iter_pages_large_count() {
let start_va = VA::from_value(0x5_0000_0000);
let num_pages = 1000;
let region = VirtMemoryRegion::new(start_va, num_pages * PAGE_SIZE);
assert_eq!(region.iter_pages().count(), num_pages);
}
}

View File

File diff suppressed because it is too large Load Diff

31
libkernel/src/pod.rs Normal file
View File

@@ -0,0 +1,31 @@
/// An unsafe trait indicating that a type is "Plain Old Data".
///
/// A type is `Pod` if it is a simple collection of bytes with no invalid bit
/// patterns. This means it can be safely created by simply copying its byte
/// representation from memory or a device.
///
/// # Safety
///
/// The implementor of this trait MUST guarantee that:
/// 1. The type has a fixed, known layout. Using `#[repr(C)]` or
/// `#[repr(transparent)]` is a must! The Rust ABI is unstable.
/// 2. The type contains no padding bytes, or if it does, that reading those
/// padding bytes as uninitialized memory is not undefined behavior.
/// 3. All possible bit patterns for the type's size are valid instances of the type.
/// For example, a `bool` is NOT `Pod` because its valid representations are only
/// 0x00 and 0x01, not any other byte value. A `u32` is `Pod` because all
/// 2^32 bit patterns are valid `u32` values.
pub unsafe trait Pod: Sized {}
// Blanket implementations for primitive types that are definitely Pod.
unsafe impl Pod for u8 {}
unsafe impl Pod for u16 {}
unsafe impl Pod for u32 {}
unsafe impl Pod for u64 {}
unsafe impl Pod for u128 {}
unsafe impl Pod for i8 {}
unsafe impl Pod for i16 {}
unsafe impl Pod for i32 {}
unsafe impl Pod for i64 {}
unsafe impl Pod for i128 {}
unsafe impl<T: Pod, const N: usize> Pod for [T; N] {}

43
libkernel/src/proc/ids.rs Normal file
View File

@@ -0,0 +1,43 @@
#[repr(C)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Uid(u32);
impl Uid {
pub const fn new(id: u32) -> Self {
Self(id)
}
pub fn is_root(self) -> bool {
self.0 == 0
}
pub fn new_root() -> Self {
Self(0)
}
}
impl From<Uid> for u32 {
fn from(value: Uid) -> Self {
value.0
}
}
#[repr(C)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Gid(u32);
impl Gid {
pub const fn new(id: u32) -> Self {
Self(id)
}
pub fn new_root_group() -> Self {
Self(0)
}
}
impl From<Gid> for u32 {
fn from(value: Gid) -> Self {
value.0
}
}

View File

@@ -0,0 +1 @@
pub mod ids;

View File

@@ -0,0 +1,224 @@
use super::spinlock::SpinLockIrq;
use super::waker_set::WakerSet;
use crate::CpuOps;
use alloc::sync::Arc;
/// The type of wakeup that should occur after a state update.
pub enum WakeupType {
None,
One,
All,
}
struct CondVarInner<S> {
state: S,
wakers: WakerSet,
}
impl<S> CondVarInner<S> {
fn new(initial_state: S) -> Self {
Self {
state: initial_state,
wakers: WakerSet::new(),
}
}
}
/// A condvar for managing asynchronous tasks that need to sleep while sharing
/// state.
///
/// This structure is thread-safe and allows registering wakers, which can be
/// woken selectively or all at once. It is "drop-aware": if a future that
/// registered a waker is dropped, its waker is automatically de-registered.
pub struct CondVar<S, C: CpuOps> {
inner: Arc<SpinLockIrq<CondVarInner<S>, C>>,
}
impl<S, C: CpuOps> CondVar<S, C> {
/// Creates a new, empty wait queue, initialized with state `initial_state`.
pub fn new(initial_state: S) -> Self {
Self {
inner: Arc::new(SpinLockIrq::new(CondVarInner::new(initial_state))),
}
}
/// Updates the internal state by calling `updater`.
///
/// The `updater` closure should return the kind of wakeup to perform on the
/// condvar after performing the update.
pub fn update(&self, updater: impl FnOnce(&mut S) -> WakeupType) {
let mut inner = self.inner.lock_save_irq();
match updater(&mut inner.state) {
WakeupType::None => (),
WakeupType::One => inner.wakers.wake_one(),
WakeupType::All => inner.wakers.wake_all(),
}
}
/// Creates a future that waits on the queue until a condition on the
/// internal state is met, returning a value `T` when finished waiting.
///
/// # Arguments
/// * `predicate`: A closure that checks the condition in the underlying
/// state. It should return `None` to continue waiting, or `Some(T)` to
/// stop waitng and yield T to the caller.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub fn wait_until<T, F>(&self, predicate: F) -> impl Future<Output = T> + use<T, S, C, F>
where
F: Fn(&mut S) -> Option<T>,
{
super::waker_set::wait_until(
self.inner.clone(),
|inner| &mut inner.wakers,
move |inner| predicate(&mut inner.state),
)
}
}
impl<S, C: CpuOps> Clone for CondVar<S, C> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(test)]
mod condvar_tests {
use crate::test::MockCpuOps;
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct TestState {
counter: u32,
}
#[tokio::test]
async fn wait_and_wake_one() {
let condvar = CondVar::<_, MockCpuOps>::new(TestState { counter: 0 });
let condvar_clone = condvar.clone();
let handle = tokio::spawn(async move {
condvar
.wait_until(|state| {
if state.counter == 1 {
Some("Condition Met".to_string())
} else {
None
}
})
.await
});
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(!handle.is_finished(), "Future finished prematurely");
condvar_clone.update(|state| {
state.counter += 1;
WakeupType::One
});
let result = tokio::time::timeout(Duration::from_millis(50), handle)
.await
.expect("Future timed out")
.unwrap();
assert_eq!(result, "Condition Met");
}
#[tokio::test]
async fn test_wait_and_wake_all() {
let condvar = Arc::new(CondVar::<_, MockCpuOps>::new(TestState { counter: 0 }));
let completion_count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
// Spawn three tasks that wait for the counter to reach 5.
for _ in 0..3 {
let condvar_clone = condvar.clone();
let completion_count_clone = completion_count.clone();
let handle = tokio::spawn(async move {
let result = condvar_clone
.wait_until(|state| {
if state.counter >= 5 {
Some(state.counter)
} else {
None
}
})
.await;
completion_count_clone.fetch_add(1, Ordering::SeqCst);
result
});
handles.push(handle);
}
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(completion_count.load(Ordering::SeqCst), 0);
condvar.update(|state| {
state.counter = 5;
WakeupType::All
});
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(result, 5);
}
assert_eq!(completion_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_predicate_already_true() {
let condvar = CondVar::<_, MockCpuOps>::new(TestState { counter: 10 });
// This future should complete immediately without pending.
let result = condvar
.wait_until(|state| {
if state.counter == 10 {
Some("Already True")
} else {
None
}
})
.await;
assert_eq!(result, "Already True");
}
#[tokio::test]
async fn test_update_with_no_wakeup() {
let condvar = CondVar::<_, MockCpuOps>::new(TestState { counter: 0 });
let handle = {
let condvar = condvar.clone();
tokio::spawn(async move {
condvar
.wait_until(|state| if state.counter == 1 { Some(()) } else { None })
.await
})
};
tokio::time::sleep(Duration::from_millis(10)).await;
condvar.update(|state| {
state.counter = 1;
WakeupType::None
});
// Give some time to see if the future completes (it shouldn't).
tokio::time::sleep(Duration::from_millis(20)).await;
// The future should still be pending.
assert!(!handle.is_finished());
// Now, perform a wakeup to allow the test to clean up.
condvar.update(|_| WakeupType::One);
handle.await.unwrap();
}
}

View File

@@ -0,0 +1,7 @@
pub mod condvar;
pub mod mpsc;
pub mod mutex;
pub mod once_lock;
pub mod per_cpu;
pub mod spinlock;
pub mod waker_set;

159
libkernel/src/sync/mpsc.rs Normal file
View File

@@ -0,0 +1,159 @@
//! An asynchronous, multi-producer, single-consumer (MPSC) channel.
//!
//! This module provides a queue for sending values between asynchronous tasks
//! within the kernel.
use super::condvar::{CondVar, WakeupType};
use crate::CpuOps;
use alloc::collections::VecDeque;
struct MpscState<T: Send> {
data: VecDeque<T>,
senders: usize,
recv_gone: bool,
}
/// The receiving half of the MPSC channel.
///
/// There can only be one `Reciever` for a given channel.
///
/// If the `Reciever` is dropped, the channel is closed. Any subsequent messages
/// sent by a `Sender` will be dropped.
pub struct Reciever<T: Send, C: CpuOps> {
inner: CondVar<MpscState<T>, C>,
}
enum RxResult<T> {
Data(T),
SenderGone,
}
impl<T: Send, C: CpuOps> Reciever<T, C> {
/// Asynchronously waits for a message from the channel.
///
/// This function returns a `Future` that resolves to:
/// - `Some(T)`: If a message was successfully received from the channel.
/// - `None`: If all `Sender` instances have been dropped, indicating that
/// no more messages will ever be sent. The channel is now closed.
pub async fn recv(&self) -> Option<T> {
let result = self
.inner
.wait_until(|state| {
if let Some(data) = state.data.pop_front() {
Some(RxResult::Data(data))
} else if state.senders == 0 {
Some(RxResult::SenderGone)
} else {
None
}
})
.await;
match result {
RxResult::Data(d) => Some(d),
RxResult::SenderGone => None,
}
}
}
impl<T: Send, C: CpuOps> Drop for Reciever<T, C> {
fn drop(&mut self) {
self.inner.update(|state| {
// Since there can only be once reciever and we are now dropping
// it, drain the queue, and set a flag such that any more sends
// result in the value being dropped.
core::mem::take(&mut state.data);
state.recv_gone = true;
WakeupType::None
})
}
}
/// The sending half of the MPSC channel.
///
/// `Sender` handles can be cloned to allow multiple producers to send messages
/// to the single `Reciever`.
///
/// When the last `Sender` is dropped, the channel is closed. This will cause
/// the `Reciever::recv` future to resolve to `None`.
pub struct Sender<T: Send, C: CpuOps> {
inner: CondVar<MpscState<T>, C>,
}
impl<T: Send, C: CpuOps> Sender<T, C> {
/// Sends a message into the channel.
///
/// This method enqueues the given object `obj` for the `Reciever` to
/// consume. After enqueuing the message, it notifies one waiting `Reciever`
/// task, if one exists.
///
/// This operation is non-blocking from an async perspective, though it will
/// acquire a spinlock.
pub fn send(&self, obj: T) {
self.inner.update(|state| {
if state.recv_gone {
// Receiver has been dropped, so drop the message.
return WakeupType::None;
}
state.data.push_back(obj);
WakeupType::One
});
}
}
impl<T: Send, C: CpuOps> Clone for Sender<T, C> {
fn clone(&self) -> Self {
self.inner.update(|state| {
state.senders += 1;
WakeupType::None
});
Self {
inner: self.inner.clone(),
}
}
}
impl<T: Send, C: CpuOps> Drop for Sender<T, C> {
fn drop(&mut self) {
self.inner.update(|state| {
state.senders -= 1;
if state.senders == 0 {
// Wake the receiver to let it know the channel is now closed. We
// use wake_all as a safeguard, though only one task should be
// waiting.
WakeupType::All
} else {
WakeupType::None
}
});
}
}
/// Creates a new asynchronous, multi-producer, single-consumer channel.
///
/// Returns a tuple containing the `Sender` and `Reciever` halves. The `Sender`
/// can be cloned to create multiple producers, while the `Reciever` is the
/// single consumer.
pub fn channel<T: Send, C: CpuOps>() -> (Sender<T, C>, Reciever<T, C>) {
let state = MpscState {
data: VecDeque::new(),
senders: 1,
recv_gone: false,
};
let waitq = CondVar::new(state);
let tx = Sender {
inner: waitq.clone(),
};
let rx = Reciever { inner: waitq };
(tx, rx)
}

133
libkernel/src/sync/mutex.rs Normal file
View File

@@ -0,0 +1,133 @@
use alloc::collections::VecDeque;
use core::cell::UnsafeCell;
use core::future::Future;
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::task::{Context, Poll, Waker};
use crate::CpuOps;
use super::spinlock::SpinLockIrq;
struct MutexState {
is_locked: bool,
waiters: VecDeque<Waker>,
}
/// An asynchronous, mutex primitive.
///
/// This mutex can be used to protect shared data across asynchronous tasks.
/// `lock()` returns a future that resolves to a guard. When the guard is
/// dropped, the lock is released.
pub struct Mutex<T: ?Sized, CPU: CpuOps> {
state: SpinLockIrq<MutexState, CPU>,
data: UnsafeCell<T>,
}
/// A guard that provides exclusive access to the data in an `AsyncMutex`.
///
/// When an `AsyncMutexGuard` is dropped, it automatically releases the lock and
/// wakes up the next task.
#[must_use = "if unused, the Mutex will immediately unlock"]
pub struct AsyncMutexGuard<'a, T: ?Sized, CPU: CpuOps> {
mutex: &'a Mutex<T, CPU>,
}
/// A future that resolves to an `AsyncMutexGuard` when the lock is acquired.
pub struct MutexGuardFuture<'a, T: ?Sized, CPU: CpuOps> {
mutex: &'a Mutex<T, CPU>,
}
impl<T, CPU: CpuOps> Mutex<T, CPU> {
/// Creates a new asynchronous mutex in an unlocked state.
pub const fn new(data: T) -> Self {
Self {
state: SpinLockIrq::new(MutexState {
is_locked: false,
waiters: VecDeque::new(),
}),
data: UnsafeCell::new(data),
}
}
/// Consumes the mutex, returning the underlying data.
///
/// This is safe because consuming `self` guarantees no other code can
/// access the mutex.
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: ?Sized, CPU: CpuOps> Mutex<T, CPU> {
/// Acquires the mutex lock.
///
/// Returns a future that resolves to a lock guard. The returned future must
/// be `.await`ed to acquire the lock. The lock is released when the
/// returned `AsyncMutexGuard` is dropped.
pub fn lock(&self) -> MutexGuardFuture<'_, T, CPU> {
MutexGuardFuture { mutex: self }
}
/// Returns a mutable reference to the underlying data.
///
/// Since this call borrows the `Mutex` mutably, no actual locking needs to
/// take place - the mutable borrow statically guarantees that no other
/// references to the `Mutex` exist.
pub fn get_mut(&mut self) -> &mut T {
// SAFETY: We can grant mutable access to the data because `&mut self`
// guarantees that no other threads are concurrently accessing the
// mutex. No other code can call `.lock()` because we hold the unique
// mutable reference. Thus, we can safely bypass the lock.
unsafe { &mut *self.data.get() }
}
}
impl<'a, T: ?Sized, CPU: CpuOps> Future for MutexGuardFuture<'a, T, CPU> {
type Output = AsyncMutexGuard<'a, T, CPU>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.mutex.state.lock_save_irq();
if !state.is_locked {
state.is_locked = true;
Poll::Ready(AsyncMutexGuard { mutex: self.mutex })
} else {
if state.waiters.iter().all(|w| !w.will_wake(cx.waker())) {
state.waiters.push_back(cx.waker().clone());
}
Poll::Pending
}
}
}
impl<T: ?Sized, CPU: CpuOps> Drop for AsyncMutexGuard<'_, T, CPU> {
fn drop(&mut self) {
let mut state = self.mutex.state.lock_save_irq();
if let Some(next_waker) = state.waiters.pop_front() {
next_waker.wake();
}
state.is_locked = false;
}
}
impl<T: ?Sized, CPU: CpuOps> Deref for AsyncMutexGuard<'_, T, CPU> {
type Target = T;
fn deref(&self) -> &T {
// SAFETY: This is safe because the existence of this guard guarantees
// we have exclusive access to the data.
unsafe { &*self.mutex.data.get() }
}
}
impl<T: ?Sized, CPU: CpuOps> DerefMut for AsyncMutexGuard<'_, T, CPU> {
fn deref_mut(&mut self) -> &mut T {
// This is safe for the same reason.
unsafe { &mut *self.mutex.data.get() }
}
}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Send for Mutex<T, CPU> {}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Sync for Mutex<T, CPU> {}

View File

@@ -0,0 +1,133 @@
use core::fmt;
use crate::CpuOps;
use super::spinlock::SpinLockIrq;
/// A cell which can be written to only once.
///
/// This is a kernel-safe, no_std equivalent of std::sync::OnceLock, built on
/// top of a SpinLock.
pub struct OnceLock<T, CPU: CpuOps> {
inner: SpinLockIrq<Option<T>, CPU>,
}
impl<T, CPU: CpuOps> OnceLock<T, CPU> {
/// Creates a new, empty `OnceLock`.
pub const fn new() -> Self {
OnceLock {
inner: SpinLockIrq::new(None),
}
}
/// Gets a reference to the contained value, if it has been initialized.
pub fn get(&self) -> Option<&T> {
let guard = self.inner.lock_save_irq();
if let Some(value) = guard.as_ref() {
// SAFETY: This is the only `unsafe` part. We are "extending" the
// lifetime of the reference beyond the scope of the lock guard.
//
// This is sound because we guarantee that once the `Option<T>` is
// `Some(T)`, it will *never* be changed back to `None` or to a
// different `Some(T)`. The value is stable in memory for the
// lifetime of the `OnceLock` itself.
let ptr: *const T = value;
Some(unsafe { &*ptr })
} else {
None
}
}
/// Gets a mutable reference to the contained value, if it has been
/// initialized.
pub fn get_mut(&mut self) -> Option<&mut T> {
let mut guard = self.inner.lock_save_irq();
if let Some(value) = guard.as_mut() {
// SAFETY: This is the only `unsafe` part. We are "extending" the
// lifetime of the reference beyond the scope of the lock guard.
//
// This is sound because we guarantee that once the `Option<T>` is
// `Some(T)`, it will *never* be changed back to `None` or to a
// different `Some(T)`. The value is stable in memory for the
// lifetime of the `OnceLock` itself.
let ptr: *mut T = value;
Some(unsafe { &mut *ptr })
} else {
None
}
}
/// Gets the contained value, or initializes it with a closure if it is empty.
pub fn get_or_init<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
if let Some(value) = self.get() {
return value;
}
// The value was not initialized. We need to acquire a full lock
// to potentially initialize it.
self.initialize(f)
}
#[cold]
fn initialize<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
let mut guard = self.inner.lock_save_irq();
// We must check again! Between our `get()` call and acquiring the lock,
// another core could have initialized the value. If we don't check
// again, we would initialize it a second time.
let value = match *guard {
Some(ref value) => value,
None => {
// It's still None, so we are the first. We run the closure
// and set the value.
let new_value = f();
guard.insert(new_value) // `insert` places the value and returns a &mut to it
}
};
// As before, we can now safely extend the lifetime of the reference.
let ptr: *const T = value;
unsafe { &*ptr }
}
/// Attempts to set the value of the `OnceLock`.
///
/// If the cell is already initialized, the given value is returned in an
/// `Err`. This is useful for when initialization might fail and you don't
/// want to use a closure-based approach.
pub fn set(&self, value: T) -> Result<(), T> {
let mut guard = self.inner.lock_save_irq();
if guard.is_some() {
Err(value)
} else {
*guard = Some(value);
Ok(())
}
}
}
impl<T, CPU: CpuOps> Default for OnceLock<T, CPU> {
fn default() -> Self {
Self::new()
}
}
// Implement Debug for nice printing, if the inner type supports it.
impl<T: fmt::Debug, CPU: CpuOps> fmt::Debug for OnceLock<T, CPU> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnceLock")
.field("inner", &self.get())
.finish()
}
}
unsafe impl<T: Sync + Send, CPU: CpuOps> Sync for OnceLock<T, CPU> {}
unsafe impl<T: Send, CPU: CpuOps> Send for OnceLock<T, CPU> {}

View File

@@ -0,0 +1,338 @@
//! A module for creating lock-free, per-CPU static variables.
//!
//! This module provides a mechanism for creating data that is unique to each
//! processor core. Accessing this data is extremely fast as it requires no
//! locks.
//!
//! The design relies on a custom linker section (`.percpu`) and a macro
//! (`per_cpu!`) to automatically register and initialize all per-CPU variables
//! at boot.
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::cell::{Ref, RefCell, RefMut};
use core::marker::PhantomData;
use core::sync::atomic::{AtomicPtr, Ordering};
use log::info;
use crate::CpuOps;
/// A trait for type-erased initialization of `PerCpu` variables.
///
/// This allows the global initialization loop to call `init` on any `PerCpu<T>`
/// without knowing the concrete type `T`.
pub trait PerCpuInitializer {
/// Initializes the per-CPU data allocation.
fn init(&self, num_cpus: usize);
}
/// A container for a value that has a separate instance for each CPU core.
///
/// See the module-level documentation for detailed usage instructions.
pub struct PerCpu<T: Send, CPU: CpuOps> {
/// A pointer to the heap-allocated array of `RefCell<T>`s, one for each
/// CPU. It's `AtomicPtr` to ensure safe one-time initialization.
ptr: AtomicPtr<RefCell<T>>,
/// A function pointer to the initializer for type `T`.
/// This is stored so it can be called during the runtime `init` phase.
initializer: fn() -> T,
phantom: PhantomData<CPU>,
}
impl<T: Send, CPU: CpuOps> PerCpu<T, CPU> {
/// Creates a new, uninitialized `PerCpu` variable.
///
// This is `const` so it can be used to initialize `static` variables.
pub const fn new(initializer: fn() -> T) -> Self {
Self {
ptr: AtomicPtr::new(core::ptr::null_mut()),
initializer,
phantom: PhantomData,
}
}
/// Returns a reference to the underlying datakj for the current CPU.
///
/// # Panics Panics if the `PerCpu` variable has not been initialized.
fn get_cell(&self) -> &RefCell<T> {
let id = CPU::id();
let base_ptr = self.ptr.load(Ordering::Acquire);
if base_ptr.is_null() {
panic!("PerCpu variable accessed before initialization");
}
// SAFETY: We have checked for null, and `init` guarantees the allocation
// is valid for `id`. The returned reference is to a `RefCell`, which
// manages its own internal safety.
unsafe { &*base_ptr.add(id) }
}
/// Immutably borrows the per-CPU data.
///
/// The borrow lasts until the returned `Ref<T>` is dropped.
///
/// # Panics
/// Panics if the value is already mutably borrowed.
pub fn borrow(&self) -> Ref<'_, T> {
self.get_cell().borrow()
}
/// Mutably borrows the per-CPU data.
///
/// The borrow lasts until the returned `RefMut<T>` is dropped.
///
/// # Panics
/// Panics if the value is already borrowed (mutably or immutably).
pub fn borrow_mut(&self) -> RefMut<'_, T> {
self.get_cell().borrow_mut()
}
/// A convenience method to execute a closure with a mutable reference.
/// This is often simpler than holding onto the `RefMut` guard.
///
/// # Panics
/// Panics if the value is already borrowed.
pub fn with_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
f(&mut *self.borrow_mut())
}
}
// Implement the type-erased initializer trait.
impl<T: Send, CPU: CpuOps> PerCpuInitializer for PerCpu<T, CPU> {
fn init(&self, num_cpus: usize) {
let mut values = Vec::with_capacity(num_cpus);
for _ in 0..num_cpus {
values.push(RefCell::new((self.initializer)()));
}
let leaked_ptr = Box::leak(values.into_boxed_slice()).as_mut_ptr();
let result = self.ptr.compare_exchange(
core::ptr::null_mut(),
leaked_ptr,
Ordering::Release,
Ordering::Relaxed, // We don't care about the value on failure.
);
if result.is_err() {
panic!("PerCpu::init called more than once on the same variable");
}
}
}
/// A `PerCpu<T>` can be safely shared between threads (`Sync`) if `T` is
/// `Send`.
///
/// # Safety
///
/// This is safe because although the `PerCpu` object itself is shared, the
/// underlying data `T` is partitioned. Each CPU can only access its own private
/// slot. The `T` value is effectively "sent" to its destination CPU's slot
/// during initialization. There is no cross-CPU data sharing at the `T` level.
unsafe impl<T: Send, CPU: CpuOps> Sync for PerCpu<T, CPU> {}
/// Initializes all `PerCpu` static variables defined with the `per_cpu!` macro.
///
/// This function iterates over the `.percpu` ELF section, which contains a list
/// of all `PerCpu` instances that need to be initialized.
///
/// # Safety
///
/// This function must only be called once during boot, before other cores are
/// started and before any `PerCpu` variable is accessed. It dereferences raw
/// pointers provided by the linker script. The caller must ensure that the
/// `__percpu_start` and `__percpu_end` symbols from the linker are valid.
pub unsafe fn setup_percpu(num_cpus: usize) {
unsafe extern "C" {
static __percpu_start: u8;
static __percpu_end: u8;
}
let start_ptr =
unsafe { &__percpu_start } as *const u8 as *const &'static (dyn PerCpuInitializer + Sync);
let end_ptr =
unsafe { &__percpu_end } as *const u8 as *const &'static (dyn PerCpuInitializer + Sync);
let mut current_ptr = start_ptr;
let mut objs_setup = 0;
while current_ptr < end_ptr {
let percpu_var = unsafe { &*current_ptr };
percpu_var.init(num_cpus);
current_ptr = unsafe { current_ptr.add(1) };
objs_setup += 1;
}
info!("Setup {} per_cpu objects.", objs_setup * num_cpus);
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
use std::sync::Arc;
use std::thread;
thread_local! {
static MOCK_CPU_ID: Cell<usize> = Cell::new(0);
}
struct MockArch;
impl CpuOps for MockArch {
fn id() -> usize {
MOCK_CPU_ID.with(|id| id.get())
}
fn halt() -> ! {
unimplemented!()
}
fn disable_interrupts() -> usize {
unimplemented!()
}
fn restore_interrupt_state(_flags: usize) {
unimplemented!()
}
fn enable_interrupts() {
unimplemented!()
}
}
#[test]
fn test_initialization_and_basic_access() {
let data: PerCpu<_, MockArch> = PerCpu::new(|| 0u32);
data.init(4); // Simulate a 4-core system
// Act as CPU 0
MOCK_CPU_ID.with(|id| id.set(0));
assert_eq!(*data.borrow(), 0);
*data.borrow_mut() = 100;
assert_eq!(*data.borrow(), 100);
// Act as CPU 3
MOCK_CPU_ID.with(|id| id.set(3));
assert_eq!(*data.borrow(), 0); // Should still be the initial value
data.with_mut(|val| *val = 300);
assert_eq!(*data.borrow(), 300);
// Check CPU 0 again to ensure it wasn't affected
MOCK_CPU_ID.with(|id| id.set(0));
assert_eq!(*data.borrow(), 100);
}
#[test]
#[should_panic(expected = "PerCpu variable accessed before initialization")]
fn test_panic_on_uninitialized_access() {
let data: PerCpu<_, MockArch> = PerCpu::new(|| 0);
// This should panic because data.init() was not called.
let _ = data.borrow();
}
#[test]
#[should_panic(expected = "PerCpu::init called more than once")]
fn test_panic_on_double_init() {
let data: PerCpu<_, MockArch> = PerCpu::new(|| 0);
data.init(1);
data.init(1); // This second call should panic.
}
#[test]
#[should_panic(expected = "already borrowed")]
fn test_refcell_panic_on_double_mutable_borrow() {
let data: PerCpu<_, MockArch> = PerCpu::new(|| String::from("hello"));
data.init(1);
MOCK_CPU_ID.with(|id| id.set(0));
let _guard1 = data.borrow_mut();
// This second mutable borrow on the same "CPU" should panic.
let _guard2 = data.borrow_mut();
}
#[test]
#[should_panic(expected = "already borrowed")]
fn test_refcell_panic_on_mutable_while_immutable_borrow() {
let data: PerCpu<_, MockArch> = PerCpu::new(|| 0);
data.init(1);
MOCK_CPU_ID.with(|id| id.set(0));
let _guard = data.borrow();
// Attempting to mutably borrow while an immutable borrow exists should panic.
*data.borrow_mut() = 5;
}
#[test]
fn test_multithreaded_access_is_isolated() {
// This is the stress test. It gives high confidence that the `unsafe impl Sync`
// is correct because each thread will perform many isolated operations.
const NUM_THREADS: usize = 8;
const ITERATIONS_PER_THREAD: usize = 1000;
// The data must be in an Arc to be shared across threads.
let per_cpu_data: Arc<PerCpu<_, MockArch>> = Arc::new(PerCpu::new(|| 0));
per_cpu_data.init(NUM_THREADS);
let mut handles = vec![];
for i in 0..NUM_THREADS {
let data_clone = Arc::clone(&per_cpu_data);
let handle = thread::spawn(move || {
MOCK_CPU_ID.with(|id| id.set(i));
let initial_val = i * 100_000;
data_clone.with_mut(|val| *val = initial_val);
for j in 0..ITERATIONS_PER_THREAD {
data_clone.with_mut(|val| {
// VERIFY: Check that the value is what we expect from
// the previous iteration. If another thread interfered,
// this assert will fail.
let expected_val = initial_val + j;
assert_eq!(
*val, expected_val,
"Data corruption on CPU {}! Expected {}, found {}",
i, expected_val, *val
);
// MODIFY: Increment the value.
*val += 1;
});
if j % 10 == 0 {
// Don't yield on every single iteration
thread::yield_now();
}
}
// After the loop, verify the final value.
let final_val = *data_clone.borrow();
let expected_final_val = initial_val + ITERATIONS_PER_THREAD;
assert_eq!(
final_val, expected_final_val,
"Incorrect final value on CPU {}",
i
);
});
handles.push(handle);
}
// Wait for all threads to complete.
for handle in handles {
handle.join().unwrap();
}
// Optional: Final sanity check from the main thread (acting as CPU 0)
// to ensure its value was not corrupted by the other threads.
MOCK_CPU_ID.with(|id| id.set(0));
let expected_val_for_cpu0 = 0 * 100_000 + ITERATIONS_PER_THREAD;
assert_eq!(*per_cpu_data.borrow(), expected_val_for_cpu0);
}
}

View File

@@ -0,0 +1,98 @@
use core::cell::UnsafeCell;
use core::hint::spin_loop;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use core::sync::atomic::{AtomicBool, Ordering};
use crate::CpuOps;
/// A spinlock that also disables interrupts on the local core while held.
///
/// This prevents deadlocks with interrupt handlers on the same core and
/// provides SMP-safety against other cores.
pub struct SpinLockIrq<T: ?Sized, CPU: CpuOps> {
lock: AtomicBool,
_phantom: PhantomData<CPU>,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Send for SpinLockIrq<T, CPU> {}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Sync for SpinLockIrq<T, CPU> {}
impl<T, CPU: CpuOps> SpinLockIrq<T, CPU> {
/// Creates a new IRQ-safe spinlock.
pub const fn new(data: T) -> Self {
Self {
lock: AtomicBool::new(false),
_phantom: PhantomData,
data: UnsafeCell::new(data),
}
}
}
impl<T: ?Sized, CPU: CpuOps> SpinLockIrq<T, CPU> {
/// Disables interrupts, acquires the lock, and returns a guard. The
/// original interrupt state is restored when the guard is dropped.
pub fn lock_save_irq(&self) -> SpinLockIrqGuard<'_, T, CPU> {
let saved_irq_flags = CPU::disable_interrupts();
while self
.lock
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
// Spin while waiting for the lock to become available.
// The `Relaxed` load is sufficient here because the `Acquire`
// exchange in the loop will synchronize memory.
while self.lock.load(Ordering::Relaxed) {
spin_loop();
}
}
SpinLockIrqGuard {
lock: self,
irq_flags: saved_irq_flags,
_marker: PhantomData,
}
}
}
/// An RAII guard for an IRQ-safe spinlock.
///
/// When this guard is dropped, the spinlock is released and the original
/// interrupt state of the local CPU core is restored.
#[must_use]
pub struct SpinLockIrqGuard<'a, T: ?Sized + 'a, CPU: CpuOps> {
lock: &'a SpinLockIrq<T, CPU>,
irq_flags: usize, // The saved DAIF register state
_marker: PhantomData<*const ()>, // !Send
}
impl<'a, T: ?Sized, CPU: CpuOps> Deref for SpinLockIrqGuard<'a, T, CPU> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: The spinlock is held, guaranteeing exclusive access.
// Interrupts are disabled on the local core, preventing re-entrant
// access from an interrupt handler on this same core.
unsafe { &*self.lock.data.get() }
}
}
impl<'a, T: ?Sized, CPU: CpuOps> DerefMut for SpinLockIrqGuard<'a, T, CPU> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: The spinlock is held, guaranteeing exclusive access.
// Interrupts are disabled on the local core, preventing re-entrant
// access from an interrupt handler on this same core.
unsafe { &mut *self.lock.data.get() }
}
}
impl<'a, T: ?Sized, CPU: CpuOps> Drop for SpinLockIrqGuard<'a, T, CPU> {
/// Releases the lock and restores the previous interrupt state.
fn drop(&mut self) {
self.lock.lock.store(false, Ordering::Release);
CPU::restore_interrupt_state(self.irq_flags);
}
}

View File

@@ -0,0 +1,234 @@
use alloc::collections::BTreeMap;
use alloc::sync::Arc;
use core::{
pin::Pin,
task::{Context, Poll, Waker},
};
use crate::CpuOps;
use super::spinlock::SpinLockIrq;
pub struct WakerSet {
waiters: BTreeMap<u64, Waker>,
next_id: u64,
}
impl Default for WakerSet {
fn default() -> Self {
Self::new()
}
}
impl WakerSet {
pub fn new() -> Self {
Self {
waiters: BTreeMap::new(),
next_id: 0,
}
}
fn allocate_id(&mut self) -> u64 {
let id = self.next_id;
// Use wrapping_add to prevent panic on overflow, though it's
// astronomically unlikely.
self.next_id = self.next_id.wrapping_add(1);
id
}
/// Registers a waker, returning a drop-aware token. When the token is
/// dropped, the waker is removed from the queue.
pub fn register(&mut self, waker: &Waker) -> u64 {
let id = self.allocate_id();
self.waiters.insert(id, waker.clone());
id
}
/// Removes a waker using its token.
pub fn remove(&mut self, token: u64) {
self.waiters.remove(&token);
}
/// Wakes one waiting task, if any.
pub fn wake_one(&mut self) {
if let Some((_, waker)) = self.waiters.pop_first() {
waker.wake();
}
}
/// Wakes all waiting tasks.
pub fn wake_all(&mut self) {
for (_, waker) in core::mem::take(&mut self.waiters) {
waker.wake();
}
}
}
/// A future that waits until a condition on a shared state is met.
///
/// This future is designed to work with a state `T` protected by a
/// `SpinLockIrq`, where `T` contains one or more `WakerSet`s.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WaitUntil<C, T, F, G, R>
where
C: CpuOps,
F: FnMut(&mut T) -> Option<R>,
G: FnMut(&mut T) -> &mut WakerSet,
{
lock: Arc<SpinLockIrq<T, C>>,
get_waker_set: G,
predicate: F,
token: Option<u64>,
}
/// Creates a future that waits on a specific `WakerSet` within a shared,
/// locked state `T` until a condition is met.
pub fn wait_until<C, T, F, G, R>(
lock: Arc<SpinLockIrq<T, C>>,
get_waker_set: G,
predicate: F,
) -> WaitUntil<C, T, F, G, R>
where
C: CpuOps,
F: FnMut(&mut T) -> Option<R>,
G: FnMut(&mut T) -> &mut WakerSet,
{
WaitUntil {
lock,
get_waker_set,
predicate,
token: None,
}
}
impl<C, T, F, G, R> Future for WaitUntil<C, T, F, G, R>
where
C: CpuOps,
F: FnMut(&mut T) -> Option<R>,
G: FnMut(&mut T) -> &mut WakerSet,
{
type Output = R;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unsafe is required to move fields out of a pinned struct.
// This is safe because we are not moving the fields that are
// required to be pinned (none of them are in this case).
let this = unsafe { self.get_unchecked_mut() };
let mut inner = this.lock.lock_save_irq();
// Check the condition first.
if let Some(result) = (this.predicate)(&mut inner) {
return Poll::Ready(result);
}
// If the condition is not met, register our waker if we haven't already.
if this.token.is_none() {
let waker_set = (this.get_waker_set)(&mut inner);
let id = waker_set.register(cx.waker());
this.token = Some(id);
}
Poll::Pending
}
}
impl<C, T, F, G, R> Drop for WaitUntil<C, T, F, G, R>
where
C: CpuOps,
F: FnMut(&mut T) -> Option<R>,
G: FnMut(&mut T) -> &mut WakerSet,
{
fn drop(&mut self) {
// If we have a token, it means we're registered in the waker set.
// We must acquire the lock and remove ourselves.
if let Some(token) = self.token {
let mut inner = self.lock.lock_save_irq();
let waker_set = (self.get_waker_set)(&mut inner);
waker_set.remove(token);
}
}
}
#[cfg(test)]
mod wait_until_tests {
use super::*;
use crate::test::MockCpuOps; // Adjust paths
use std::sync::Arc;
use std::time::Duration;
struct SharedState {
condition_met: bool,
waker_set: WakerSet,
}
#[tokio::test]
async fn wait_until_completes_when_condition_is_met() {
let initial_state = SharedState {
condition_met: false,
waker_set: WakerSet::new(),
};
let lock = Arc::new(SpinLockIrq::<_, MockCpuOps>::new(initial_state));
let lock_clone = lock.clone();
let wait_future = wait_until(
lock.clone(),
|state| &mut state.waker_set,
|state| {
if state.condition_met { Some(()) } else { None }
},
);
let handle = tokio::spawn(wait_future);
// Give the future a chance to run and register its waker.
tokio::time::sleep(Duration::from_millis(10)).await;
// The future should not have completed yet.
assert!(!handle.is_finished());
// Now, meet the condition and wake the task.
{
let mut state = lock_clone.lock_save_irq();
state.condition_met = true;
state.waker_set.wake_one();
}
// The future should now complete.
let result = tokio::time::timeout(Duration::from_millis(50), handle).await;
assert!(result.is_ok(), "Future timed out");
}
#[tokio::test]
async fn wait_until_drop_removes_waker() {
let initial_state = SharedState {
condition_met: false,
waker_set: WakerSet::new(),
};
let lock = Arc::new(SpinLockIrq::<_, MockCpuOps>::new(initial_state));
let lock_clone = lock.clone();
let wait_future = wait_until(
lock.clone(),
|state| &mut state.waker_set,
|state| if state.condition_met { Some(()) } else { None },
);
let handle = tokio::spawn(async {
// Poll the future once to register the waker, then drop it.
let _ = tokio::time::timeout(Duration::from_millis(1), wait_future).await;
});
// Wait for the spawned task to complete.
handle.await.unwrap();
// Check that the waker has been removed from the waker set.
let state = lock_clone.lock_save_irq();
assert!(state.waker_set.waiters.is_empty());
}
}