diff --git a/src/arch/x86_64/gdt.rs b/src/arch/x86_64/gdt.rs index f162bb57d445ea02966c9d91feb775670dee8dcb..ca846074fefc917d339e1a6845569b86b389eef2 100644 --- a/src/arch/x86_64/gdt.rs +++ b/src/arch/x86_64/gdt.rs @@ -88,6 +88,18 @@ pub static mut TSS: TaskStateSegment = TaskStateSegment { iomap_base: 0xFFFF }; +#[cfg(feature = "pti")] +pub unsafe fn set_tss_stack(stack: usize) { + use arch::x86_64::pti::{PTI_CPU_STACK, PTI_CONTEXT_STACK}; + TSS.rsp[0] = (PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()) as u64; + PTI_CONTEXT_STACK = stack; +} + +#[cfg(not(feature = "pti"))] +pub unsafe fn set_tss_stack(stack: usize) { + TSS.rsp[0] = stack as u64; +} + /// Initialize GDT pub unsafe fn init(tcb_offset: usize, stack_offset: usize) { // Setup the initial GDT with TLS, so we can setup the TLS GDT (a little confusing) @@ -124,13 +136,7 @@ pub unsafe fn init(tcb_offset: usize, stack_offset: usize) { GDT[GDT_TSS].set_limit(mem::size_of::<TaskStateSegment>() as u32); // Set the stack pointer when coming back from userspace - if cfg!(feature = "pti") { - use arch::x86_64::pti::{PTI_CPU_STACK, PTI_CONTEXT_STACK}; - TSS.rsp[0] = (PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()) as u64; - PTI_CONTEXT_STACK = stack_offset; - } else { - TSS.rsp[0] = stack_offset as u64; - } + set_tss_stack(stack_offset); // Load the new GDT, which is correctly located in thread local storage dtables::lgdt(&GDTR); diff --git a/src/arch/x86_64/interrupt/syscall.rs b/src/arch/x86_64/interrupt/syscall.rs index 89a867db27f423f10b954d0f7b9e53e87a974594..359b33b12843ff8e5cbbcc45d4b9fa7fd8da1855 100644 --- a/src/arch/x86_64/interrupt/syscall.rs +++ b/src/arch/x86_64/interrupt/syscall.rs @@ -4,64 +4,65 @@ use syscall; #[naked] pub unsafe extern fn syscall() { #[inline(never)] - unsafe fn inner(stack: &mut SyscallStack) { - let mut a; + unsafe fn inner(stack: &mut SyscallStack) -> usize { let rbp; - asm!("" : "={rax}"(a), "={rbp}"(rbp) - : : : "intel", "volatile"); + asm!("" : "={rbp}"(rbp) : : : "intel", "volatile"); - // Map kernel - pti::map(); + println!("{:X}, {:X}", stack.rax, stack.rbx); - a = syscall::syscall(a, stack.rbx, stack.rcx, stack.rdx, stack.rsi, stack.rdi, rbp, stack); - - // Unmap kernel - pti::unmap(); - - asm!("" : : "{rax}"(a) : : "intel", "volatile"); + syscall::syscall(stack.rax, stack.rbx, stack.rcx, stack.rdx, stack.rsi, stack.rdi, rbp, stack) } - // Push scratch registers, minus rax for the return value - asm!("push rcx - push rdx - push rdi - push rsi - push r8 - push r9 - push r10 - push r11 - push rbx - push fs - mov r11, 0x18 - mov fs, r11" - : : : : "intel", "volatile"); + // Push scratch registers + asm!("push rax + push rbx + push rcx + push rdx + push rdi + push rsi + push r8 + push r9 + push r10 + push r11 + push fs + mov r11, 0x18 + mov fs, r11" + : : : : "intel", "volatile"); // Get reference to stack variables let rsp: usize; asm!("" : "={rsp}"(rsp) : : : "intel", "volatile"); - inner(&mut *(rsp as *mut SyscallStack)); + // Map kernel + pti::map(); + + let a = inner(&mut *(rsp as *mut SyscallStack)); + + // Unmap kernel + pti::unmap(); + + asm!("" : : "{rax}"(a) : : "intel", "volatile"); // Interrupt return asm!("pop fs - pop rbx - pop r11 - pop r10 - pop r9 - pop r8 - pop rsi - pop rdi - pop rdx - pop rcx - iretq" - : : : : "intel", "volatile"); + pop r11 + pop r10 + pop r9 + pop r8 + pop rsi + pop rdi + pop rdx + pop rcx + pop rbx + add rsp, 8 + iretq" + : : : : "intel", "volatile"); } #[allow(dead_code)] #[repr(packed)] pub struct SyscallStack { pub fs: usize, - pub rbx: usize, pub r11: usize, pub r10: usize, pub r9: usize, @@ -70,6 +71,8 @@ pub struct SyscallStack { pub rdi: usize, pub rdx: usize, pub rcx: usize, + pub rbx: usize, + pub rax: usize, pub rip: usize, pub cs: usize, pub rflags: usize, diff --git a/src/arch/x86_64/pti.rs b/src/arch/x86_64/pti.rs index bb38ea5e65e75333a63d65ecd5b346e33e9b736b..9124c92f1b95089b7cf88294cd4d9b119ee43417 100644 --- a/src/arch/x86_64/pti.rs +++ b/src/arch/x86_64/pti.rs @@ -13,20 +13,14 @@ pub static mut PTI_CPU_STACK: [u8; 256] = [0; 256]; pub static mut PTI_CONTEXT_STACK: usize = 0; #[cfg(feature = "pti")] -#[inline(never)] -#[naked] +#[inline(always)] unsafe fn switch_stack(old: usize, new: usize) { - asm!("xchg bx, bx" : : : : "intel", "volatile"); - let old_rsp: usize; - let old_rbp: usize; - asm!("" : "={rsp}"(old_rsp), "={rbp}"(old_rbp) : : : "intel", "volatile"); + asm!("" : "={rsp}"(old_rsp) : : : "intel", "volatile"); let offset_rsp = old - old_rsp; - let offset_rbp = old - old_rbp; let new_rsp = new - offset_rsp; - let new_rbp = new - offset_rbp; ptr::copy_nonoverlapping( old_rsp as *const u8, @@ -34,17 +28,12 @@ unsafe fn switch_stack(old: usize, new: usize) { offset_rsp ); - asm!("xchg bx, bx" : : : : "intel", "volatile"); - - asm!("" : : "{rsp}"(new_rsp), "{rbp}"(new_rbp) : : "intel", "volatile"); + asm!("" : : "{rsp}"(new_rsp) : : "intel", "volatile"); } #[cfg(feature = "pti")] -#[inline(never)] -#[naked] +#[inline(always)] pub unsafe fn map() { - asm!("xchg bx, bx" : : : : "intel", "volatile"); - // { // let mut active_table = unsafe { ActivePageTable::new() }; // @@ -64,11 +53,8 @@ pub unsafe fn map() { } #[cfg(feature = "pti")] -#[inline(never)] -#[naked] +#[inline(always)] pub unsafe fn unmap() { - asm!("xchg bx, bx" : : : : "intel", "volatile"); - // Switch to per-CPU stack switch_stack(PTI_CONTEXT_STACK, PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()); diff --git a/src/arch/x86_64/start.rs b/src/arch/x86_64/start.rs index c98135aeeabd609c508c2388db21f9098c68ee2b..4f6c730090a1fcea85259280fee50fd119233743 100644 --- a/src/arch/x86_64/start.rs +++ b/src/arch/x86_64/start.rs @@ -192,31 +192,52 @@ pub unsafe extern fn kstart_ap(args_ptr: *const KernelArgsAp) -> ! { #[naked] pub unsafe fn usermode(ip: usize, sp: usize, arg: usize) -> ! { - asm!("xchg bx, bx" : : : : "intel", "volatile"); + asm!("push r10 + push r11 + push r12 + push r13 + push r14 + push r15" + : // No output + : "{r10}"(gdt::GDT_USER_DATA << 3 | 3), // Data segment + "{r11}"(sp), // Stack pointer + "{r12}"(1 << 9), // Flags - Set interrupt enable flag + "{r13}"(gdt::GDT_USER_CODE << 3 | 3), // Code segment + "{r14}"(ip), // IP + "{r15}"(arg) // Argument + : // No clobbers + : "intel", "volatile"); // Unmap kernel pti::unmap(); // Go to usermode - asm!("mov ds, r10d - mov es, r10d - mov fs, r11d - mov gs, r10d - push r10 - push r12 - push r13 - push r14 - push r15 - iretq" - : // No output because it never returns - : "{r10}"(gdt::GDT_USER_DATA << 3 | 3), // Data segment - "{r11}"(gdt::GDT_USER_TLS << 3 | 3), // TLS segment - "{r12}"(sp), // Stack pointer - "{r13}"(1 << 9), // Flags - Set interrupt enable flag - "{r14}"(gdt::GDT_USER_CODE << 3 | 3), // Code segment - "{r15}"(ip) // IP - "{rdi}"(arg) // Argument - : // No clobers because it never returns - : "intel", "volatile"); + asm!("mov ds, r14d + mov es, r14d + mov fs, r15d + mov gs, r14d + xor rax, rax + xor rbx, rbx + xor rcx, rcx + xor rdx, rdx + xor rsi, rsi + xor rdi, rdi + xor rbp, rbp + xor r8, r8 + xor r9, r9 + xor r10, r10 + xor r11, r11 + xor r12, r12 + xor r13, r13 + xor r14, r14 + xor r15, r15 + finit + pop rdi + iretq" + : // No output because it never returns + : "{r14}"(gdt::GDT_USER_DATA << 3 | 3), // Data segment + "{r15}"(gdt::GDT_USER_TLS << 3 | 3) // TLS segment + : // No clobbers because it never returns + : "intel", "volatile"); unreachable!(); } diff --git a/src/context/switch.rs b/src/context/switch.rs index 3cc89fcf311f4eb1a1601a90a4fd31d65139b57b..4c0e91109fb428ca0e10e7cd74803b0b51f913d3 100644 --- a/src/context/switch.rs +++ b/src/context/switch.rs @@ -118,13 +118,7 @@ pub unsafe fn switch() -> bool { (&mut *from_ptr).running = false; (&mut *to_ptr).running = true; if let Some(ref stack) = (*to_ptr).kstack { - if cfg!(feature = "pti") { - use arch::x86_64::pti::{PTI_CPU_STACK, PTI_CONTEXT_STACK}; - gdt::TSS.rsp[0] = (PTI_CPU_STACK.as_ptr() as usize + PTI_CPU_STACK.len()) as u64; - PTI_CONTEXT_STACK = stack.as_ptr() as usize + stack.len(); - } else { - gdt::TSS.rsp[0] = (stack.as_ptr() as usize + stack.len()) as u64; - } + gdt::set_tss_stack(stack.as_ptr() as usize + stack.len()); } CONTEXT_ID.store((&mut *to_ptr).id, Ordering::SeqCst); }