From 0a0aec502b701a2b62f66d4e312c550149ccb426 Mon Sep 17 00:00:00 2001
From: jD91mZM2 <me@krake.one>
Date: Tue, 23 Apr 2019 07:34:06 +0200
Subject: [PATCH] Fix #140

---
 src/header/stdio/printf.rs | 385 ++++++++++++++++++++-----------------
 tests/stdio/printf.c       |   2 -
 2 files changed, 212 insertions(+), 175 deletions(-)

diff --git a/src/header/stdio/printf.rs b/src/header/stdio/printf.rs
index 1c9dcb01..b825868b 100644
--- a/src/header/stdio/printf.rs
+++ b/src/header/stdio/printf.rs
@@ -1,7 +1,8 @@
 use alloc::string::String;
 use alloc::string::ToString;
+use alloc::collections::BTreeMap;
 use alloc::vec::Vec;
-use core::ffi::VaList as va_list;
+use core::ffi::VaList;
 use core::ops::Range;
 use core::{fmt, slice};
 use io::{self, Write};
@@ -9,7 +10,15 @@ use io::{self, Write};
 use platform;
 use platform::types::*;
 
-#[derive(PartialEq, Eq)]
+//  ____        _ _                 _       _
+// | __ )  ___ (_) | ___ _ __ _ __ | | __ _| |_ ___ _
+// |  _ \ / _ \| | |/ _ \ '__| '_ \| |/ _` | __/ _ (_)
+// | |_) | (_) | | |  __/ |  | |_) | | (_| | ||  __/_
+// |____/ \___/|_|_|\___|_|  | .__/|_|\__,_|\__\___(_)
+//                           |_|
+
+
+#[derive(Clone, Copy, PartialEq, Eq)]
 enum IntKind {
     Byte,
     Short,
@@ -20,6 +29,7 @@ enum IntKind {
     PtrDiff,
     Size,
 }
+#[derive(Clone, Copy, PartialEq, Eq)]
 enum FmtKind {
     Percent,
 
@@ -42,144 +52,113 @@ enum Number {
     Next
 }
 impl Number {
-    unsafe fn resolve(&self, ap: &mut BufferedVaList) -> usize {
-        match *self {
-            Number::Static(num) => num,
-            Number::Index(i) => ap.index(ArgKind::Int, i).int as usize,
-            Number::Next => ap.next(ArgKind::Int).int as usize,
+    unsafe fn resolve(&self, varargs: &mut VaListCache, ap: &mut VaList) -> usize {
+        let arg = match *self {
+            Number::Static(num) => return num,
+            Number::Index(i) => varargs.get(i-1, ap, None),
+            Number::Next => {
+                let i = varargs.i;
+                varargs.i += 1;
+                varargs.get(i, ap, None)
+            }
+        };
+        match arg {
+            VaArg::c_char(i) => i as usize,
+            VaArg::c_double(i) => i as usize,
+            VaArg::c_int(i) => i as usize,
+            VaArg::c_long(i) => i as usize,
+            VaArg::c_longlong(i) => i as usize,
+            VaArg::c_short(i) => i as usize,
+            VaArg::intmax_t(i) => i as usize,
+            VaArg::pointer(i) => i as usize,
+            VaArg::ptrdiff_t(i) => i as usize,
+            VaArg::ssize_t(i) => i as usize
         }
     }
 }
-enum ArgKind {
-    Byte,
-    Short,
-    Int,
-    Long,
-    LongLong,
-    PtrDiff,
-    Size,
-    IntMax,
-    Double,
-    CharPtr,
-    VoidPtr,
-    IntPtr,
-    ArgDefault,
-}
-
 #[derive(Clone, Copy)]
-union VaArg {
-    byte: c_char,
-    short: c_short,
-    int: c_int,
-    long: c_long,
-    longlong: c_longlong,
-    ptrdiff: ptrdiff_t,
-    size: ssize_t,
-    intmax: intmax_t,
-    double: c_double,
-    char_ptr: *const c_char,
-    void_ptr: *const c_void,
-    int_ptr: *mut c_int,
-    arg_default: usize,
-}
-
-struct BufferedVaList<'a> {
-    list: va_list<'a>,
-    buf: Vec<VaArg>,
-    i: usize,
+enum VaArg {
+    c_char(c_char),
+    c_double(c_double),
+    c_int(c_int),
+    c_long(c_long),
+    c_longlong(c_longlong),
+    c_short(c_short),
+    intmax_t(intmax_t),
+    pointer(*const c_void),
+    ptrdiff_t(ptrdiff_t),
+    ssize_t(ssize_t)
 }
-
-impl<'a> BufferedVaList<'a> {
-    fn new(list: va_list<'a>) -> Self {
-        Self {
-            list,
-            buf: Vec::new(),
-            i: 0,
+impl VaArg {
+    unsafe fn arg_from(arg: &PrintfArg, ap: &mut VaList) -> VaArg {
+        // Per the C standard using va_arg with a type with a size
+        // less than that of an int for integers and double for floats
+        // is invalid. As a result any arguments smaller than an int or
+        // double passed to a function will be promoted to the smallest
+        // possible size. The VaList::arg function will handle this
+        // automagically.
+
+        match (arg.fmtkind, arg.intkind) {
+            (FmtKind::Percent, _) => panic!("Can't call arg_from on %"),
+
+            (FmtKind::Char, _) |
+            (FmtKind::Unsigned, IntKind::Byte) |
+            (FmtKind::Signed, IntKind::Byte) => VaArg::c_char(ap.arg::<c_char>()),
+            (FmtKind::Unsigned, IntKind::Short) |
+            (FmtKind::Signed, IntKind::Short) => VaArg::c_short(ap.arg::<c_short>()),
+            (FmtKind::Unsigned, IntKind::Int) |
+            (FmtKind::Signed, IntKind::Int) => VaArg::c_int(ap.arg::<c_int>()),
+            (FmtKind::Unsigned, IntKind::Long) |
+            (FmtKind::Signed, IntKind::Long) => VaArg::c_long(ap.arg::<c_long>()),
+            (FmtKind::Unsigned, IntKind::LongLong) |
+            (FmtKind::Signed, IntKind::LongLong) => VaArg::c_longlong(ap.arg::<c_longlong>()),
+            (FmtKind::Unsigned, IntKind::IntMax) |
+            (FmtKind::Signed, IntKind::IntMax) => VaArg::intmax_t(ap.arg::<intmax_t>()),
+            (FmtKind::Unsigned, IntKind::PtrDiff) |
+            (FmtKind::Signed, IntKind::PtrDiff) => VaArg::ptrdiff_t(ap.arg::<ptrdiff_t>()),
+            (FmtKind::Unsigned, IntKind::Size) |
+            (FmtKind::Signed, IntKind::Size) => VaArg::ssize_t(ap.arg::<ssize_t>()),
+
+            (FmtKind::AnyNotation, _) | (FmtKind::Decimal, _) | (FmtKind::Scientific, _)
+                => VaArg::c_double(ap.arg::<c_double>()),
+
+            (FmtKind::GetWritten, _) | (FmtKind::Pointer, _) | (FmtKind::String, _)
+                => VaArg::pointer(ap.arg::<*const c_void>()),
         }
     }
-
-    unsafe fn get_arg(&mut self, ty: ArgKind) -> VaArg {
-        match ty {
-            ArgKind::Byte => VaArg {
-                byte: self.list.arg::<c_char>(),
-            },
-            ArgKind::Short => VaArg {
-                short: self.list.arg::<c_short>(),
-            },
-            ArgKind::Int => VaArg {
-                int: self.list.arg::<c_int>(),
-            },
-            ArgKind::Long => VaArg {
-                long: self.list.arg::<c_long>(),
-            },
-            ArgKind::LongLong => VaArg {
-                longlong: self.list.arg::<c_longlong>(),
-            },
-            ArgKind::PtrDiff => VaArg {
-                ptrdiff: self.list.arg::<ptrdiff_t>(),
-            },
-            ArgKind::Size => VaArg {
-                size: self.list.arg::<ssize_t>(),
-            },
-            ArgKind::IntMax => VaArg {
-                intmax: self.list.arg::<intmax_t>(),
-            },
-            ArgKind::Double => VaArg {
-                double: self.list.arg::<c_double>(),
-            },
-            ArgKind::CharPtr => VaArg {
-                char_ptr: self.list.arg::<*const c_char>(),
-            },
-            ArgKind::VoidPtr => VaArg {
-                void_ptr: self.list.arg::<*const c_void>(),
-            },
-            ArgKind::IntPtr => VaArg {
-                int_ptr: self.list.arg::<*mut c_int>(),
-            },
-            ArgKind::ArgDefault => VaArg {
-                arg_default: self.list.arg::<usize>(),
-            },
+}
+#[derive(Default)]
+struct VaListCache {
+    args: Vec<VaArg>,
+    i: usize
+}
+impl VaListCache {
+    unsafe fn get(&mut self, i: usize, ap: &mut VaList, arg: Option<&PrintfArg>) -> VaArg {
+        if let Some(&arg) = self.args.get(i) {
+            return arg;
         }
-    }
-
-    unsafe fn get(&mut self, ty: ArgKind, i: Option<usize>) -> VaArg {
-        match i {
-            None => self.next(ty),
-            Some(i) => self.index(ty, i),
+        while self.args.len() < i {
+            // We can't POSSIBLY know the type if we reach this
+            // point. Reaching here means there are unused gaps in the
+            // arguments. Ultimately we'll have to settle down with
+            // defaulting to c_int.
+            self.args.push(VaArg::c_int(ap.arg::<c_int>()))
         }
+        self.args.push(match arg {
+            Some(arg) => VaArg::arg_from(arg, ap),
+            None => VaArg::c_int(ap.arg::<c_int>())
+        });
+        self.args[i]
     }
+}
 
-    unsafe fn next(&mut self, ty: ArgKind) -> VaArg {
-        if self.i >= self.buf.len() {
-            let arg = self.get_arg(ty);
-            self.buf.push(arg);
-        }
-        let arg = self.buf[self.i];
-        self.i += 1;
-        arg
-    }
+//  ___                 _                           _        _   _
+// |_ _|_ __ ___  _ __ | | ___ _ __ ___   ___ _ __ | |_ __ _| |_(_) ___  _ __  _
+//  | || '_ ` _ \| '_ \| |/ _ \ '_ ` _ \ / _ \ '_ \| __/ _` | __| |/ _ \| '_ \(_)
+//  | || | | | | | |_) | |  __/ | | | | |  __/ | | | || (_| | |_| | (_) | | | |_
+// |___|_| |_| |_| .__/|_|\___|_| |_| |_|\___|_| |_|\__\__,_|\__|_|\___/|_| |_(_)
+//               |_|
 
-    unsafe fn index(&mut self, ty: ArgKind, i: usize) -> VaArg {
-        if self.i >= self.buf.len() {
-            while self.buf.len() < (i - 1) {
-                // Getting a usize here most definitely isn't sane, however,
-                // there's no way to know the type!
-                // Just take this for example:
-                //
-                // printf("%*4$d\n", "hi", 0, "hello", 10);
-                //
-                // This chooses the width 10. How does it know the type of 0 and "hello"?
-                // It clearly can't.
-
-                let arg = self.get_arg(ArgKind::ArgDefault);
-                self.buf.push(arg);
-            }
-            let arg = self.get_arg(ty);
-            self.buf.push(arg);
-        }
-        self.buf[i - 1]
-    }
-}
 
 unsafe fn pop_int_raw(format: &mut *const u8) -> Option<usize> {
     let mut int = None;
@@ -342,6 +321,7 @@ fn fmt_float_normal<W: Write>(
 struct PrintfIter {
     format: *const u8
 }
+#[derive(Clone, Copy)]
 struct PrintfArg {
     index: Option<usize>,
     alternate: bool,
@@ -486,14 +466,40 @@ impl Iterator for PrintfIter {
     }
 }
 
-unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io::Result<c_int> {
+
+unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, mut ap: VaList) -> io::Result<c_int> {
     let w = &mut platform::CountingWriter::new(w);
-    let mut ap = BufferedVaList::new(ap);
 
     let iterator = PrintfIter {
         format: format as *const u8
     };
 
+    // Pre-fetch vararg types
+    let mut varargs = VaListCache::default();
+    let mut positional = BTreeMap::new();
+    // ^ NOTE: This depends on the sorted order, do not change to HashMap or whatever
+
+    for section in iterator {
+        let arg = match section {
+            Ok(PrintfFmt::Plain(text)) => continue,
+            Ok(PrintfFmt::Arg(arg)) => arg,
+            Err(()) => return Ok(-1)
+        };
+        if arg.fmtkind == FmtKind::Percent {
+            continue;
+        }
+        if let Some(i) = arg.index {
+            positional.insert(i-1, arg);
+        } else {
+            varargs.args.push(VaArg::arg_from(&arg, &mut ap));
+        }
+    }
+    // Make sure, in order, the positional arguments exist with the specified type
+    for (i, arg) in positional {
+        varargs.get(i, &mut ap, Some(&arg));
+    }
+
+    // Main loop
     for section in iterator {
         let arg = match section {
             Ok(PrintfFmt::Plain(text)) => {
@@ -503,40 +509,43 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
             Ok(PrintfFmt::Arg(arg)) => arg,
             Err(()) => return Ok(-1)
         };
-        let index = arg.index;
         let alternate = arg.alternate;
         let zero = arg.zero;
         let left = arg.left;
         let sign_reserve = arg.sign_reserve;
         let sign_always = arg.sign_always;
-        let min_width = arg.min_width.resolve(&mut ap);
-        let precision = arg.precision.map(|n| n.resolve(&mut ap));
-        let pad_space = arg.pad_space.resolve(&mut ap);
-        let pad_zero = arg.pad_zero.resolve(&mut ap);
+        let min_width = arg.min_width.resolve(&mut varargs, &mut ap);
+        let precision = arg.precision.map(|n| n.resolve(&mut varargs, &mut ap));
+        let pad_space = arg.pad_space.resolve(&mut varargs, &mut ap);
+        let pad_zero = arg.pad_zero.resolve(&mut varargs, &mut ap);
         let intkind = arg.intkind;
         let fmt = arg.fmt;
         let fmtkind = arg.fmtkind;
 
-        // Finally, type:
+        let index = arg.index
+            .map(|i| i-1)
+            .unwrap_or_else(|| if fmtkind == FmtKind::Percent {
+                0
+            } else {
+                let i = varargs.i;
+                varargs.i += 1;
+                i
+            });
+
         match fmtkind {
             FmtKind::Percent => w.write_all(&[b'%'])?,
             FmtKind::Signed => {
-                let string = match intkind {
-                    // Per the C standard using va_arg with a type with a size
-                    // less than that of an int for integers and double for floats
-                    // is invalid. As a result any arguments smaller than an int or
-                    // double passed to a function will be promoted to the smallest
-                    // possible size. The va_list::arg function will handle this
-                    // automagically.
-                    IntKind::Byte => ap.get(ArgKind::Byte, index).byte.to_string(),
-                    IntKind::Short => ap.get(ArgKind::Short, index).short.to_string(),
-                    // Types that will not be promoted
-                    IntKind::Int => ap.get(ArgKind::Int, index).int.to_string(),
-                    IntKind::Long => ap.get(ArgKind::Long, index).long.to_string(),
-                    IntKind::LongLong => ap.get(ArgKind::LongLong, index).longlong.to_string(),
-                    IntKind::PtrDiff => ap.get(ArgKind::PtrDiff, index).ptrdiff.to_string(),
-                    IntKind::Size => ap.get(ArgKind::Size, index).size.to_string(),
-                    IntKind::IntMax => ap.get(ArgKind::IntMax, index).intmax.to_string(),
+                let string = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_char(i) => i.to_string(),
+                    VaArg::c_double(i) => panic!("this should not be possible"),
+                    VaArg::c_int(i) => i.to_string(),
+                    VaArg::c_long(i) => i.to_string(),
+                    VaArg::c_longlong(i) => i.to_string(),
+                    VaArg::c_short(i) => i.to_string(),
+                    VaArg::intmax_t(i) => i.to_string(),
+                    VaArg::pointer(i) => (i as usize).to_string(),
+                    VaArg::ptrdiff_t(i) => i.to_string(),
+                    VaArg::ssize_t(i) => i.to_string()
                 };
                 let positive = !string.starts_with('-');
                 let zero = precision == Some(0) && string == "0";
@@ -573,18 +582,17 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
                 pad(w, left, b' ', final_len..pad_space)?;
             },
             FmtKind::Unsigned => {
-                let string = match intkind {
-                    // va_list will promote the following two to a c_int
-                    IntKind::Byte => fmt_int(fmt, ap.get(ArgKind::Byte, index).byte),
-                    IntKind::Short => fmt_int(fmt, ap.get(ArgKind::Short, index).short),
-                    IntKind::Int => fmt_int(fmt, ap.get(ArgKind::Int, index).int),
-                    IntKind::Long => fmt_int(fmt, ap.get(ArgKind::Long, index).long),
-                    IntKind::LongLong => {
-                        fmt_int(fmt, ap.get(ArgKind::LongLong, index).longlong)
-                    }
-                    IntKind::PtrDiff => fmt_int(fmt, ap.get(ArgKind::PtrDiff, index).ptrdiff),
-                    IntKind::Size => fmt_int(fmt, ap.get(ArgKind::Size, index).size),
-                    IntKind::IntMax => fmt_int(fmt, ap.get(ArgKind::IntMax, index).intmax),
+                let string = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_char(i) => fmt_int(fmt, i as c_uchar),
+                    VaArg::c_double(i) => panic!("this should not be possible"),
+                    VaArg::c_int(i) => fmt_int(fmt, i as c_uint),
+                    VaArg::c_long(i) => fmt_int(fmt, i as c_ulong),
+                    VaArg::c_longlong(i) => fmt_int(fmt, i as c_ulonglong),
+                    VaArg::c_short(i) => fmt_int(fmt, i as c_ushort),
+                    VaArg::intmax_t(i) => fmt_int(fmt, i as uintmax_t),
+                    VaArg::pointer(i) => fmt_int(fmt, i as usize),
+                    VaArg::ptrdiff_t(i) => fmt_int(fmt, i as size_t),
+                    VaArg::ssize_t(i) => fmt_int(fmt, i as size_t)
                 };
                 let zero = precision == Some(0) && string == "0";
 
@@ -628,20 +636,29 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
                 pad(w, left, b' ', final_len..pad_space)?;
             },
             FmtKind::Scientific => {
-                let mut float = ap.get(ArgKind::Double, index).double;
+                let mut float = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_double(i) => i,
+                    _ => panic!("this should not be possible")
+                };
                 let precision = precision.unwrap_or(6);
 
                 fmt_float_exp(w, fmt, None, false, precision, float, left, pad_space, pad_zero)?;
             },
             FmtKind::Decimal => {
-                let mut float = ap.get(ArgKind::Double, index).double;
+                let mut float = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_double(i) => i,
+                    _ => panic!("this should not be possible")
+                };
                 let precision = precision.unwrap_or(6);
 
                 fmt_float_normal(w, false, precision, float, left, pad_space, pad_zero)?;
             },
             FmtKind::AnyNotation => {
+                let mut float = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_double(i) => i,
+                    _ => panic!("this should not be possible")
+                };
                 let exp_fmt = b'E' | (fmt & 32);
-                let mut float = ap.get(ArgKind::Double, index).double;
                 let precision = precision.unwrap_or(6);
 
                 if !fmt_float_exp(
@@ -661,7 +678,10 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
             FmtKind::String => {
                 // if intkind == IntKind::Long || intkind == IntKind::LongLong, handle *const wchar_t
 
-                let ptr = ap.get(ArgKind::CharPtr, index).char_ptr;
+                let mut ptr = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::pointer(p) => p,
+                    _ => panic!("this should not be possible")
+                } as *const c_char;
 
                 if ptr.is_null() {
                     w.write_all(b"(null)")?;
@@ -680,14 +700,20 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
             FmtKind::Char => {
                 // if intkind == IntKind::Long || intkind == IntKind::LongLong, handle wint_t
 
-                let c = ap.get(ArgKind::Byte, index).byte;
+                let c = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::c_char(c) => c,
+                    _ => panic!("this should not be possible")
+                };
 
                 pad(w, !left, b' ', 1..pad_space)?;
                 w.write_all(&[c as u8])?;
                 pad(w, left, b' ', 1..pad_space)?;
             },
             FmtKind::Pointer => {
-                let ptr = ap.get(ArgKind::VoidPtr, index).int_ptr;
+                let mut ptr = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::pointer(p) => p,
+                    _ => panic!("this should not be possible")
+                };
 
                 let mut len = 1;
                 if ptr.is_null() {
@@ -709,14 +735,27 @@ unsafe fn inner_printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> io
                 pad(w, left, b' ', len..pad_space)?;
             },
             FmtKind::GetWritten => {
-                let ptr = ap.get(ArgKind::IntPtr, index).int_ptr;
-                *ptr = w.written as c_int;
+                let mut ptr = match varargs.get(index, &mut ap, Some(&arg)) {
+                    VaArg::pointer(p) => p,
+                    _ => panic!("this should not be possible")
+                };
+
+                match intkind {
+                    IntKind::Byte => *(ptr as *mut c_char) = w.written as c_char,
+                    IntKind::Short => *(ptr as *mut c_short) = w.written as c_short,
+                    IntKind::Int => *(ptr as *mut c_int) = w.written as c_int,
+                    IntKind::Long => *(ptr as *mut c_long) = w.written as c_long,
+                    IntKind::LongLong => *(ptr as *mut c_longlong) = w.written as c_longlong,
+                    IntKind::IntMax => *(ptr as *mut intmax_t) = w.written as intmax_t,
+                    IntKind::PtrDiff => *(ptr as *mut ptrdiff_t) = w.written as ptrdiff_t,
+                    IntKind::Size => *(ptr as *mut size_t) = w.written as size_t
+                }
             }
         }
     }
     Ok(w.written as c_int)
 }
 
-pub unsafe fn printf<W: Write>(w: W, format: *const c_char, ap: va_list) -> c_int {
+pub unsafe fn printf<W: Write>(w: W, format: *const c_char, ap: VaList) -> c_int {
     inner_printf(w, format, ap).unwrap_or(-1)
 }
diff --git a/tests/stdio/printf.c b/tests/stdio/printf.c
index f36cc970..ea7b8dde 100644
--- a/tests/stdio/printf.c
+++ b/tests/stdio/printf.c
@@ -1,7 +1,5 @@
 #include <stdio.h>
 
-#include "test_helpers.h"
-
 int main(void) {
     int sofar = 0;
     int len = printf(
-- 
GitLab