From 41264dc8b145a844e3ab3977d837b8a63b3f36f1 Mon Sep 17 00:00:00 2001
From: Peter Limkilde Svendsen <peter.limkilde@gmail.com>
Date: Tue, 22 Oct 2024 22:12:24 +0000
Subject: [PATCH] Add NulTerminatedInclusive iterator, rewrite/fix
 strchr/wcschr

---
 src/header/string/mod.rs                      | 28 ++++++----
 src/header/wchar/mod.rs                       | 27 ++++++----
 src/iter.rs                                   | 53 +++++++++++++++++++
 tests/Makefile                                |  1 +
 .../expected/bins_static/string/strchr.stdout |  3 +-
 .../expected/bins_static/wchar/wcschr.stderr  |  0
 .../expected/bins_static/wchar/wcschr.stdout  |  0
 tests/string/strchr.c                         |  3 +-
 tests/wchar/wcschr.c                          | 13 +++++
 9 files changed, 107 insertions(+), 21 deletions(-)
 create mode 100644 tests/expected/bins_static/wchar/wcschr.stderr
 create mode 100644 tests/expected/bins_static/wchar/wcschr.stdout
 create mode 100644 tests/wchar/wcschr.c

diff --git a/src/header/string/mod.rs b/src/header/string/mod.rs
index 9ddc6cb37..6dd354619 100644
--- a/src/header/string/mod.rs
+++ b/src/header/string/mod.rs
@@ -6,7 +6,7 @@ use cbitset::BitSet256;
 
 use crate::{
     header::{errno::*, signal},
-    iter::{NulTerminated, SrcDstPtrIter},
+    iter::{NulTerminated, NulTerminatedInclusive, SrcDstPtrIter},
     platform::{self, types::*},
 };
 
@@ -130,16 +130,26 @@ pub unsafe extern "C" fn memset(s: *mut c_void, c: c_int, n: size_t) -> *mut c_v
     s
 }
 
+/// See <https://pubs.opengroup.org/onlinepubs/7908799/xsh/strchr.html>.
+///
+/// # Safety
+/// The caller is required to ensure that `s` is a valid pointer to a buffer
+/// containing at least one nul value. The pointed-to buffer must not be
+/// modified for the duration of the call.
 #[no_mangle]
 pub unsafe extern "C" fn strchr(mut s: *const c_char, c: c_int) -> *mut c_char {
-    let c = c as c_char;
-    while *s != 0 {
-        if *s == c {
-            return s as *mut c_char;
-        }
-        s = s.offset(1);
-    }
-    ptr::null_mut()
+    let c_as_c_char = c as c_char;
+
+    // We iterate over non-mut references and thus need to coerce the
+    // resulting reference via a *const pointer before we can get our *mut.
+    // SAFETY: the caller is required to ensure that s points to a valid
+    // nul-terminated buffer.
+    let ptr: *const c_char =
+        match unsafe { NulTerminatedInclusive::new(s) }.find(|&&sc| sc == c_as_c_char) {
+            Some(sc_ref) => sc_ref,
+            None => ptr::null(),
+        };
+    ptr.cast_mut()
 }
 
 #[no_mangle]
diff --git a/src/header/wchar/mod.rs b/src/header/wchar/mod.rs
index 67e268e09..c50f12686 100644
--- a/src/header/wchar/mod.rs
+++ b/src/header/wchar/mod.rs
@@ -12,7 +12,7 @@ use crate::{
         time::*,
         wctype::*,
     },
-    iter::NulTerminated,
+    iter::{NulTerminated, NulTerminatedInclusive},
     platform::{self, types::*, ERRNO},
 };
 
@@ -435,17 +435,24 @@ pub unsafe extern "C" fn wcscat(ws1: *mut wchar_t, ws2: *const wchar_t) -> *mut
     wcsncat(ws1, ws2, usize::MAX)
 }
 
+/// See <https://pubs.opengroup.org/onlinepubs/7908799/xsh/wcschr.html>.
+///
+/// # Safety
+/// The caller is required to ensure that `ws` is a valid pointer to a buffer
+/// containing at least one nul value. The pointed-to buffer must not be
+/// modified for the duration of the call.
 #[no_mangle]
 pub unsafe extern "C" fn wcschr(ws: *const wchar_t, wc: wchar_t) -> *mut wchar_t {
-    let mut i = 0;
-    loop {
-        if *ws.add(i) == wc {
-            return ws.add(i) as *mut wchar_t;
-        } else if *ws.add(i) == 0 {
-            return ptr::null_mut();
-        }
-        i += 1;
-    }
+    // We iterate over non-mut references and thus need to coerce the
+    // resulting reference via a *const pointer before we can get our *mut.
+    // SAFETY: the caller is required to ensure that ws points to a valid
+    // nul-terminated buffer.
+    let ptr: *const wchar_t =
+        match unsafe { NulTerminatedInclusive::new(ws) }.find(|&&wsc| wsc == wc) {
+            Some(wsc_ref) => wsc_ref,
+            None => ptr::null(),
+        };
+    ptr.cast_mut()
 }
 
 #[no_mangle]
