From df75b8d037653e286cbd04b447a38e6ce01b7ae6 Mon Sep 17 00:00:00 2001
From: Jeremy Soller <jackpot51@gmail.com>
Date: Thu, 1 Dec 2022 10:48:53 -0700
Subject: [PATCH] Semaphore improvements

---
 src/platform/pte.rs   | 27 +++++++++++++++++-------
 src/sync/semaphore.rs | 49 ++++++++++++++++++++++++++++++-------------
 2 files changed, 53 insertions(+), 23 deletions(-)

diff --git a/src/platform/pte.rs b/src/platform/pte.rs
index 02a08bbdf..cdbbe3a6b 100644
--- a/src/platform/pte.rs
+++ b/src/platform/pte.rs
@@ -8,7 +8,7 @@ use core::{
 };
 
 use crate::{
-    header::{sys_mman, time::timespec},
+    header::{sys_mman, time::{CLOCK_MONOTONIC, clock_gettime, timespec}},
     ld_so::{
         linker::Linker,
         tcb::{Master, Tcb},
@@ -237,7 +237,7 @@ pub unsafe extern "C" fn pte_osThreadWaitForEnd(handle: pte_osThreadHandle) -> p
 #[no_mangle]
 pub unsafe extern "C" fn pte_osThreadCancel(handle: pte_osThreadHandle) -> pte_osResult {
     //TODO: allow cancel of thread
-    println!("pte_osThreadCancel");
+    eprintln!("pte_osThreadCancel");
     PTE_OS_OK
 }
 
@@ -338,7 +338,7 @@ pub unsafe extern "C" fn pte_osSemaphorePost(
     handle: pte_osSemaphoreHandle,
     count: c_int,
 ) -> pte_osResult {
-    (*handle).post();
+    (*handle).post(count);
     PTE_OS_OK
 }
 
@@ -348,15 +348,26 @@ pub unsafe extern "C" fn pte_osSemaphorePend(
     pTimeout: *mut c_uint,
 ) -> pte_osResult {
     let timeout_opt = if !pTimeout.is_null() {
+        // Get current time
+        let mut time = timespec::default();
+        clock_gettime(CLOCK_MONOTONIC, &mut time);
+
+        // Add timeout to time
         let timeout = *pTimeout as i64;
-        let tv_sec = timeout / 1000;
-        let tv_nsec = (timeout % 1000) * 1000000;
-        Some(timespec { tv_sec, tv_nsec })
+        time.tv_sec += timeout / 1000;
+        time.tv_nsec += (timeout % 1000) * 1_000_000;
+        while time.tv_nsec >= 1_000_000_000 {
+            time.tv_sec += 1;
+            time.tv_nsec -= 1_000_000_000;
+        }
+        Some(time)
     } else {
         None
     };
-    (*handle).wait(timeout_opt.as_ref());
-    PTE_OS_OK
+    match (*handle).wait(timeout_opt.as_ref()) {
+        Ok(()) => PTE_OS_OK,
+        Err(()) => PTE_OS_TIMEOUT,
+    }
 }
 
 #[no_mangle]
diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs
index b9b00b4f8..1eea73aa1 100644
--- a/src/sync/semaphore.rs
+++ b/src/sync/semaphore.rs
@@ -21,27 +21,46 @@ impl Semaphore {
         }
     }
 
-    pub fn post(&self) {
-        self.lock.fetch_add(1, Ordering::Release);
+    pub fn post(&self, count: c_int) {
+        self.lock.fetch_add(count, Ordering::SeqCst);
     }
 
-    pub fn wait(&self, timeout_opt: Option<&timespec>) {
-        if let Some(timeout) = timeout_opt {
-            println!(
-                "semaphore wait tv_sec: {}, tv_nsec: {}",
-                timeout.tv_sec, timeout.tv_nsec
-            );
+    pub fn try_wait(&self) -> Result<(), ()> {
+        let mut value = self.lock.load(Ordering::SeqCst);
+        if value > 0 {
+            match self.lock.compare_exchange(
+                value,
+                value - 1,
+                Ordering::SeqCst,
+                Ordering::SeqCst
+            ) {
+                Ok(_) => Ok(()),
+                Err(_) => Err(())
+            }
+        } else {
+            Err(())
         }
+    }
+
+    pub fn wait(&self, timeout_opt: Option<&timespec>) -> Result<(), ()> {
+
         loop {
-            while self.lock.load(Ordering::Acquire) < 1 {
-                //spin_loop();
-                Sys::sched_yield();
+            match self.try_wait() {
+                Ok(()) => {
+                    return Ok(());
+                }
+                Err(()) => ()
             }
-            let tmp = self.lock.fetch_sub(1, Ordering::AcqRel);
-            if tmp >= 1 {
-                break;
+            if let Some(timeout) = timeout_opt {
+                let mut time = timespec::default();
+                clock_gettime(CLOCK_MONOTONIC, &mut time);
+                if (time.tv_sec > timeout.tv_sec) ||
+                   (time.tv_sec == timeout.tv_sec && time.tv_nsec >= timeout.tv_nsec)
+                {
+                    return Err(())
+                }
             }
-            self.lock.fetch_add(1, Ordering::Release);
+            Sys::sched_yield();
         }
     }
 }
-- 
GitLab