diff --git a/aws_sigma_service/executor/src/builtins/sql.rs b/aws_sigma_service/executor/src/builtins/sql.rs index b0073b1..33db722 100644 --- a/aws_sigma_service/executor/src/builtins/sql.rs +++ b/aws_sigma_service/executor/src/builtins/sql.rs @@ -100,8 +100,7 @@ impl Sql { } fn check_query(query: &str) -> JsResult<()> { - let sanitized = Self::strip_sql_comments(query); - let lowered = sanitized.to_ascii_lowercase(); + let lowered = query.to_ascii_lowercase(); for &blacklisted in BLACKLIST { if lowered.contains(blacklisted) { @@ -114,9 +113,7 @@ impl Sql { } if let Some(scope) = SCOPE.get() { - if !scope.is_empty() - && !Self::contains_scope_identifier(&lowered, &scope.to_ascii_lowercase()) - { + if !scope.is_empty() && !lowered.contains(scope) { return Err(JsNativeError::error() .with_message(format!( "sql: query must only reference the configured scope `{scope}`" @@ -128,96 +125,6 @@ impl Sql { Ok(()) } - fn strip_sql_comments(query: &str) -> String { - let mut without_block = String::with_capacity(query.len()); - let mut i = 0; - let bytes = query.as_bytes(); - - while i < bytes.len() { - if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' { - if let Some(end) = query[i + 2..].find("*/") { - i += 2 + end + 2; - continue; - } else { - break; - } - } - - without_block.push(bytes[i] as char); - i += 1; - } - - let mut without_line = String::with_capacity(without_block.len()); - for line in without_block.lines() { - if let Some(idx) = line.find("--") { - without_line.push_str(&line[..idx]); - } else { - without_line.push_str(line); - } - without_line.push('\n'); - } - - while without_line.ends_with('\n') { - without_line.pop(); - } - - without_line - } - - fn contains_scope_identifier(query: &str, scope: &str) -> bool { - if scope.is_empty() || query.len() < scope.len() { - return false; - } - - let scope_bytes = scope.as_bytes(); - let bytes = query.as_bytes(); - let mut in_single_quote = false; - let mut in_double_quote = false; - let mut i = 0; - - while i + scope_bytes.len() <= bytes.len() { - let current = bytes[i]; - - match current { - b'\'' if !in_double_quote => { - in_single_quote = !in_single_quote; - i += 1; - continue; - } - b'"' if !in_single_quote => { - in_double_quote = !in_double_quote; - i += 1; - continue; - } - _ => {} - } - - if in_single_quote || in_double_quote { - i += 1; - continue; - } - - if &bytes[i..i + scope_bytes.len()] == scope_bytes { - let before = if i == 0 { None } else { Some(bytes[i - 1]) }; - let after = bytes.get(i + scope_bytes.len()).copied(); - - let is_boundary = |ch: Option| match ch { - None => true, - Some(b'a'..=b'z') | Some(b'0'..=b'9') | Some(b'_') => false, - _ => true, - }; - - if is_boundary(before) && is_boundary(after) { - return true; - } - } - - i += 1; - } - - false - } - async fn fetch_rows( pool: PgPool, sql: String,