use core::cell::RefCell; use std::sync::OnceLock; use base64::Engine as _; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use boa_engine::error::JsNativeError; use boa_engine::js_string; use boa_engine::native_function::NativeFunction; use boa_engine::object::{JsObject, ObjectInitializer}; use boa_engine::property::Attribute; use boa_engine::{Context, JsError, JsResult, JsString, JsValue}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use serde_json::Value as JsonValue; use sqlx::postgres::{PgArguments, PgQueryResult, PgRow, PgTypeInfo}; use sqlx::types::Json; use sqlx::{Column, PgPool, Postgres, Row}; pub(crate) struct Sql; static POOL: OnceLock = OnceLock::new(); static SCOPE: OnceLock = OnceLock::new(); static BLACKLIST: &'static [&str] = &[ "pg_authid", "pg_shadow", "pg_user", "pg_roles", "pg_auth_members", "pg_database", "pg_tablespace", "pg_settings", "pg_file_settings", "pg_hba_file_rules", "pg_stat_activity", "pg_stat_replication", "pg_replication_slots", "pg_config", "pg_backend_memory_contexts", ]; impl Sql { pub const NAME: JsString = js_string!("sql"); pub fn init(pool: PgPool, scope: String, context: &mut Context) -> JsObject { POOL.set(pool).unwrap(); SCOPE.set(scope).unwrap(); ObjectInitializer::new(context) .property( js_string!("name"), JsString::from(Self::NAME), Attribute::READONLY, ) .function( NativeFunction::from_async_fn(Self::query), js_string!("query"), 2, ) .function( NativeFunction::from_async_fn(Self::execute), js_string!("execute"), 2, ) .build() } fn pool() -> JsResult { POOL.get().cloned().ok_or_else(|| { JsNativeError::error() .with_message("sql: module not initialized") .into() }) } async fn query( _: &JsValue, args: &[JsValue], context: &RefCell<&mut Context>, ) -> JsResult { let sql = Self::string_from_arg(args, 0, "query", "sql.query", &mut context.borrow_mut())?; Self::check_query(&*sql)?; let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?; let rows = Self::fetch_rows(Self::pool()?, sql, params).await?; Self::rows_to_js(rows, &mut context.borrow_mut()) } async fn execute( _: &JsValue, args: &[JsValue], context: &RefCell<&mut Context>, ) -> JsResult { let sql = Self::string_from_arg(args, 0, "query", "sql.execute", &mut context.borrow_mut())?; Self::check_query(&*sql)?; let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?; let result = Self::execute_query(Self::pool()?, sql, params).await?; Self::result_to_js(result, &mut context.borrow_mut()) } fn check_query(query: &str) -> JsResult<()> { let lowered = query.to_ascii_lowercase(); for &blacklisted in BLACKLIST { if lowered.contains(blacklisted) { return Err(JsNativeError::error() .with_message(format!( "sql: use of the system table `{blacklisted}` is prohibited" )) .into()); } } if let Some(scope) = SCOPE.get() { if !scope.is_empty() && !lowered.contains(scope) { return Err(JsNativeError::error() .with_message(format!( "sql: query must only reference the configured scope `{scope}`" )) .into()); } } Ok(()) } async fn fetch_rows( pool: PgPool, sql: String, params: Vec, ) -> Result, SqlExecutionError> { let mut query = sqlx::query(&sql); for param in params { query = param.bind(query); } query .fetch_all(&pool) .await .map_err(SqlExecutionError::from) } async fn execute_query( pool: PgPool, sql: String, params: Vec, ) -> Result { let mut query = sqlx::query(&sql); for param in params { query = param.bind(query); } query.execute(&pool).await.map_err(SqlExecutionError::from) } fn params_from_js_value( arg: Option<&JsValue>, context: &mut Context, ) -> JsResult> { let Some(value) = arg else { return Ok(Vec::new()); }; if value.is_undefined() || value.is_null() { return Ok(Vec::new()); } let object = value.as_object().ok_or_else(|| { JsError::from( JsNativeError::typ().with_message("sql: parameters must be provided as an array"), ) })?; let length_value = object.get(js_string!("length"), context)?; let length_number = length_value.to_number(context)?; let length = if length_number.is_nan() || length_number.is_sign_negative() { 0 } else { length_number.floor().min(u32::MAX as f64) as u32 }; let mut params = Vec::with_capacity(length as usize); for index in 0..length { let element = object.get(index, context)?; params.push(Self::param_from_js_value(&element, context)?); } Ok(params) } fn param_from_js_value(value: &JsValue, context: &mut Context) -> JsResult { if value.is_undefined() || value.is_null() { return Ok(SqlParam::Null); } if value.is_boolean() { return Ok(SqlParam::Bool(value.to_boolean())); } if value.is_number() { let number = value.to_number(context)?; if number.fract() == 0.0 && number >= i64::MIN as f64 && number <= i64::MAX as f64 { return Ok(SqlParam::Int(number as i64)); } return Ok(SqlParam::Float(number)); } if value.is_string() { return Ok(SqlParam::Text( value.to_string(context)?.to_std_string_escaped(), )); } if value.is_bigint() { return Ok(SqlParam::Text( value.to_string(context)?.to_std_string_escaped(), )); } if value.is_symbol() { return Err(JsNativeError::typ() .with_message("sql: Symbols cannot be sent as parameters") .into()); } let Some(json) = value.to_json(context)? else { return Ok(SqlParam::Null); }; Ok(SqlParam::Json(json)) } fn string_from_arg( args: &[JsValue], index: usize, name: &str, method: &str, context: &mut Context, ) -> JsResult { let value = args .get(index) .ok_or_else(|| Self::missing_argument(name, method))?; if value.is_undefined() || value.is_null() { return Err(Self::missing_argument(name, method)); } Ok(value.to_string(context)?.to_std_string_escaped()) } fn rows_to_js(rows: Vec, context: &mut Context) -> JsResult { let constructor = context .intrinsics() .constructors() .array() .constructor() .clone(); let array_value = constructor.construct(&[], None, context)?; for (index, row) in rows.iter().enumerate() { let js_row = Self::row_to_object(row, context)?; let row_value: JsValue = js_row.into(); array_value.create_data_property_or_throw(index, row_value, context)?; } Ok(array_value.into()) } fn row_to_object(row: &PgRow, context: &mut Context) -> JsResult { let object = JsObject::with_null_proto(); for (index, column) in row.columns().iter().enumerate() { let value = Self::value_to_js(row, index, column.type_info(), context)?; object.create_data_property_or_throw(JsString::from(column.name()), value, context)?; } Ok(object) } fn value_to_js( row: &PgRow, index: usize, type_info: &PgTypeInfo, context: &mut Context, ) -> JsResult { let type_name = type_info.to_string().to_ascii_uppercase(); macro_rules! optional_number { ($ty:ty) => {{ let value: Option<$ty> = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|inner| JsValue::new(inner as f64)) .unwrap_or_else(JsValue::null)) }}; } match type_name.as_str() { "BOOL" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null)) } "INT2" | "INT4" => optional_number!(i32), "INT8" => optional_number!(i64), "FLOAT4" | "FLOAT8" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null)) } "NUMERIC" | "DECIMAL" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null)) } "TEXT" | "VARCHAR" | "BPCHAR" | "CHAR" | "UUID" | "INET" | "CIDR" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|text| JsValue::from(JsString::from(text))) .unwrap_or_else(JsValue::null)) } "JSON" | "JSONB" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; match value { Some(json) => JsValue::from_json(&json, context), None => Ok(JsValue::null()), } } "TIMESTAMP" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|ts| { let dt = DateTime::::from_naive_utc_and_offset(ts, Utc); JsValue::from(JsString::from(dt.to_rfc3339())) }) .unwrap_or_else(JsValue::null)) } "TIMESTAMPTZ" => { let value: Option> = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|ts| JsValue::from(JsString::from(ts.to_rfc3339()))) .unwrap_or_else(JsValue::null)) } "DATE" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|date| JsValue::from(JsString::from(date.to_string()))) .unwrap_or_else(JsValue::null)) } "TIME" | "TIMETZ" => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|time| JsValue::from(JsString::from(time.to_string()))) .unwrap_or_else(JsValue::null)) } "BYTEA" => { let value: Option> = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|bytes| { let encoded = BASE64_STANDARD.encode(bytes); JsValue::from(JsString::from(encoded)) }) .unwrap_or_else(JsValue::null)) } _ => { let value: Option = row.try_get(index).map_err(Self::column_access_error)?; Ok(value .map(|text| JsValue::from(JsString::from(text))) .unwrap_or_else(JsValue::null)) } } } fn result_to_js(result: PgQueryResult, context: &mut Context) -> JsResult { let mut initializer = ObjectInitializer::new(context); initializer.property( js_string!("rowsAffected"), result.rows_affected() as f64, Attribute::READONLY | Attribute::ENUMERABLE, ); Ok(initializer.build().into()) } fn missing_argument(name: &str, method: &str) -> JsError { JsNativeError::typ() .with_message(format!("{method}: missing required argument `{name}`")) .into() } fn column_access_error(err: sqlx::Error) -> JsError { JsNativeError::error() .with_message(format!("sql: failed to read column value: {err}")) .into() } } #[derive(Debug, Clone)] enum SqlParam { Int(i64), Float(f64), Bool(bool), Text(String), Json(JsonValue), Null, } impl SqlParam { fn bind<'q>( self, query: sqlx::query::Query<'q, Postgres, PgArguments>, ) -> sqlx::query::Query<'q, Postgres, PgArguments> { match self { SqlParam::Int(value) => query.bind(value), SqlParam::Float(value) => query.bind(value), SqlParam::Bool(value) => query.bind(value), SqlParam::Text(value) => query.bind(value), SqlParam::Json(value) => query.bind(Json(value)), SqlParam::Null => query.bind(Option::::None), } } } #[derive(Debug)] enum SqlExecutionError { Sql(sqlx::Error), } impl From for SqlExecutionError { fn from(value: sqlx::Error) -> Self { Self::Sql(value) } } impl SqlExecutionError { fn into_js_error(self, method: &str) -> JsError { match self { Self::Sql(err) => JsNativeError::error() .with_message(format!("{method}: database error: {err}")) .into(), } } } impl From for JsError { fn from(value: SqlExecutionError) -> Self { value.into_js_error("sql") } }