From d82ffd16cbee1eaa203acac9ca44f0b9a73b0b9d Mon Sep 17 00:00:00 2001
From: Jeremy Soller <jackpot51@gmail.com>
Date: Tue, 9 Jan 2018 22:16:14 -0700
Subject: [PATCH] WIP: Add per-cpu interrupt stack used before mapping kernel
 heap

---
 src/arch/x86_64/gdt.rs    |  8 +++-
 src/arch/x86_64/macros.rs | 60 +++++++++++++-------------
 src/arch/x86_64/pti.rs    | 91 ++++++++++++++++++++++++++++++++++-----
 src/arch/x86_64/start.rs  |  3 ++
 src/context/switch.rs     |  8 +++-
 5 files changed, 128 insertions(+), 42 deletions(-)

diff --git a/src/arch/x86_64/gdt.rs b/src/arch/x86_64/gdt.rs
index 96e1b99..f162bb5 100644
--- a/src/arch/x86_64/gdt.rs
+++ b/src/arch/x86_64/gdt.rs
@@ -124,7 +124,13 @@ pub unsafe fn init(tcb_offset: usize, stack_offset: usize) {
     GDT[GDT_TSS].set_limit(mem::size_of::<TaskStateSegment>() as u32);
 
     // Set the stack pointer when coming back from userspace
-    TSS.rsp[0] = stack_offset as u64;
+    if cfg!(feature = "pti") {
+        use arch::x86_64::pti::{PTI_CPU_STACK, PTI_CONTEXT_STACK};
+        TSS.rsp[0] = (PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()) as u64;
+        PTI_CONTEXT_STACK = stack_offset;
+    } else {
+        TSS.rsp[0] = stack_offset as u64;
+    }
 
     // Load the new GDT, which is correctly located in thread local storage
     dtables::lgdt(&GDTR);
diff --git a/src/arch/x86_64/macros.rs b/src/arch/x86_64/macros.rs
index c336dae..a7b5a3b 100644
--- a/src/arch/x86_64/macros.rs
+++ b/src/arch/x86_64/macros.rs
@@ -166,22 +166,22 @@ macro_rules! interrupt {
         pub unsafe extern fn $name () {
             #[inline(never)]
             unsafe fn inner() {
-                // Map kernel
-                $crate::arch::x86_64::pti::map();
-
                 $func
-
-                // Unmap kernel
-                $crate::arch::x86_64::pti::unmap();
             }
 
             // Push scratch registers
             scratch_push!();
             fs_push!();
 
+            // Map kernel
+            $crate::arch::x86_64::pti::map();
+
             // Call inner rust function
             inner();
 
+            // Unmap kernel
+            $crate::arch::x86_64::pti::unmap();
+
             // Pop scratch registers and return
             fs_pop!();
             scratch_pop!();
@@ -213,13 +213,7 @@ macro_rules! interrupt_stack {
         pub unsafe extern fn $name () {
             #[inline(never)]
             unsafe fn inner($stack: &mut $crate::arch::x86_64::macros::InterruptStack) {
-                // Map kernel
-                $crate::arch::x86_64::pti::map();
-
                 $func
-
-                // Unmap kernel
-                $crate::arch::x86_64::pti::unmap();
             }
 
             // Push scratch registers
@@ -230,9 +224,15 @@ macro_rules! interrupt_stack {
             let rsp: usize;
             asm!("" : "={rsp}"(rsp) : : : "intel", "volatile");
 
+            // Map kernel
+            $crate::arch::x86_64::pti::map();
+
             // Call inner rust function
             inner(&mut *(rsp as *mut $crate::arch::x86_64::macros::InterruptStack));
 
+            // Unmap kernel
+            $crate::arch::x86_64::pti::unmap();
+
             // Pop scratch registers and return
             fs_pop!();
             scratch_pop!();
@@ -266,13 +266,7 @@ macro_rules! interrupt_error {
         pub unsafe extern fn $name () {
             #[inline(never)]
             unsafe fn inner($stack: &$crate::arch::x86_64::macros::InterruptErrorStack) {
-                // Map kernel
-                $crate::arch::x86_64::pti::map();
-
                 $func
-
-                // Unmap kernel
-                $crate::arch::x86_64::pti::unmap();
             }
 
             // Push scratch registers
@@ -283,9 +277,15 @@ macro_rules! interrupt_error {
             let rsp: usize;
             asm!("" : "={rsp}"(rsp) : : : "intel", "volatile");
 
+            // Map kernel
+            $crate::arch::x86_64::pti::map();
+
             // Call inner rust function
             inner(&*(rsp as *const $crate::arch::x86_64::macros::InterruptErrorStack));
 
+            // Unmap kernel
+            $crate::arch::x86_64::pti::unmap();
+
             // Pop scratch registers, error code, and return
             fs_pop!();
             scratch_pop!();
@@ -320,13 +320,7 @@ macro_rules! interrupt_stack_p {
         pub unsafe extern fn $name () {
             #[inline(never)]
             unsafe fn inner($stack: &mut $crate::arch::x86_64::macros::InterruptStackP) {
-                // Map kernel
-                $crate::arch::x86_64::pti::map();
-
                 $func
-
-                // Unmap kernel
-                $crate::arch::x86_64::pti::unmap();
             }
 
             // Push scratch registers
@@ -338,9 +332,15 @@ macro_rules! interrupt_stack_p {
             let rsp: usize;
             asm!("" : "={rsp}"(rsp) : : : "intel", "volatile");
 
+            // Map kernel
+            $crate::arch::x86_64::pti::map();
+
             // Call inner rust function
             inner(&mut *(rsp as *mut $crate::arch::x86_64::macros::InterruptStackP));
 
+            // Unmap kernel
+            $crate::arch::x86_64::pti::unmap();
+
             // Pop scratch registers and return
             fs_pop!();
             preserved_pop!();
@@ -377,13 +377,7 @@ macro_rules! interrupt_error_p {
         pub unsafe extern fn $name () {
             #[inline(never)]
             unsafe fn inner($stack: &$crate::arch::x86_64::macros::InterruptErrorStackP) {
-                // Map kernel
-                $crate::arch::x86_64::pti::map();
-
                 $func
-
-                // Unmap kernel
-                $crate::arch::x86_64::pti::unmap();
             }
 
             // Push scratch registers
@@ -395,9 +389,15 @@ macro_rules! interrupt_error_p {
             let rsp: usize;
             asm!("" : "={rsp}"(rsp) : : : "intel", "volatile");
 
+            // Map kernel
+            $crate::arch::x86_64::pti::map();
+
             // Call inner rust function
             inner(&*(rsp as *const $crate::arch::x86_64::macros::InterruptErrorStackP));
 
+            // Unmap kernel
+            $crate::arch::x86_64::pti::unmap();
+
             // Pop scratch registers, error code, and return
             fs_pop!();
             preserved_pop!();
diff --git a/src/arch/x86_64/pti.rs b/src/arch/x86_64/pti.rs
index bee8b3d..bb38ea5 100644
--- a/src/arch/x86_64/pti.rs
+++ b/src/arch/x86_64/pti.rs
@@ -1,19 +1,90 @@
+use core::ptr;
+
+use memory::Frame;
+use paging::ActivePageTable;
+use paging::entry::EntryFlags;
+
 #[cfg(feature = "pti")]
-#[inline(always)]
+#[thread_local]
+pub static mut PTI_CPU_STACK: [u8; 256] = [0; 256];
+
+#[cfg(feature = "pti")]
+#[thread_local]
+pub static mut PTI_CONTEXT_STACK: usize = 0;
+
+#[cfg(feature = "pti")]
+#[inline(never)]
+#[naked]
+unsafe fn switch_stack(old: usize, new: usize) {
+    asm!("xchg bx, bx" : : : : "intel", "volatile");
+
+    let old_rsp: usize;
+    let old_rbp: usize;
+    asm!("" : "={rsp}"(old_rsp), "={rbp}"(old_rbp) : : : "intel", "volatile");
+
+    let offset_rsp = old - old_rsp;
+    let offset_rbp = old - old_rbp;
+
+    let new_rsp = new - offset_rsp;
+    let new_rbp = new - offset_rbp;
+
+    ptr::copy_nonoverlapping(
+        old_rsp as *const u8,
+        new_rsp as *mut u8,
+        offset_rsp
+    );
+
+    asm!("xchg bx, bx" : : : : "intel", "volatile");
+
+    asm!("" : : "{rsp}"(new_rsp), "{rbp}"(new_rbp) : : "intel", "volatile");
+}
+
+#[cfg(feature = "pti")]
+#[inline(never)]
+#[naked]
 pub unsafe fn map() {
-    let _cr3: usize;
-    asm!("mov $0, cr3
-          mov cr3, $0"
-          : "=r"(_cr3) : : "memory" : "intel", "volatile");
+    asm!("xchg bx, bx" : : : : "intel", "volatile");
+
+    // {
+    //     let mut active_table = unsafe { ActivePageTable::new() };
+    //
+    //     // Map kernel heap
+    //     let address = active_table.p4()[::KERNEL_HEAP_PML4].address();
+    //     let frame = Frame::containing_address(address);
+    //     let mut flags = active_table.p4()[::KERNEL_HEAP_PML4].flags();
+    //     flags.remove(EntryFlags::PRESENT);
+    //     active_table.p4_mut()[::KERNEL_HEAP_PML4].set(frame, flags);
+    //
+    //     // Reload page tables
+    //     active_table.flush_all();
+    // }
+
+    // Switch to per-context stack
+    switch_stack(PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len(), PTI_CONTEXT_STACK);
 }
 
 #[cfg(feature = "pti")]
-#[inline(always)]
+#[inline(never)]
+#[naked]
 pub unsafe fn unmap() {
-    let _cr3: usize;
-    asm!("mov $0, cr3
-          mov cr3, $0"
-          : "=r"(_cr3) : : "memory" : "intel", "volatile");
+    asm!("xchg bx, bx" : : : : "intel", "volatile");
+
+    // Switch to per-CPU stack
+    switch_stack(PTI_CONTEXT_STACK, PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len());
+
+    // {
+    //     let mut active_table = unsafe { ActivePageTable::new() };
+    //
+    //     // Unmap kernel heap
+    //     let address = active_table.p4()[::KERNEL_HEAP_PML4].address();
+    //     let frame = Frame::containing_address(address);
+    //     let mut flags = active_table.p4()[::KERNEL_HEAP_PML4].flags();
+    //     flags.insert(EntryFlags::PRESENT);
+    //     active_table.p4_mut()[::KERNEL_HEAP_PML4].set(frame, flags);
+    //
+    //     // Reload page tables
+    //     active_table.flush_all();
+    // }
 }
 
 #[cfg(not(feature = "pti"))]
diff --git a/src/arch/x86_64/start.rs b/src/arch/x86_64/start.rs
index ddc7b61..c98135a 100644
--- a/src/arch/x86_64/start.rs
+++ b/src/arch/x86_64/start.rs
@@ -190,7 +190,10 @@ pub unsafe extern fn kstart_ap(args_ptr: *const KernelArgsAp) -> ! {
     ::kmain_ap(cpu_id);
 }
 
+#[naked]
 pub unsafe fn usermode(ip: usize, sp: usize, arg: usize) -> ! {
+    asm!("xchg bx, bx" : : : : "intel", "volatile");
+
     // Unmap kernel
     pti::unmap();
 
diff --git a/src/context/switch.rs b/src/context/switch.rs
index 3e0c947..3cc89fc 100644
--- a/src/context/switch.rs
+++ b/src/context/switch.rs
@@ -118,7 +118,13 @@ pub unsafe fn switch() -> bool {
         (&mut *from_ptr).running = false;
         (&mut *to_ptr).running = true;
         if let Some(ref stack) = (*to_ptr).kstack {
-            gdt::TSS.rsp[0] = (stack.as_ptr() as usize + stack.len() - 256) as u64;
+            if cfg!(feature = "pti") {
+                use arch::x86_64::pti::{PTI_CPU_STACK, PTI_CONTEXT_STACK};
+                gdt::TSS.rsp[0] = (PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()) as u64;
+                PTI_CONTEXT_STACK = stack.as_ptr() as usize + stack.len();
+            } else {
+                gdt::TSS.rsp[0] = (stack.as_ptr() as usize + stack.len()) as u64;
+            }
         }
         CONTEXT_ID.store((&mut *to_ptr).id, Ordering::SeqCst);
     }
-- 
GitLab