diff --git a/src/iter.rs b/src/iter.rs
index 5b9f7cb1c..c383d81f7 100644
--- a/src/iter.rs
+++ b/src/iter.rs
@@ -77,6 +77,59 @@ impl<'a, T: Zero> NulTerminated<'a, T> {
     }
 }
 
+/// An iterator over a nul-terminated buffer, including the terminating nul.
+///
+/// Similar to [`NulTerminated`], but includes the terminating nul.
+pub struct NulTerminatedInclusive<'a, T: Zero> {
+    ptr_opt: Option<NonNull<T>>,
+    phantom: PhantomData<&'a T>,
+}
+
+impl<'a, T: Zero> Iterator for NulTerminatedInclusive<'a, T> {
+    type Item = &'a T;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if let Some(old_ptr) = self.ptr_opt {
+            // SAFETY: the caller is required to ensure a valid pointer to a
+            // 0-terminated buffer is provided, and the zero-check below
+            // ensures that iteration and pointer increments will stop in
+            // time.
+            let val_ref = unsafe { old_ptr.as_ref() };
+            self.ptr_opt = if val_ref.is_zero() {
+                None
+            } else {
+                // SAFETY: if a terminating nul value has been encountered,
+                // this will not be called
+                Some(unsafe { old_ptr.add(1) })
+            };
+            Some(val_ref)
+        } else {
+            None
+        }
+    }
+}
+
+impl<'a, T: Zero> NulTerminatedInclusive<'a, T> {
+    /// Constructs a new iterator, starting at `ptr`, yielding elements of
+    /// type `&T` up to and including the terminating nul.
+    ///
+    /// The iterator returns `None` after the terminating nul has been
+    /// encountered.
+    ///
+    /// # Safety
+    /// The provided pointer must be a valid pointer to a buffer of contiguous
+    /// elements of type `T`, and the value 0 must be present within the
+    /// buffer at or after `ptr` (not necessarily at the end). The buffer must
+    /// not be written to for the lifetime of the iterator.
+    pub unsafe fn new(ptr: *const T) -> Self {
+        NulTerminatedInclusive {
+            // NonNull can only wrap only *mut pointers...
+            ptr_opt: NonNull::new(ptr.cast_mut()),
+            phantom: PhantomData,
+        }
+    }
+}
+
 /// A zipped iterator mapping an input iterator to an "out" pointer.
 ///
 /// This is intended to allow safe, iterative writing to an "out pointer".
diff --git a/tests/Makefile b/tests/Makefile
index 0ac748401..dcbcdf761 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -127,6 +127,7 @@ EXPECT_NAMES=\
 	wchar/wcrtomb \
 	wchar/wcpcpy \
 	wchar/wcpncpy \
+	wchar/wcschr \
 	wchar/wcscspn \
 	wchar/wcsdup \
 	wchar/wcsrchr \
diff --git a/tests/expected/bins_static/string/strchr.stdout b/tests/expected/bins_static/string/strchr.stdout
index 02c9cb955..43f3dcfdf 100644
--- a/tests/expected/bins_static/string/strchr.stdout
+++ b/tests/expected/bins_static/string/strchr.stdout
@@ -1,3 +1,4 @@
 ello
 ld
-1
+
+(nil)
diff --git a/tests/expected/bins_static/wchar/wcschr.stderr b/tests/expected/bins_static/wchar/wcschr.stderr
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/expected/bins_static/wchar/wcschr.stdout b/tests/expected/bins_static/wchar/wcschr.stdout
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/string/strchr.c b/tests/string/strchr.c
index ebe973a51..51f4b7aa5 100644
--- a/tests/string/strchr.c
+++ b/tests/string/strchr.c
@@ -6,5 +6,6 @@
 int main(void) {
 	printf("%s\n", strchr("hello", 'e')); // should be ello
 	printf("%s\n", strchr("world", 'l')); // should be ld
-	printf("%i\n", strchr("world", 0) == NULL); // should be 1
+	printf("%s\n", strchr("world", '\0')); // should be an empty, nul-terminated string
+	printf("%p\n", strchr("world", 'x')); // should be a null pointer
 }
diff --git a/tests/wchar/wcschr.c b/tests/wchar/wcschr.c
new file mode 100644
index 000000000..2b396d476
--- /dev/null
+++ b/tests/wchar/wcschr.c
@@ -0,0 +1,13 @@
+#include <assert.h>
+#include <wchar.h>
+
+int main(void) {
+    wchar_t *haystack = L"Hello World!";
+
+    assert(wcschr(haystack, L'H') == haystack);
+    assert(wcschr(haystack, L'W') == &haystack[6]);
+    assert(wcschr(haystack, L'\0') == &haystack[12]); // the terminating nul is considered part of the string
+    assert(wcschr(haystack, L'X') == NULL);
+
+    return 0;
+}
-- 
GitLab