Fixed sql

This commit is contained in:
pwn
2025-12-05 12:41:36 +03:00
parent 4688603d80
commit de46f09b6a

View File

@@ -100,7 +100,8 @@ impl Sql {
} }
fn check_query(query: &str) -> JsResult<()> { fn check_query(query: &str) -> JsResult<()> {
let lowered = query.to_ascii_lowercase(); let sanitized = Self::strip_sql_comments(query);
let lowered = sanitized.to_ascii_lowercase();
for &blacklisted in BLACKLIST { for &blacklisted in BLACKLIST {
if lowered.contains(blacklisted) { if lowered.contains(blacklisted) {
@@ -113,7 +114,9 @@ impl Sql {
} }
if let Some(scope) = SCOPE.get() { if let Some(scope) = SCOPE.get() {
if !scope.is_empty() && !lowered.contains(scope) { if !scope.is_empty()
&& !Self::contains_scope_identifier(&lowered, &scope.to_ascii_lowercase())
{
return Err(JsNativeError::error() return Err(JsNativeError::error()
.with_message(format!( .with_message(format!(
"sql: query must only reference the configured scope `{scope}`" "sql: query must only reference the configured scope `{scope}`"
@@ -125,6 +128,96 @@ impl Sql {
Ok(()) Ok(())
} }
fn strip_sql_comments(query: &str) -> String {
let mut without_block = String::with_capacity(query.len());
let mut i = 0;
let bytes = query.as_bytes();
while i < bytes.len() {
if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
if let Some(end) = query[i + 2..].find("*/") {
i += 2 + end + 2;
continue;
} else {
break;
}
}
without_block.push(bytes[i] as char);
i += 1;
}
let mut without_line = String::with_capacity(without_block.len());
for line in without_block.lines() {
if let Some(idx) = line.find("--") {
without_line.push_str(&line[..idx]);
} else {
without_line.push_str(line);
}
without_line.push('\n');
}
while without_line.ends_with('\n') {
without_line.pop();
}
without_line
}
fn contains_scope_identifier(query: &str, scope: &str) -> bool {
if scope.is_empty() || query.len() < scope.len() {
return false;
}
let scope_bytes = scope.as_bytes();
let bytes = query.as_bytes();
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut i = 0;
while i + scope_bytes.len() <= bytes.len() {
let current = bytes[i];
match current {
b'\'' if !in_double_quote => {
in_single_quote = !in_single_quote;
i += 1;
continue;
}
b'"' if !in_single_quote => {
in_double_quote = !in_double_quote;
i += 1;
continue;
}
_ => {}
}
if in_single_quote || in_double_quote {
i += 1;
continue;
}
if &bytes[i..i + scope_bytes.len()] == scope_bytes {
let before = if i == 0 { None } else { Some(bytes[i - 1]) };
let after = bytes.get(i + scope_bytes.len()).copied();
let is_boundary = |ch: Option<u8>| match ch {
None => true,
Some(b'a'..=b'z') | Some(b'0'..=b'9') | Some(b'_') => false,
_ => true,
};
if is_boundary(before) && is_boundary(after) {
return true;
}
}
i += 1;
}
false
}
async fn fetch_rows( async fn fetch_rows(
pool: PgPool, pool: PgPool,
sql: String, sql: String,