diff --git a/src/lib/parser/statement/splitter.rs b/src/lib/parser/statement/splitter.rs index c6cfb9f5ac3e5f90dc9a1c33e084760763ed509f..abf7df0d4b401f5b16130b7a58d928b6c8413a49 100644 --- a/src/lib/parser/statement/splitter.rs +++ b/src/lib/parser/statement/splitter.rs @@ -59,15 +59,6 @@ impl Display for StatementError { } } -/// Returns true if the byte matches [^A-Za-z0-9_] -fn is_invalid(byte: u8) -> bool { - byte <= 47 - || (byte >= 58 && byte <= 64) - || (byte >= 91 && byte <= 94) - || byte == 96 - || (byte >= 123 && byte <= 127) -} - #[derive(Debug, PartialEq)] pub enum StatementVariant<'a> { And(&'a str), @@ -112,14 +103,13 @@ impl<'a> StatementSplitter<'a> { } } - fn get_statement(&mut self) -> StatementVariant<'a> { + fn get_statement(&self, start: usize, end: usize) -> StatementVariant<'a> { if self.logical == LogicalOp::And { - StatementVariant::And(&self.data[self.start + 1..self.read - 1].trim()) + StatementVariant::And(&self.data[start + 1..end].trim()) } else if self.logical == LogicalOp::Or { - StatementVariant::Or(&self.data[self.start + 1..self.read - 1].trim()) + StatementVariant::Or(&self.data[start + 1..end].trim()) } else { - let statement = &self.data[self.start..self.read - 1].trim(); - StatementVariant::Default(statement) + StatementVariant::Default(&self.data[start..end].trim()) } } @@ -140,14 +130,14 @@ impl<'a> Iterator for StatementSplitter<'a> { type Item = Result<StatementVariant<'a>, StatementError>; fn next(&mut self) -> Option<Self::Item> { - self.start = self.read; + let start = self.read; let mut first_arg_found = false; - let mut else_found = false; - let mut else_pos = 0; let mut error = None; let mut bytes = self.data.bytes().skip(self.read).peekable(); let mut last = None; + bytes.peek()?; + while let Some(character) = bytes.next() { self.read += 1; match character { @@ -176,17 +166,13 @@ impl<'a> Iterator for StatementSplitter<'a> { } } // Toggle quotes and stop matching variables. + b'"' if self.quotes == Quotes::Double => self.quotes = Quotes::None, b'"' => { - if self.quotes == Quotes::Double { - self.quotes = Quotes::None; - } else { - self.quotes = Quotes::Double; - self.variable = false; - } + self.quotes = Quotes::Double; + self.variable = false; } // Array expansion - b'@' => self.variable = true, - b'$' => self.variable = true, + b'@' | b'$' => self.variable = true, b'{' if [Some(b'$'), Some(b'@')].contains(&last) => self.vbrace = true, b'{' if self.quotes == Quotes::None => self.brace_level += 1, b'}' if self.vbrace => self.vbrace = false, @@ -201,11 +187,6 @@ impl<'a> Iterator for StatementSplitter<'a> { } } b'(' if self.math_expr => self.math_paren_level += 1, - b'(' if !self.variable => { - if error.is_none() && self.quotes == Quotes::None { - error = Some(StatementError::InvalidCharacter(character as char, self.read)) - } - } b'(' if self.method || last == Some(b'$') => { self.variable = false; if bytes.peek() == Some(&b'(') { @@ -221,6 +202,9 @@ impl<'a> Iterator for StatementSplitter<'a> { self.method = true; self.variable = false; } + b'(' if error.is_none() && self.quotes == Quotes::None => { + error = Some(StatementError::InvalidCharacter(character as char, self.read)) + } b')' if self.math_expr => { if self.math_paren_level == 0 { match bytes.peek() { @@ -253,7 +237,7 @@ impl<'a> Iterator for StatementSplitter<'a> { } b')' => self.paren_level -= 1, b';' if self.quotes == Quotes::None && self.paren_level == 0 => { - let statement = self.get_statement(); + let statement = self.get_statement(start, self.read - 1); self.logical = LogicalOp::None; return match error { @@ -261,77 +245,60 @@ impl<'a> Iterator for StatementSplitter<'a> { None => Some(Ok(statement)), }; } - b'&' | b'|' if self.quotes == Quotes::None && self.paren_level == 0 => { - if bytes.peek() == Some(&character) { - // Detecting if there is a 2nd `&` character - let statement = self.get_statement(); - self.read += 1; - self.logical = - if character == b'&' { LogicalOp::And } else { LogicalOp::Or }; - return match error { - Some(error) => Some(Err(error)), - None => Some(Ok(statement)), - }; - } - } - b' ' if else_found => { - let output = &self.data[else_pos..self.read - 1].trim(); - if !output.is_empty() && &"if" != output { - self.read = else_pos; - self.logical = LogicalOp::None; - return Some(Ok(StatementVariant::Default("else"))); - } - else_found = false; + b'&' | b'|' + if self.quotes == Quotes::None + && self.paren_level == 0 + && last == Some(character) => + { + // Detecting if there is a 2nd `&` character + let statement = self.get_statement(start, self.read - 2); + self.logical = if character == b'&' { LogicalOp::And } else { LogicalOp::Or }; + return match error { + Some(error) => Some(Err(error)), + None => Some(Ok(statement)), + }; } - b' ' if !first_arg_found => { - let output = &self.data[self.start..self.read - 1].trim(); - if !output.is_empty() { - match *output { - "else" => { - else_found = true; - else_pos = self.read; - } - _ => first_arg_found = true, + b' ' if !first_arg_found => match self.data[start..self.read - 1].trim() { + "else" => { + if self.data.len() < self.read + 2 + || &self.data[self.read..self.read + 2] != "if" + { + self.logical = LogicalOp::None; + return Some(Ok(StatementVariant::Default("else"))); } } - } + "" => {} + _ => first_arg_found = true, + }, // [^A-Za-z0-9_] - byte => { - if self.variable && is_invalid(byte) { - self.variable = false - } - } + 0...47 | 58...64 | 91...94 | 96 | 123...127 => self.variable = false, + _ => {} } last = Some(character); } - if self.start == self.read { - None - } else { - self.read = self.data.len(); - match error { - Some(error) => Some(Err(error)), - None if self.paren_level != 0 => Some(Err(StatementError::UnterminatedSubshell)), - None if self.method => Some(Err(StatementError::UnterminatedMethod)), - None if self.vbrace => Some(Err(StatementError::UnterminatedBracedVar)), - None if self.brace_level != 0 => Some(Err(StatementError::UnterminatedBrace)), - None if self.math_expr => Some(Err(StatementError::UnterminatedArithmetic)), - None => { - let output = self.data[self.start..].trim(); - if output.is_empty() { - Some(Ok(self.get_statement_from(output))) - } else { - match output.as_bytes()[0] { - b'>' | b'<' | b'^' => { - Some(Err(StatementError::ExpectedCommandButFound("redirection"))) - } - b'|' => Some(Err(StatementError::ExpectedCommandButFound("pipe"))), - b'&' => Some(Err(StatementError::ExpectedCommandButFound("&"))), - b'*' | b'%' | b'?' | b'{' | b'}' => { - Some(Err(StatementError::IllegalCommandName(String::from(output)))) - } - _ => Some(Ok(self.get_statement_from(output))), + match error { + Some(error) => Some(Err(error)), + None if self.paren_level != 0 => Some(Err(StatementError::UnterminatedSubshell)), + None if self.method => Some(Err(StatementError::UnterminatedMethod)), + None if self.vbrace => Some(Err(StatementError::UnterminatedBracedVar)), + None if self.brace_level != 0 => Some(Err(StatementError::UnterminatedBrace)), + None if self.math_expr => Some(Err(StatementError::UnterminatedArithmetic)), + None => { + let output = self.data[start..].trim(); + if output.is_empty() { + Some(Ok(self.get_statement_from(output))) + } else { + match output.as_bytes()[0] { + b'>' | b'<' | b'^' => { + Some(Err(StatementError::ExpectedCommandButFound("redirection"))) + } + b'|' => Some(Err(StatementError::ExpectedCommandButFound("pipe"))), + b'&' => Some(Err(StatementError::ExpectedCommandButFound("&"))), + b'*' | b'%' | b'?' | b'{' | b'}' => { + Some(Err(StatementError::IllegalCommandName(String::from(output)))) } + _ => Some(Ok(self.get_statement_from(output))), } } }