diff --git a/src/header/string/mod.rs b/src/header/string/mod.rs index 8e8169be1a9af9f51a3a28a8cda9a2a6eb9fae8a..9a6e4ddd225e182f39b5771c9dbd1a3f8777af23 100644 --- a/src/header/string/mod.rs +++ b/src/header/string/mod.rs @@ -348,25 +348,35 @@ pub unsafe extern "C" fn strspn(s1: *const c_char, s2: *const c_char) -> size_t inner_strspn(s1, s2, true) } -#[no_mangle] -pub unsafe extern "C" fn strstr(s1: *const c_char, s2: *const c_char) -> *mut c_char { - let mut i = 0; - while *s1.offset(i) != 0 { - let mut j = 0; - while *s2.offset(j) != 0 && *s1.offset(j + i) != 0 { - if *s2.offset(j) != *s1.offset(j + i) { - break; +unsafe fn inner_strstr(mut haystack: *const c_char, mut needle: *const c_char, mask: c_char) -> *mut c_char { + while *haystack != 0 { + let mut i = 0; + loop { + if *needle.offset(i) == 0 { + // We reached the end of the needle, everything matches this far + return haystack as *mut c_char; } - j += 1; - if *s2.offset(j) == 0 { - return s1.offset(i) as *mut c_char; + if *haystack.offset(i) & mask != *needle.offset(i) & mask { + break; } + + i += 1; } - i += 1; + + haystack = haystack.offset(1); } ptr::null_mut() } +#[no_mangle] +pub unsafe extern "C" fn strstr(haystack: *const c_char, needle: *const c_char) -> *mut c_char { + inner_strstr(haystack, needle, !0) +} +#[no_mangle] +pub unsafe extern "C" fn strcasestr(haystack: *const c_char, needle: *const c_char) -> *mut c_char { + inner_strstr(haystack, needle, !32) +} + #[no_mangle] pub extern "C" fn strtok(s1: *mut c_char, delimiter: *const c_char) -> *mut c_char { static mut HAYSTACK: *mut c_char = ptr::null_mut(); diff --git a/tests/expected/string/strstr.stdout b/tests/expected/string/strstr.stdout index e978edff7f142deb5af595baa3d297e01180e911..c5a1f6e4463f8111c6f17ee62293508da281848d 100644 --- a/tests/expected/string/strstr.stdout +++ b/tests/expected/string/strstr.stdout @@ -1,3 +1,5 @@ rust libc we trust NULL +NULL +RUST diff --git a/tests/string/strstr.c b/tests/string/strstr.c index 6a074d18d83ca3c3fec9aa676866f171c84335de..cc23f070fd365170a76548f5e02f66d7b196e7ed 100644 --- a/tests/string/strstr.c +++ b/tests/string/strstr.c @@ -2,17 +2,11 @@ #include <stdio.h> int main(int argc, char* argv[]) { - // should be "rust" - char* res1 = strstr("In relibc we trust", "rust"); - printf("%s\n", (res1) ? res1 : "NULL"); - - // should be "libc we trust" - char* res2 = strstr("In relibc we trust", "libc"); - printf("%s\n", (res2) ? res2 : "NULL"); - - // should be "NULL" - char* res3 = strstr("In relibc we trust", "bugs"); - printf("%s\n", (res3) ? res3 : "NULL"); + printf("%s\n", strstr("In relibc we trust", "rust")); + printf("%s\n", strstr("In relibc we trust", "libc")); + printf("%s\n", strstr("In relibc we trust", "bugs")); + printf("%s\n", strstr("IN RELIBC WE TRUST", "rust")); + printf("%s\n", strcasestr("IN RELIBC WE TRUST", "rust")); return 0; }