diff --git a/src/string/src/lib.rs b/src/string/src/lib.rs index c272d9211a5714e6ba3aac694b91cb268256650a..c7178b5c3c160d755444b5cb5433ed1a7faa8dae 100644 --- a/src/string/src/lib.rs +++ b/src/string/src/lib.rs @@ -156,25 +156,17 @@ pub unsafe extern "C" fn strncat(s1: *mut c_char, s2: *const c_char, n: usize) - #[no_mangle] pub unsafe extern "C" fn strncmp(s1: *const c_char, s2: *const c_char, n: usize) -> c_int { - let s1 = platform::c_str_n(s1, n); - let s2 = platform::c_str_n(s2, n); - - let min_len = n.min(s1.len()).min(s2.len()); - for i in 0..min_len { - let val = s1[i] - s2[i]; - if val != 0 { - return val as c_int; + let s1 = core::slice::from_raw_parts(s1 as *const c_uchar, n); + let s2 = core::slice::from_raw_parts(s2 as *const c_uchar, n); + + for (&a, &b) in s1.iter().zip(s2.iter()) { + let val = (a as c_int) - (b as c_int); + if a != b || a == 0 { + return val; } } - // we can't just check for the NUL byte in the loop as c_str_n() removes it - if s1.len() > s2.len() { - s1[min_len] as c_int - } else if s1.len() < s2.len() { - -(s2[min_len] as c_int) - } else { - 0 - } + 0 } #[no_mangle] diff --git a/tests/.gitignore b/tests/.gitignore index 620d8c72440cdb7d870a79e4ce161596092d3106..cd04e5ea8e611a129ed17a1db4f718a4441bc8d3 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -27,5 +27,6 @@ /rmdir /setid /stdlib/strtol +/string/strncmp /unlink /write diff --git a/tests/Makefile b/tests/Makefile index ae73d603b19f36217f1cd5eeb74f99d01cfbf736..81c5bfe1f147c4dd698fbd925db5ad1935fb347a 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -21,7 +21,8 @@ BINS=\ rmdir \ setid \ sleep \ - stdlib/strtol \ + stdlib/strtol \ + string/strncmp \ unlink \ write diff --git a/tests/string/strncmp.c b/tests/string/strncmp.c new file mode 100644 index 0000000000000000000000000000000000000000..efb2bea9783af619a6bdf1995ab3b4feaff6b813 --- /dev/null +++ b/tests/string/strncmp.c @@ -0,0 +1,13 @@ +#include <string.h> +#include <stdio.h> + +int main(int argc, char* argv[]) { + printf("%d\n", strncmp("a", "aa", 2)); + printf("%d\n", strncmp("a", "aä", 2)); + printf("%d\n", strncmp("\xFF", "\xFE", 2)); + printf("%d\n", strncmp("", "\xFF", 1)); + printf("%d\n", strncmp("a", "c", 1)); + printf("%d\n", strncmp("a", "a", 2)); + + return 0; +}