From e34a3cee42d4211196b3aab2177e1530a7af38a4 Mon Sep 17 00:00:00 2001
From: 4lDO2 <4lDO2@protonmail.com>
Date: Mon, 26 Jun 2023 19:44:42 +0200
Subject: [PATCH] Fix mmap/munmap, invert borrowed/owned refcount.

---
 src/context/memory.rs | 90 ++++++++++++++++++++++++++-----------------
 src/memory/mod.rs     | 15 +++++---
 2 files changed, 63 insertions(+), 42 deletions(-)

diff --git a/src/context/memory.rs b/src/context/memory.rs
index 768baf80..b3f3820f 100644
--- a/src/context/memory.rs
+++ b/src/context/memory.rs
@@ -1,12 +1,9 @@
-use alloc::boxed::Box;
 use alloc::collections::BTreeMap;
 use alloc::{sync::Arc, vec::Vec};
 use core::cmp;
 use core::fmt::Debug;
-use core::mem::ManuallyDrop;
 use core::num::NonZeroUsize;
-use core::ops::Deref;
-use core::sync::atomic::{AtomicUsize, Ordering};
+use core::sync::atomic::Ordering;
 use spin::{RwLock, RwLockWriteGuard, Once, RwLockUpgradableGuard};
 use syscall::{
     flag::MapFlags,
@@ -15,9 +12,8 @@ use syscall::{
 use rmm::Arch as _;
 
 use crate::arch::paging::PAGE_SIZE;
-use crate::common::{try_box_slice_new, try_new_vec_with_exact_size};
 use crate::context::file::FileDescriptor;
-use crate::memory::{Enomem, Frame, RaiiFrame, get_page_info, PageInfo};
+use crate::memory::{Enomem, Frame, get_page_info, PageInfo};
 use crate::paging::mapper::{Flusher, InactiveFlusher, PageFlushAll};
 use crate::paging::{KernelMapper, Page, PageFlags, PageMapper, RmmA, TableKind, VirtualAddress};
 
@@ -152,17 +148,25 @@ impl AddrSpace {
         }
         Ok(())
     }
-    pub fn munmap(mut self: RwLockWriteGuard<'_, Self>, requested_span: PageSpan) {
+    pub fn munmap(mut self: RwLockWriteGuard<'_, Self>, mut requested_span: PageSpan) {
         let mut notify_files = Vec::new();
 
         let mut flusher = PageFlushAll::new();
 
-        // TODO: Allocating may even be wrong!
-        let conflicting: Vec<PageSpan> = self.grants.conflicts(requested_span).map(|(base, info)| PageSpan::new(base, info.page_count)).collect();
+        let this = &mut *self;
+
+        let next = |grants: &mut UserGrants, span: PageSpan| grants.conflicts(span).map(|(base, info)| PageSpan::new(base, info.page_count)).next();
+
+        while let Some(conflicting_span) = next(&mut this.grants, requested_span) {
+            let grant = this.grants.remove(conflicting_span.base).expect("conflicting region didn't exist");
+
+            let intersection = conflicting_span.intersection(requested_span);
+
+            requested_span = {
+                let offset = conflicting_span.base.offset_from(requested_span.base);
+                PageSpan::new(conflicting_span.end(), requested_span.count - offset - conflicting_span.count)
+            };
 
-        for conflict in conflicting {
-            let grant = self.grants.remove(conflict.base).expect("conflicting region didn't exist");
-            let intersection = conflict.intersection(requested_span);
             let (before, mut grant, after) = grant.extract(intersection).expect("conflicting region shared no common parts");
 
             // Notify scheme that holds grant
@@ -173,14 +177,15 @@ impl AddrSpace {
 
             // Keep untouched regions
             if let Some(before) = before {
-                self.grants.insert(before);
+                this.grants.insert(before);
             }
             if let Some(after) = after {
-                self.grants.insert(after);
+                this.grants.insert(after);
             }
 
             // Remove irrelevant region
-            grant.unmap(&mut self.table.utable, &mut flusher);
+            grant.unmap(&mut this.table.utable, &mut flusher);
+
         }
         drop(self);
 
@@ -652,14 +657,16 @@ impl Grant {
 
         let src_span = PageSpan::new(src_base, page_count);
 
-        for (src_base, src_grant) in src_address_space.grants.conflicts(src_span) {
-            let grant_span = PageSpan::new(src_base, src_grant.page_count);
+        for (src_grant_base, src_grant) in src_address_space.grants.conflicts(src_span) {
+            let grant_span = PageSpan::new(src_grant_base, src_grant.page_count);
 
             let common_span = src_span.intersection(grant_span);
             let offset_from_src_base = common_span.base.offset_from(src_base);
 
+            let grant_dst_base = dst_base.next_by(offset_from_src_base);
+
             dst_grants.push(Grant {
-                base: dst_base.next_by(offset_from_src_base),
+                base: grant_dst_base,
                 info: GrantInfo {
                     page_count: common_span.count,
                     flags,
@@ -766,14 +773,11 @@ impl Grant {
             };
             let frame = Frame::containing_address(phys);
 
-            match self.info.provider {
-                Provider::Allocated | Provider::External { .. } => {
-                    /*get_page_info(frame)
-                        .expect("allocated frame did not have an associated PageInfo")
-                        .remove_ref(is_cow);*/
-                }
-                _ => (),
-            }
+            let is_cow = !matches!(self.info.provider, Provider::External { .. });
+
+            get_page_info(frame)
+                .expect("allocated frame did not have an associated PageInfo")
+                .remove_ref(is_cow);
 
 
             flusher.consume(flush);
@@ -897,6 +901,7 @@ impl GrantInfo {
 }
 
 impl Drop for GrantInfo {
+    #[track_caller]
     fn drop(&mut self) {
         // XXX: This will not show the address...
         assert!(!self.mapped, "Grant dropped while still mapped: {:#x?}", self);
@@ -1020,7 +1025,7 @@ fn init_frame() -> Result<Frame, PfError> {
     let new_frame = crate::memory::allocate_frames(1).ok_or(PfError::Oom)?;
     let page_info = get_page_info(new_frame).expect("all allocated frames need an associated page info");
     page_info.refcount.store(1, Ordering::Relaxed);
-    page_info.cow_refcount.store(1, Ordering::Relaxed);
+    page_info.borrowed_refcount.store(0, Ordering::Relaxed);
 
     Ok(new_frame)
 }
@@ -1029,7 +1034,7 @@ fn map_zeroed(mapper: &mut PageMapper, page: Page, page_flags: PageFlags<RmmA>,
     let new_frame = init_frame()?;
 
     unsafe {
-        mapper.map_phys(page.start_address(), new_frame.start_address(), page_flags).ok_or(PfError::Oom)?.flush();
+        mapper.map_phys(page.start_address(), new_frame.start_address(), page_flags).ok_or(PfError::Oom)?.ignore();
     }
 
     Ok(new_frame)
@@ -1047,7 +1052,7 @@ pub unsafe fn copy_frame_to_frame_directly(dst: Frame, src: Frame) {
 
 pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Result<(), PfError> {
     let Ok(addr_space) = AddrSpace::current() else {
-        log::warn!("User page fault without address space being set.");
+        log::debug!("User page fault without address space being set.");
         return Err(PfError::Segv);
     };
 
@@ -1055,6 +1060,7 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
     let addr_space = &mut *addr_space;
 
     let Some((grant_base, grant_info)) = addr_space.grants.contains(faulting_page) else {
+        log::debug!("Lacks grant");
         return Err(PfError::Segv);
     };
 
@@ -1065,8 +1071,14 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
         // TODO: has_read
         AccessMode::Read => (),
 
-        AccessMode::Write if !grant_flags.has_write() => return Err(PfError::Segv),
-        AccessMode::InstrFetch if !grant_flags.has_execute() => return Err(PfError::Segv),
+        AccessMode::Write if !grant_flags.has_write() => {
+            log::debug!("Instuction fetch, but grant was not PROT_WRITE.");
+            return Err(PfError::Segv);
+        }
+        AccessMode::InstrFetch if !grant_flags.has_execute() => {
+            log::debug!("Instuction fetch, but grant was not PROT_EXEC.");
+            return Err(PfError::Segv);
+        }
 
         _ => (),
     }
@@ -1088,11 +1100,13 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
 
     let mut allow_writable = true;
 
+    let mut debug = false;
+
     let frame = match grant_info.provider {
         Provider::Allocated if access == AccessMode::Write => {
             match faulting_pageinfo_opt {
                 Some((_, None)) => unreachable!("allocated page needs frame to be valid"),
-                Some((frame, Some(info))) => if info.cow_refcount.load(Ordering::SeqCst) == 1 {
+                Some((frame, Some(info))) => if info.owned_refcount() == 1 {
                     frame
                 } else {
                     cow(&mut addr_space.table.utable, faulting_page, frame, info, grant_flags)?
@@ -1104,7 +1118,7 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
             match faulting_pageinfo_opt {
                 Some((_, None)) => unreachable!("allocated page needs frame to be valid"),
                 Some((frame, Some(page_info))) => {
-                    allow_writable = page_info.cow_refcount.load(Ordering::SeqCst) == 1;
+                    allow_writable = page_info.owned_refcount() == 1;
 
                     frame
                 }
@@ -1119,7 +1133,9 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
             base.next_by(pages_from_grant_start)
         }
         Provider::External { address_space: ref foreign_address_space, src_base } => {
-            let guard = foreign_address_space.read();
+            debug = true;
+
+            let guard = foreign_address_space.upgradeable_read();
             let src_page = src_base.next_by(pages_from_grant_start);
 
             if let Some((phys, _)) = guard.table.utable.translate(src_page.start_address()) {
@@ -1130,14 +1146,16 @@ pub fn try_correcting_page_tables(faulting_page: Page, access: AccessMode) -> Re
 
                 src_frame
             } else {
+                let mut guard = RwLockUpgradableGuard::upgrade(guard);
+
                 // TODO: Should this be called?
-                map_zeroed(&mut addr_space.table.utable, src_page, grant_flags, access == AccessMode::Write)?
+                map_zeroed(&mut guard.table.utable, src_page, grant_flags, access == AccessMode::Write)?
             }
         }
         Provider::Fmap { ref desc } => todo!(),
     };
 
-    if super::context_id().into() == 3 {
+    if super::context_id().into() == 3 && debug {
         //log::info!("Correcting {:?} => {:?} (base {:?} info {:?})", faulting_page, frame, grant_base, grant_info);
     }
     let new_flags = grant_flags.write(grant_flags.has_write() && allow_writable);
diff --git a/src/memory/mod.rs b/src/memory/mod.rs
index ac071138..5a6c30aa 100644
--- a/src/memory/mod.rs
+++ b/src/memory/mod.rs
@@ -193,7 +193,7 @@ impl Drop for RaiiFrame {
 #[derive(Debug)]
 pub struct PageInfo {
     pub refcount: AtomicUsize,
-    pub cow_refcount: AtomicUsize,
+    pub borrowed_refcount: AtomicUsize,
     // TODO: AtomicFlags?
     pub flags: FrameFlags,
     pub _padding: usize,
@@ -252,27 +252,30 @@ impl PageInfo {
     pub fn new() -> Self {
         Self {
             refcount: AtomicUsize::new(0),
-            cow_refcount: AtomicUsize::new(0),
+            borrowed_refcount: AtomicUsize::new(0),
             flags: FrameFlags::NONE,
             _padding: 0,
         }
     }
     pub fn add_ref(&self, cow: bool) {
-        if cow {
-            self.cow_refcount.fetch_add(1, Ordering::Relaxed);
+        if !cow {
+            self.borrowed_refcount.fetch_add(1, Ordering::Relaxed);
         }
         self.refcount.fetch_add(1, Ordering::Relaxed);
 
         core::sync::atomic::fence(Ordering::Release);
     }
     pub fn remove_ref(&self, cow: bool) {
-        if cow {
-            self.cow_refcount.fetch_sub(1, Ordering::Relaxed);
+        if !cow {
+            self.borrowed_refcount.fetch_sub(1, Ordering::Relaxed);
         }
         self.refcount.fetch_sub(1, Ordering::Relaxed);
 
         core::sync::atomic::fence(Ordering::Release);
     }
+    pub fn owned_refcount(&self) -> usize {
+        self.refcount.load(Ordering::SeqCst) - self.borrowed_refcount.load(Ordering::SeqCst)
+    }
 }
 pub fn get_page_info(frame: Frame) -> Option<&'static PageInfo> {
     let sections = SECTIONS.read();
-- 
GitLab