From d6e1797620b8b0ef0fbae8f9061bc1037bd80fa9 Mon Sep 17 00:00:00 2001
From: 4lDO2 <4lDO2@protonmail.com>
Date: Fri, 25 Dec 2020 18:15:10 +0100
Subject: [PATCH] Make Mapper::map fallible.

---
 src/allocator/mod.rs             |  6 +++---
 src/arch/x86_64/paging/entry.rs  |  1 +
 src/arch/x86_64/paging/mapper.rs | 11 +++++------
 src/arch/x86_64/paging/mod.rs    | 15 ++++++++++-----
 src/context/memory.rs            | 15 +++++++++++----
 src/memory/mod.rs                |  3 +++
 6 files changed, 33 insertions(+), 18 deletions(-)

diff --git a/src/allocator/mod.rs b/src/allocator/mod.rs
index 0d7c9393..dfc618a4 100644
--- a/src/allocator/mod.rs
+++ b/src/allocator/mod.rs
@@ -1,5 +1,4 @@
-use crate::paging::{ActivePageTable, Page, PageFlags, VirtualAddress};
-use crate::paging::mapper::PageFlushAll;
+use crate::paging::{ActivePageTable, Page, PageFlags, VirtualAddress, mapper::PageFlushAll, entry::EntryFlags};
 
 #[cfg(not(feature="slab"))]
 pub use self::linked_list::Allocator;
