diff --git a/src/stdlib/src/lib.rs b/src/stdlib/src/lib.rs index 6b6645cf2ab8caa7fcdcc5cd1acd130533dd1eb0..8ff2dada7c1da2ca92bce377096a93084c94df9d 100644 --- a/src/stdlib/src/lib.rs +++ b/src/stdlib/src/lib.rs @@ -13,7 +13,7 @@ extern crate time; extern crate unistd; extern crate wchar; -use core::{ptr, slice, str}; +use core::{iter, mem, ptr, slice, str}; use rand::distributions::Alphanumeric; use rand::prng::XorShiftRng; use rand::rngs::JitterRng; @@ -400,53 +400,64 @@ pub unsafe extern "C" fn mbtowc(pwc: *mut wchar_t, s: *const c_char, n: size_t) mbrtowc(pwc, s, n, &mut state) as c_int } -#[no_mangle] -pub extern "C" fn mktemp(name: *mut c_char) -> *mut c_char { - use core::iter; - use core::mem; - let len = unsafe { strlen(name) }; - if len < 6 { +fn inner_mktemp<T, F>(name: *mut c_char, suffix_len: c_int, mut attempt: F) -> Option<T> + where F: FnMut() -> Option<T> +{ + let len = unsafe { strlen(name) } as c_int; + + if len < 6 || suffix_len > len - 6 { unsafe { platform::errno = errno::EINVAL }; - unsafe { *name = 0 }; - return name; + return None; } - for i in len - 6..len { + + for i in (len - suffix_len - 6)..(len - suffix_len) { if unsafe { *name.offset(i as isize) } != b'X' as c_char { unsafe { platform::errno = errno::EINVAL }; - unsafe { *name = 0 }; - return name; + return None; } } let mut rng = JitterRng::new_with_timer(get_nstime); rng.test_timer(); - let mut retries = 100; - loop { - let char_iter = iter::repeat(()).map(|()| rng.sample(Alphanumeric)).take(6); + for _ in 0..100 { + let char_iter = iter::repeat(()).map(|()| rng.sample(Alphanumeric)).take(6).enumerate(); unsafe { - for (i, c) in char_iter.enumerate() { - *name.offset(len as isize - i as isize - 1) = c as c_char + for (i, c) in char_iter { + *name.offset((len as isize) - (suffix_len as isize) - (i as isize) - 1) = c as c_char } } + if let result @ Some(_) = attempt() { + return result; + } + } + + unsafe { + platform::errno = errno::EEXIST + } + + None +} + +#[no_mangle] +pub extern "C" fn mktemp(name: *mut c_char) -> *mut c_char { + if inner_mktemp(name, 0, || { unsafe { let mut st: stat = mem::uninitialized(); - if platform::stat(name, &mut st) != 0 { - if platform::errno != ENOENT { - *name = 0; - } - return name; - } + let ret = if platform::stat(name, &mut st) != 0 && platform::errno == ENOENT { + Some(()) + } else { + None + }; mem::forget(st); + ret } - retries = retries - 1; - if retries == 0 { - break; + }).is_none() { + unsafe { + *name = 0; } } - unsafe { platform::errno = EEXIST }; - unsafe { *name = 0 }; name } @@ -460,56 +471,32 @@ fn get_nstime() -> u64 { #[no_mangle] pub extern "C" fn mkostemps(name: *mut c_char, suffix_len: c_int, mut flags: c_int) -> c_int { - use core::iter; - let len = unsafe { strlen(name) } as c_int; - - if len < 6 || suffix_len > len - 6 { - unsafe { platform::errno = errno::EINVAL }; - return -1; - } + flags &= !O_ACCMODE; + flags |= O_RDWR | O_CREAT | O_EXCL; - for i in (len - suffix_len - 6)..(len - suffix_len) { - if unsafe { *name.offset(i as isize) } != b'X' as c_char { - unsafe { platform::errno = errno::EINVAL }; - return -1; - } - } - - flags -= flags & O_ACCMODE; - - let mut rng = JitterRng::new_with_timer(get_nstime); - rng.test_timer(); - - for _retries in 0..100 { - let char_iter = iter::repeat(()).map(|()| rng.sample(Alphanumeric)).take(6).enumerate(); - unsafe { - for (i, c) in char_iter { - *name.offset((len as isize) - (suffix_len as isize) - (i as isize) - 1) = c as c_char - } - } - - let fd = platform::open(name, flags | O_RDWR | O_CREAT | O_EXCL, 0600); + inner_mktemp(name, suffix_len, || { + let fd = platform::open(name, flags, 0600); if fd >= 0 { - return fd; - } - - unsafe { platform::errno = errno::EEXIST }; - } - - unsafe { - for i in 0..6 { - *name.offset((len as isize) - (suffix_len as isize) - (i as isize) - 1) = b'X' as c_char; + Some(fd) + } else { + None } - } - - -1 + }).unwrap_or(-1) } #[no_mangle] pub extern "C" fn mkstemp(name: *mut c_char) -> c_int { mkostemps(name, 0, 0) } +#[no_mangle] +pub extern "C" fn mkostemp(name: *mut c_char, flags: c_int) -> c_int { + mkostemps(name, 0, flags) +} +#[no_mangle] +pub extern "C" fn mkstemps(name: *mut c_char, suffix_len: c_int) -> c_int { + mkostemps(name, suffix_len, 0) +} // #[no_mangle] pub extern "C" fn mrand48() -> c_long {