From b66df46f33d8c8d0e12ffb50c7dd18688173b85d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Mei=C3=9Fner?= <florian.meissner@mailbox.org>
Date: Mon, 5 Jun 2023 23:04:25 +0000
Subject: [PATCH] strtof(), strtod(): handle NaN and Infinity

---
 src/macros.rs                                 | 129 ++++++++++--------
 tests/expected/bins_static/stdlib/atof.stdout |   1 +
 .../expected/bins_static/stdlib/strtod.stdout |  25 ++++
 tests/stdlib/atof.c                           |   3 +
 tests/stdlib/strtod.c                         |  11 ++
 5 files changed, 111 insertions(+), 58 deletions(-)

diff --git a/src/macros.rs b/src/macros.rs
index 88383d659..01f7f3461 100644
--- a/src/macros.rs
+++ b/src/macros.rs
@@ -227,19 +227,19 @@ macro_rules! strto_impl {
         num
     }};
 }
+
 #[macro_export]
 macro_rules! strto_float_impl {
     ($type:ident, $s:expr, $endptr:expr) => {{
         let mut s = $s;
         let endptr = $endptr;
 
-        // TODO: Handle named floats: NaN, Inf...
-
         while ctype::isspace(*s as c_int) != 0 {
             s = s.offset(1);
         }
 
         let mut result: $type = 0.0;
+        let mut exponent: Option<$type> = None;
         let mut radix = 10;
 
         let result_sign = match *s as u8 {
@@ -254,75 +254,88 @@ macro_rules! strto_float_impl {
             _ => 1.0,
         };
 
-        if *s as u8 == b'0' && *s.offset(1) as u8 == b'x' {
-            s = s.offset(2);
-            radix = 16;
-        }
-
-        while let Some(digit) = (*s as u8 as char).to_digit(radix) {
-            result *= radix as $type;
-            result += digit as $type;
-            s = s.offset(1);
-        }
-
-        if *s as u8 == b'.' {
-            s = s.offset(1);
+        let rust_s = CStr::from_ptr(s).to_string_lossy();
+
+        // detect NaN, Inf
+        if rust_s.to_lowercase().starts_with("inf") {
+            result = $type::INFINITY;
+            s = s.offset(3);
+        } else if rust_s.to_lowercase().starts_with("nan") {
+            // we cannot signal negative NaN in LLVM backed languages
+            // https://github.com/rust-lang/rust/issues/73328 , https://github.com/rust-lang/rust/issues/81261
+            result = $type::NAN;
+            s = s.offset(3);
+        } else {
+            if *s as u8 == b'0' && *s.offset(1) as u8 == b'x' {
+                s = s.offset(2);
+                radix = 16;
+            }
 
-            let mut i = 1.0;
             while let Some(digit) = (*s as u8 as char).to_digit(radix) {
-                i *= radix as $type;
-                result += digit as $type / i;
+                result *= radix as $type;
+                result += digit as $type;
                 s = s.offset(1);
             }
-        }
-
-        let s_before_exponent = s;
 
-        let exponent = match (*s as u8, radix) {
-            (b'e' | b'E', 10) | (b'p' | b'P', 16) => {
+            if *s as u8 == b'.' {
                 s = s.offset(1);
 
-                let is_exponent_positive = match *s as u8 {
-                    b'-' => {
-                        s = s.offset(1);
-                        false
-                    }
-                    b'+' => {
-                        s = s.offset(1);
-                        true
-                    }
-                    _ => true,
-                };
-
-                // Exponent digits are always in base 10.
-                if (*s as u8 as char).is_digit(10) {
-                    let mut exponent_value = 0;
-
-                    while let Some(digit) = (*s as u8 as char).to_digit(10) {
-                        exponent_value *= 10;
-                        exponent_value += digit;
-                        s = s.offset(1);
-                    }
+                let mut i = 1.0;
+                while let Some(digit) = (*s as u8 as char).to_digit(radix) {
+                    i *= radix as $type;
+                    result += digit as $type / i;
+                    s = s.offset(1);
+                }
+            }
 
-                    let exponent_base = match radix {
-                        10 => 10u128,
-                        16 => 2u128,
-                        _ => unreachable!(),
+            let s_before_exponent = s;
+
+            exponent = match (*s as u8, radix) {
+                (b'e' | b'E', 10) | (b'p' | b'P', 16) => {
+                    s = s.offset(1);
+
+                    let is_exponent_positive = match *s as u8 {
+                        b'-' => {
+                            s = s.offset(1);
+                            false
+                        }
+                        b'+' => {
+                            s = s.offset(1);
+                            true
+                        }
+                        _ => true,
                     };
 
-                    if is_exponent_positive {
-                        Some(exponent_base.pow(exponent_value) as $type)
+                    // Exponent digits are always in base 10.
+                    if (*s as u8 as char).is_digit(10) {
+                        let mut exponent_value = 0;
+
+                        while let Some(digit) = (*s as u8 as char).to_digit(10) {
+                            exponent_value *= 10;
+                            exponent_value += digit;
+                            s = s.offset(1);
+                        }
+
+                        let exponent_base = match radix {
+                            10 => 10u128,
+                            16 => 2u128,
+                            _ => unreachable!(),
+                        };
+
+                        if is_exponent_positive {
+                            Some(exponent_base.pow(exponent_value) as $type)
+                        } else {
+                            Some(1.0 / (exponent_base.pow(exponent_value) as $type))
+                        }
                     } else {
-                        Some(1.0 / (exponent_base.pow(exponent_value) as $type))
+                        // Exponent had no valid digits after 'e'/'p' and '+'/'-', rollback
+                        s = s_before_exponent;
+                        None
                     }
-                } else {
-                    // Exponent had no valid digits after 'e'/'p' and '+'/'-', rollback
-                    s = s_before_exponent;
-                    None
                 }
-            }
-            _ => None,
-        };
+                _ => None,
+            };
+        }
 
         if !endptr.is_null() {
             // This is stupid, but apparently strto* functions want
diff --git a/tests/expected/bins_static/stdlib/atof.stdout b/tests/expected/bins_static/stdlib/atof.stdout
index dd1835a6f..3227d8538 100644
--- a/tests/expected/bins_static/stdlib/atof.stdout
+++ b/tests/expected/bins_static/stdlib/atof.stdout
@@ -1 +1,2 @@
 -3.140000
+inf
diff --git a/tests/expected/bins_static/stdlib/strtod.stdout b/tests/expected/bins_static/stdlib/strtod.stdout
index 5ed8e2ec1..a92def04f 100644
--- a/tests/expected/bins_static/stdlib/strtod.stdout
+++ b/tests/expected/bins_static/stdlib/strtod.stdout
@@ -169,3 +169,28 @@ d: -49999999999999998431683053958987776.000000 Endptr: ""
 d: -500000000000000021210318687008980992.000000 Endptr: ""
 d: -4999999999999999769381329101060571136.000000 Endptr: ""
 d: -49999999999999998874404911728017014784.000000 Endptr: ""
+d: -0.000000 Endptr: ""
+d: inf Endptr: ""
+d: inf Endptr: ""
+d: inf Endptr: ""
+d: inf Endptr: " foobarbaz"
+d: inf Endptr: ""
+d: inf Endptr: ""
+d: inf Endptr: ""
+d: inf Endptr: " foobarbaz"
+d: -inf Endptr: ""
+d: -inf Endptr: ""
+d: -inf Endptr: ""
+d: -inf Endptr: " foobarbaz"
+d: nan Endptr: "0.1e5"
+d: nan Endptr: "-37"
+d: nan Endptr: "1.05"
+d: nan Endptr: " foo bar baz"
+d: nan Endptr: "0.1e5"
+d: nan Endptr: "-37"
+d: nan Endptr: "1.05"
+d: nan Endptr: " foo bar baz"
+d: nan Endptr: "0.1e5"
+d: nan Endptr: "-37"
+d: nan Endptr: "1.05"
+d: nan Endptr: " foo bar baz"
diff --git a/tests/stdlib/atof.c b/tests/stdlib/atof.c
index ec945ef69..f9783467b 100644
--- a/tests/stdlib/atof.c
+++ b/tests/stdlib/atof.c
@@ -6,4 +6,7 @@
 int main(void) {
     double d = atof("-3.14");
     printf("%f\n", d);
+
+    d = atof("INF");
+    printf("%f\n", d);
 }
diff --git a/tests/stdlib/strtod.c b/tests/stdlib/strtod.c
index fa1013ed0..44fd371ec 100644
--- a/tests/stdlib/strtod.c
+++ b/tests/stdlib/strtod.c
@@ -61,6 +61,17 @@ int main(void) {
         "-0.5e25", "-0.5e26", "-0.5e27", "-0.5e28", "-0.5e29",
         "-0.5e30", "-0.5e31", "-0.5e32", "-0.5e33", "-0.5e34",
         "-0.5e35", "-0.5e36", "-0.5e37", "-0.5e38",
+
+        "-0",
+
+        "INF", "inf", "iNf", "Inf foobarbaz",
+        "+INF", "+inf", "+iNf", "+Inf foobarbaz",
+        "-INF", "-inf", "-iNf", "-Inf foobarbaz",
+
+        "NaN0.1e5", "nan-37", "nAn1.05", "Nan foo bar baz",
+        "+NaN0.1e5", "+nan-37", "+nAn1.05", "+Nan foo bar baz",
+        "-NaN0.1e5", "-nan-37", "-nAn1.05", "-Nan foo bar baz",
+
     };
     for (int i = 0; i < sizeof(inputs) / sizeof(char*); i += 1) {
         d = strtod(inputs[i], &endptr);
-- 
GitLab