Files
CTFCUP-25/aws_sigma_service/executor/src/builtins/sql.rs
2025-12-05 07:14:11 +00:00

454 lines
14 KiB
Rust

use core::cell::RefCell;
use std::sync::OnceLock;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use boa_engine::error::JsNativeError;
use boa_engine::js_string;
use boa_engine::native_function::NativeFunction;
use boa_engine::object::{JsObject, ObjectInitializer};
use boa_engine::property::Attribute;
use boa_engine::{Context, JsError, JsResult, JsString, JsValue};
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use serde_json::Value as JsonValue;
use sqlx::postgres::{PgArguments, PgQueryResult, PgRow, PgTypeInfo};
use sqlx::types::Json;
use sqlx::{Column, PgPool, Postgres, Row};
pub(crate) struct Sql;
static POOL: OnceLock<PgPool> = OnceLock::new();
static SCOPE: OnceLock<String> = OnceLock::new();
static BLACKLIST: &'static [&str] = &[
"pg_authid",
"pg_shadow",
"pg_user",
"pg_roles",
"pg_auth_members",
"pg_database",
"pg_tablespace",
"pg_settings",
"pg_file_settings",
"pg_hba_file_rules",
"pg_stat_activity",
"pg_stat_replication",
"pg_replication_slots",
"pg_config",
"pg_backend_memory_contexts",
];
impl Sql {
pub const NAME: JsString = js_string!("sql");
pub fn init(pool: PgPool, scope: String, context: &mut Context) -> JsObject {
POOL.set(pool).unwrap();
SCOPE.set(scope).unwrap();
ObjectInitializer::new(context)
.property(
js_string!("name"),
JsString::from(Self::NAME),
Attribute::READONLY,
)
.function(
NativeFunction::from_async_fn(Self::query),
js_string!("query"),
2,
)
.function(
NativeFunction::from_async_fn(Self::execute),
js_string!("execute"),
2,
)
.build()
}
fn pool() -> JsResult<PgPool> {
POOL.get().cloned().ok_or_else(|| {
JsNativeError::error()
.with_message("sql: module not initialized")
.into()
})
}
async fn query(
_: &JsValue,
args: &[JsValue],
context: &RefCell<&mut Context>,
) -> JsResult<JsValue> {
let sql = Self::string_from_arg(args, 0, "query", "sql.query", &mut context.borrow_mut())?;
Self::check_query(&*sql)?;
let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?;
let rows = Self::fetch_rows(Self::pool()?, sql, params).await?;
Self::rows_to_js(rows, &mut context.borrow_mut())
}
async fn execute(
_: &JsValue,
args: &[JsValue],
context: &RefCell<&mut Context>,
) -> JsResult<JsValue> {
let sql =
Self::string_from_arg(args, 0, "query", "sql.execute", &mut context.borrow_mut())?;
Self::check_query(&*sql)?;
let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?;
let result = Self::execute_query(Self::pool()?, sql, params).await?;
Self::result_to_js(result, &mut context.borrow_mut())
}
fn check_query(query: &str) -> JsResult<()> {
let lowered = query.to_ascii_lowercase();
for &blacklisted in BLACKLIST {
if lowered.contains(blacklisted) {
return Err(JsNativeError::error()
.with_message(format!(
"sql: use of the system table `{blacklisted}` is prohibited"
))
.into());
}
}
if let Some(scope) = SCOPE.get() {
if !scope.is_empty() && !lowered.contains(scope) {
return Err(JsNativeError::error()
.with_message(format!(
"sql: query must only reference the configured scope `{scope}`"
))
.into());
}
}
Ok(())
}
async fn fetch_rows(
pool: PgPool,
sql: String,
params: Vec<SqlParam>,
) -> Result<Vec<PgRow>, SqlExecutionError> {
let mut query = sqlx::query(&sql);
for param in params {
query = param.bind(query);
}
query
.fetch_all(&pool)
.await
.map_err(SqlExecutionError::from)
}
async fn execute_query(
pool: PgPool,
sql: String,
params: Vec<SqlParam>,
) -> Result<PgQueryResult, SqlExecutionError> {
let mut query = sqlx::query(&sql);
for param in params {
query = param.bind(query);
}
query.execute(&pool).await.map_err(SqlExecutionError::from)
}
fn params_from_js_value(
arg: Option<&JsValue>,
context: &mut Context,
) -> JsResult<Vec<SqlParam>> {
let Some(value) = arg else {
return Ok(Vec::new());
};
if value.is_undefined() || value.is_null() {
return Ok(Vec::new());
}
let object = value.as_object().ok_or_else(|| {
JsError::from(
JsNativeError::typ().with_message("sql: parameters must be provided as an array"),
)
})?;
let length_value = object.get(js_string!("length"), context)?;
let length_number = length_value.to_number(context)?;
let length = if length_number.is_nan() || length_number.is_sign_negative() {
0
} else {
length_number.floor().min(u32::MAX as f64) as u32
};
let mut params = Vec::with_capacity(length as usize);
for index in 0..length {
let element = object.get(index, context)?;
params.push(Self::param_from_js_value(&element, context)?);
}
Ok(params)
}
fn param_from_js_value(value: &JsValue, context: &mut Context) -> JsResult<SqlParam> {
if value.is_undefined() || value.is_null() {
return Ok(SqlParam::Null);
}
if value.is_boolean() {
return Ok(SqlParam::Bool(value.to_boolean()));
}
if value.is_number() {
let number = value.to_number(context)?;
if number.fract() == 0.0 && number >= i64::MIN as f64 && number <= i64::MAX as f64 {
return Ok(SqlParam::Int(number as i64));
}
return Ok(SqlParam::Float(number));
}
if value.is_string() {
return Ok(SqlParam::Text(
value.to_string(context)?.to_std_string_escaped(),
));
}
if value.is_bigint() {
return Ok(SqlParam::Text(
value.to_string(context)?.to_std_string_escaped(),
));
}
if value.is_symbol() {
return Err(JsNativeError::typ()
.with_message("sql: Symbols cannot be sent as parameters")
.into());
}
let Some(json) = value.to_json(context)? else {
return Ok(SqlParam::Null);
};
Ok(SqlParam::Json(json))
}
fn string_from_arg(
args: &[JsValue],
index: usize,
name: &str,
method: &str,
context: &mut Context,
) -> JsResult<String> {
let value = args
.get(index)
.ok_or_else(|| Self::missing_argument(name, method))?;
if value.is_undefined() || value.is_null() {
return Err(Self::missing_argument(name, method));
}
Ok(value.to_string(context)?.to_std_string_escaped())
}
fn rows_to_js(rows: Vec<PgRow>, context: &mut Context) -> JsResult<JsValue> {
let constructor = context
.intrinsics()
.constructors()
.array()
.constructor()
.clone();
let array_value = constructor.construct(&[], None, context)?;
for (index, row) in rows.iter().enumerate() {
let js_row = Self::row_to_object(row, context)?;
let row_value: JsValue = js_row.into();
array_value.create_data_property_or_throw(index, row_value, context)?;
}
Ok(array_value.into())
}
fn row_to_object(row: &PgRow, context: &mut Context) -> JsResult<JsObject> {
let object = JsObject::with_null_proto();
for (index, column) in row.columns().iter().enumerate() {
let value = Self::value_to_js(row, index, column.type_info(), context)?;
object.create_data_property_or_throw(JsString::from(column.name()), value, context)?;
}
Ok(object)
}
fn value_to_js(
row: &PgRow,
index: usize,
type_info: &PgTypeInfo,
context: &mut Context,
) -> JsResult<JsValue> {
let type_name = type_info.to_string().to_ascii_uppercase();
macro_rules! optional_number {
($ty:ty) => {{
let value: Option<$ty> = row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|inner| JsValue::new(inner as f64))
.unwrap_or_else(JsValue::null))
}};
}
match type_name.as_str() {
"BOOL" => {
let value: Option<bool> = row.try_get(index).map_err(Self::column_access_error)?;
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
}
"INT2" | "INT4" => optional_number!(i32),
"INT8" => optional_number!(i64),
"FLOAT4" | "FLOAT8" => {
let value: Option<f64> = row.try_get(index).map_err(Self::column_access_error)?;
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
}
"NUMERIC" | "DECIMAL" => {
let value: Option<f64> = row.try_get(index).map_err(Self::column_access_error)?;
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
}
"TEXT" | "VARCHAR" | "BPCHAR" | "CHAR" | "UUID" | "INET" | "CIDR" => {
let value: Option<String> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|text| JsValue::from(JsString::from(text)))
.unwrap_or_else(JsValue::null))
}
"JSON" | "JSONB" => {
let value: Option<JsonValue> =
row.try_get(index).map_err(Self::column_access_error)?;
match value {
Some(json) => JsValue::from_json(&json, context),
None => Ok(JsValue::null()),
}
}
"TIMESTAMP" => {
let value: Option<NaiveDateTime> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|ts| {
let dt = DateTime::<Utc>::from_naive_utc_and_offset(ts, Utc);
JsValue::from(JsString::from(dt.to_rfc3339()))
})
.unwrap_or_else(JsValue::null))
}
"TIMESTAMPTZ" => {
let value: Option<DateTime<Utc>> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|ts| JsValue::from(JsString::from(ts.to_rfc3339())))
.unwrap_or_else(JsValue::null))
}
"DATE" => {
let value: Option<NaiveDate> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|date| JsValue::from(JsString::from(date.to_string())))
.unwrap_or_else(JsValue::null))
}
"TIME" | "TIMETZ" => {
let value: Option<NaiveTime> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|time| JsValue::from(JsString::from(time.to_string())))
.unwrap_or_else(JsValue::null))
}
"BYTEA" => {
let value: Option<Vec<u8>> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|bytes| {
let encoded = BASE64_STANDARD.encode(bytes);
JsValue::from(JsString::from(encoded))
})
.unwrap_or_else(JsValue::null))
}
_ => {
let value: Option<String> =
row.try_get(index).map_err(Self::column_access_error)?;
Ok(value
.map(|text| JsValue::from(JsString::from(text)))
.unwrap_or_else(JsValue::null))
}
}
}
fn result_to_js(result: PgQueryResult, context: &mut Context) -> JsResult<JsValue> {
let mut initializer = ObjectInitializer::new(context);
initializer.property(
js_string!("rowsAffected"),
result.rows_affected() as f64,
Attribute::READONLY | Attribute::ENUMERABLE,
);
Ok(initializer.build().into())
}
fn missing_argument(name: &str, method: &str) -> JsError {
JsNativeError::typ()
.with_message(format!("{method}: missing required argument `{name}`"))
.into()
}
fn column_access_error(err: sqlx::Error) -> JsError {
JsNativeError::error()
.with_message(format!("sql: failed to read column value: {err}"))
.into()
}
}
#[derive(Debug, Clone)]
enum SqlParam {
Int(i64),
Float(f64),
Bool(bool),
Text(String),
Json(JsonValue),
Null,
}
impl SqlParam {
fn bind<'q>(
self,
query: sqlx::query::Query<'q, Postgres, PgArguments>,
) -> sqlx::query::Query<'q, Postgres, PgArguments> {
match self {
SqlParam::Int(value) => query.bind(value),
SqlParam::Float(value) => query.bind(value),
SqlParam::Bool(value) => query.bind(value),
SqlParam::Text(value) => query.bind(value),
SqlParam::Json(value) => query.bind(Json(value)),
SqlParam::Null => query.bind(Option::<String>::None),
}
}
}
#[derive(Debug)]
enum SqlExecutionError {
Sql(sqlx::Error),
}
impl From<sqlx::Error> for SqlExecutionError {
fn from(value: sqlx::Error) -> Self {
Self::Sql(value)
}
}
impl SqlExecutionError {
fn into_js_error(self, method: &str) -> JsError {
match self {
Self::Sql(err) => JsNativeError::error()
.with_message(format!("{method}: database error: {err}"))
.into(),
}
}
}
impl From<SqlExecutionError> for JsError {
fn from(value: SqlExecutionError) -> Self {
value.into_js_error("sql")
}
}