@@ -19,7 +18,8 @@ unsafe fn map_heap(active_table: &mut ActivePageTable, offset: usize, size: usiz
     let heap_start_page = Page::containing_address(VirtualAddress::new(offset));
     let heap_end_page = Page::containing_address(VirtualAddress::new(offset + size-1));
     for page in Page::range_inclusive(heap_start_page, heap_end_page) {
-        let result = active_table.map(page, PageFlags::new().write(true));
+        let result = active_table.map(page, PageFlags::new().write(true).custom_flag(EntryFlags::GLOBAL.bits(), cfg!(not(feature = "pti"))))
+            .expect("failed to map kernel heap");
         flush_all.consume(result);
     }
 
diff --git a/src/arch/x86_64/paging/entry.rs b/src/arch/x86_64/paging/entry.rs
index 09f0afde..4092ab18 100644
--- a/src/arch/x86_64/paging/entry.rs
+++ b/src/arch/x86_64/paging/entry.rs
@@ -13,6 +13,7 @@ bitflags! {
     pub struct EntryFlags: usize {
         const NO_CACHE =        1 << 4;
         const HUGE_PAGE =       1 << 7;
+        const GLOBAL =          1 << 8;
     }
 }
 
diff --git a/src/arch/x86_64/paging/mapper.rs b/src/arch/x86_64/paging/mapper.rs
index ab0c5c7f..babefcca 100644
--- a/src/arch/x86_64/paging/mapper.rs
+++ b/src/arch/x86_64/paging/mapper.rs
@@ -1,5 +1,5 @@
-use crate::memory::{allocate_frames, deallocate_frames, Frame};
 use super::{linear_phys_to_virt, Page, PAGE_SIZE, PageFlags, PhysicalAddress, VirtualAddress};
+use crate::memory::{allocate_frames, deallocate_frames, Enomem, Frame};
 
 use super::RmmA;
 use super::table::{Table, Level4};
@@ -36,8 +36,7 @@ impl<'table> Mapper<'table> {
     /// For this to be safe, the caller must have exclusive access to the frame argument. The frame
     /// must also be valid, and the frame must not outlive the lifetime.
     pub unsafe fn from_p4_unchecked(frame: &mut Frame) -> Self {
-        let phys = frame.start_address();
-        let virt = linear_phys_to_virt(phys)
+        let virt = linear_phys_to_virt(frame.start_address())
             .expect("expected page table frame to fit within linear mapping");
 
         Self {
@@ -70,9 +69,9 @@ impl<'table> Mapper<'table> {
     }
 
     /// Map a page to the next free frame
-    pub fn map(&mut self, page: Page, flags: PageFlags<RmmA>) -> PageFlush<RmmA> {
-        let frame = allocate_frames(1).expect("out of frames");
-        self.map_to(page, frame, flags)
+    pub fn map(&mut self, page: Page, flags: PageFlags<RmmA>) -> Result<PageFlush<RmmA>, Enomem> {
+        let frame = allocate_frames(1).ok_or(Enomem)?;
+        Ok(self.map_to(page, frame, flags))
     }
 
     /// Update flags for a page
diff --git a/src/arch/x86_64/paging/mod.rs b/src/arch/x86_64/paging/mod.rs
index 02dac6de..70d1f0d6 100644
--- a/src/arch/x86_64/paging/mod.rs
+++ b/src/arch/x86_64/paging/mod.rs
@@ -8,6 +8,7 @@ use x86::msr;
 
 use crate::memory::Frame;
 
+use self::entry::EntryFlags;
 use self::mapper::{Mapper, PageFlushAll};
 use self::table::{Level4, Table};
 
@@ -94,8 +95,8 @@ unsafe fn init_pat() {
     );
 }
 
-/// Map TSS
-unsafe fn map_tss(cpu_id: usize, mapper: &mut Mapper) -> PageFlushAll<RmmA> {
+/// Map percpu
+unsafe fn map_percpu(cpu_id: usize, mapper: &mut Mapper) -> PageFlushAll<RmmA> {
     extern "C" {
         /// The starting byte of the thread data segment
         static mut __tdata_start: u8;
@@ -115,7 +116,11 @@ unsafe fn map_tss(cpu_id: usize, mapper: &mut Mapper) -> PageFlushAll<RmmA> {
     let start_page = Page::containing_address(VirtualAddress::new(start));
     let end_page = Page::containing_address(VirtualAddress::new(end - 1));
     for page in Page::range_inclusive(start_page, end_page) {
-        let result = mapper.map(page, PageFlags::new().write(true));
+        let result = mapper.map(
+            page,
+            PageFlags::new().write(true).custom_flag(EntryFlags::GLOBAL.bits(), cfg!(not(feature = "pti"))),
+        )
+        .expect("failed to allocate page table frames while mapping percpu");
         flush_all.consume(result);
     }
     flush_all
@@ -188,7 +193,7 @@ pub unsafe fn init(
 
     let mut active_table = ActivePageTable::new_unlocked(TableKind::User);
 
-    let flush_all = map_tss(cpu_id, &mut active_table);
+    let flush_all = map_percpu(cpu_id, &mut active_table);
     flush_all.flush();
 
     return (active_table, init_tcb(cpu_id));
@@ -205,7 +210,7 @@ pub unsafe fn init_ap(
     let mut new_table = InactivePageTable::from_address(bsp_table);
 
     {
-        let flush_all = map_tss(cpu_id, &mut new_table.mapper());
+        let flush_all = map_percpu(cpu_id, &mut new_table.mapper());
         // The flush can be ignored as this is not the active table. See later active_table.switch
         flush_all.ignore();
     };
diff --git a/src/context/memory.rs b/src/context/memory.rs
index 46512997..d422457f 100644
--- a/src/context/memory.rs
+++ b/src/context/memory.rs
@@ -339,7 +339,9 @@ impl Grant {
         let start_page = Page::containing_address(to);
         let end_page = Page::containing_address(VirtualAddress::new(to.data() + size - 1));
         for page in Page::range_inclusive(start_page, end_page) {
-            let result = active_table.map(page, flags);
+            let result = active_table
+                .map(page, flags)
+                .expect("TODO: handle ENOMEM in Grant::map");
             flush_all.consume(result);
         }
 
@@ -408,7 +410,8 @@ impl Grant {
 
             let new_page = Page::containing_address(VirtualAddress::new(page.start_address().data() - self.region.start.data() + new_start.data()));
             if self.owned {
-                let result = active_table.map(new_page, PageFlags::new().write(true));
+                let result = active_table.map(new_page, PageFlags::new().write(true))
+                    .expect("TODO: handle ENOMEM in Grant::secret_clone");
                 flush_all.consume(result);
             } else {
                 let result = active_table.map_to(new_page, frame, flags);
@@ -692,7 +695,9 @@ impl Memory {
         let flush_all = PageFlushAll::new();
 
         for page in self.pages() {
-            let result = active_table.map(page, self.flags);
+            let result = active_table
+                .map(page, self.flags)
+                .expect("TODO: handle ENOMEM in Memory::map");
             flush_all.consume(result);
         }
 
@@ -769,7 +774,9 @@ impl Memory {
             let end_page = Page::containing_address(VirtualAddress::new(self.start.data() + new_size - 1));
             for page in Page::range_inclusive(start_page, end_page) {
                 if active_table.translate_page(page).is_none() {
-                    let result = active_table.map(page, self.flags);
+                    let result = active_table
+                        .map(page, self.flags)
+                        .expect("TODO: Handle OOM in Memory::resize");
                     flush_all.consume(result);
                 }
             }
diff --git a/src/memory/mod.rs b/src/memory/mod.rs
index f2f09292..98946660 100644
--- a/src/memory/mod.rs
+++ b/src/memory/mod.rs
@@ -118,3 +118,6 @@ impl Iterator for FrameIter {
         }
     }
 }
+
+#[derive(Debug)]
+pub struct Enomem;
-- 
GitLab