454 lines
14 KiB
Rust
454 lines
14 KiB
Rust
use core::cell::RefCell;
|
|
use std::sync::OnceLock;
|
|
|
|
use base64::Engine as _;
|
|
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
|
use boa_engine::error::JsNativeError;
|
|
use boa_engine::js_string;
|
|
use boa_engine::native_function::NativeFunction;
|
|
use boa_engine::object::{JsObject, ObjectInitializer};
|
|
use boa_engine::property::Attribute;
|
|
use boa_engine::{Context, JsError, JsResult, JsString, JsValue};
|
|
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
|
|
use serde_json::Value as JsonValue;
|
|
use sqlx::postgres::{PgArguments, PgQueryResult, PgRow, PgTypeInfo};
|
|
use sqlx::types::Json;
|
|
use sqlx::{Column, PgPool, Postgres, Row};
|
|
|
|
pub(crate) struct Sql;
|
|
|
|
static POOL: OnceLock<PgPool> = OnceLock::new();
|
|
static SCOPE: OnceLock<String> = OnceLock::new();
|
|
|
|
static BLACKLIST: &'static [&str] = &[
|
|
"pg_authid",
|
|
"pg_shadow",
|
|
"pg_user",
|
|
"pg_roles",
|
|
"pg_auth_members",
|
|
"pg_database",
|
|
"pg_tablespace",
|
|
"pg_settings",
|
|
"pg_file_settings",
|
|
"pg_hba_file_rules",
|
|
"pg_stat_activity",
|
|
"pg_stat_replication",
|
|
"pg_replication_slots",
|
|
"pg_config",
|
|
"pg_backend_memory_contexts",
|
|
];
|
|
|
|
impl Sql {
|
|
pub const NAME: JsString = js_string!("sql");
|
|
|
|
pub fn init(pool: PgPool, scope: String, context: &mut Context) -> JsObject {
|
|
POOL.set(pool).unwrap();
|
|
SCOPE.set(scope).unwrap();
|
|
|
|
ObjectInitializer::new(context)
|
|
.property(
|
|
js_string!("name"),
|
|
JsString::from(Self::NAME),
|
|
Attribute::READONLY,
|
|
)
|
|
.function(
|
|
NativeFunction::from_async_fn(Self::query),
|
|
js_string!("query"),
|
|
2,
|
|
)
|
|
.function(
|
|
NativeFunction::from_async_fn(Self::execute),
|
|
js_string!("execute"),
|
|
2,
|
|
)
|
|
.build()
|
|
}
|
|
|
|
fn pool() -> JsResult<PgPool> {
|
|
POOL.get().cloned().ok_or_else(|| {
|
|
JsNativeError::error()
|
|
.with_message("sql: module not initialized")
|
|
.into()
|
|
})
|
|
}
|
|
|
|
async fn query(
|
|
_: &JsValue,
|
|
args: &[JsValue],
|
|
context: &RefCell<&mut Context>,
|
|
) -> JsResult<JsValue> {
|
|
let sql = Self::string_from_arg(args, 0, "query", "sql.query", &mut context.borrow_mut())?;
|
|
Self::check_query(&*sql)?;
|
|
let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?;
|
|
|
|
let rows = Self::fetch_rows(Self::pool()?, sql, params).await?;
|
|
Self::rows_to_js(rows, &mut context.borrow_mut())
|
|
}
|
|
|
|
async fn execute(
|
|
_: &JsValue,
|
|
args: &[JsValue],
|
|
context: &RefCell<&mut Context>,
|
|
) -> JsResult<JsValue> {
|
|
let sql =
|
|
Self::string_from_arg(args, 0, "query", "sql.execute", &mut context.borrow_mut())?;
|
|
Self::check_query(&*sql)?;
|
|
let params = Self::params_from_js_value(args.get(1), &mut context.borrow_mut())?;
|
|
|
|
let result = Self::execute_query(Self::pool()?, sql, params).await?;
|
|
Self::result_to_js(result, &mut context.borrow_mut())
|
|
}
|
|
|
|
fn check_query(query: &str) -> JsResult<()> {
|
|
let lowered = query.to_ascii_lowercase();
|
|
|
|
for &blacklisted in BLACKLIST {
|
|
if lowered.contains(blacklisted) {
|
|
return Err(JsNativeError::error()
|
|
.with_message(format!(
|
|
"sql: use of the system table `{blacklisted}` is prohibited"
|
|
))
|
|
.into());
|
|
}
|
|
}
|
|
|
|
if let Some(scope) = SCOPE.get() {
|
|
if !scope.is_empty() && !lowered.contains(scope) {
|
|
return Err(JsNativeError::error()
|
|
.with_message(format!(
|
|
"sql: query must only reference the configured scope `{scope}`"
|
|
))
|
|
.into());
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn fetch_rows(
|
|
pool: PgPool,
|
|
sql: String,
|
|
params: Vec<SqlParam>,
|
|
) -> Result<Vec<PgRow>, SqlExecutionError> {
|
|
let mut query = sqlx::query(&sql);
|
|
for param in params {
|
|
query = param.bind(query);
|
|
}
|
|
|
|
query
|
|
.fetch_all(&pool)
|
|
.await
|
|
.map_err(SqlExecutionError::from)
|
|
}
|
|
|
|
async fn execute_query(
|
|
pool: PgPool,
|
|
sql: String,
|
|
params: Vec<SqlParam>,
|
|
) -> Result<PgQueryResult, SqlExecutionError> {
|
|
let mut query = sqlx::query(&sql);
|
|
for param in params {
|
|
query = param.bind(query);
|
|
}
|
|
|
|
query.execute(&pool).await.map_err(SqlExecutionError::from)
|
|
}
|
|
|
|
fn params_from_js_value(
|
|
arg: Option<&JsValue>,
|
|
context: &mut Context,
|
|
) -> JsResult<Vec<SqlParam>> {
|
|
let Some(value) = arg else {
|
|
return Ok(Vec::new());
|
|
};
|
|
|
|
if value.is_undefined() || value.is_null() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let object = value.as_object().ok_or_else(|| {
|
|
JsError::from(
|
|
JsNativeError::typ().with_message("sql: parameters must be provided as an array"),
|
|
)
|
|
})?;
|
|
|
|
let length_value = object.get(js_string!("length"), context)?;
|
|
let length_number = length_value.to_number(context)?;
|
|
let length = if length_number.is_nan() || length_number.is_sign_negative() {
|
|
0
|
|
} else {
|
|
length_number.floor().min(u32::MAX as f64) as u32
|
|
};
|
|
let mut params = Vec::with_capacity(length as usize);
|
|
|
|
for index in 0..length {
|
|
let element = object.get(index, context)?;
|
|
params.push(Self::param_from_js_value(&element, context)?);
|
|
}
|
|
|
|
Ok(params)
|
|
}
|
|
|
|
fn param_from_js_value(value: &JsValue, context: &mut Context) -> JsResult<SqlParam> {
|
|
if value.is_undefined() || value.is_null() {
|
|
return Ok(SqlParam::Null);
|
|
}
|
|
|
|
if value.is_boolean() {
|
|
return Ok(SqlParam::Bool(value.to_boolean()));
|
|
}
|
|
|
|
if value.is_number() {
|
|
let number = value.to_number(context)?;
|
|
if number.fract() == 0.0 && number >= i64::MIN as f64 && number <= i64::MAX as f64 {
|
|
return Ok(SqlParam::Int(number as i64));
|
|
}
|
|
return Ok(SqlParam::Float(number));
|
|
}
|
|
|
|
if value.is_string() {
|
|
return Ok(SqlParam::Text(
|
|
value.to_string(context)?.to_std_string_escaped(),
|
|
));
|
|
}
|
|
|
|
if value.is_bigint() {
|
|
return Ok(SqlParam::Text(
|
|
value.to_string(context)?.to_std_string_escaped(),
|
|
));
|
|
}
|
|
|
|
if value.is_symbol() {
|
|
return Err(JsNativeError::typ()
|
|
.with_message("sql: Symbols cannot be sent as parameters")
|
|
.into());
|
|
}
|
|
|
|
let Some(json) = value.to_json(context)? else {
|
|
return Ok(SqlParam::Null);
|
|
};
|
|
Ok(SqlParam::Json(json))
|
|
}
|
|
|
|
fn string_from_arg(
|
|
args: &[JsValue],
|
|
index: usize,
|
|
name: &str,
|
|
method: &str,
|
|
context: &mut Context,
|
|
) -> JsResult<String> {
|
|
let value = args
|
|
.get(index)
|
|
.ok_or_else(|| Self::missing_argument(name, method))?;
|
|
|
|
if value.is_undefined() || value.is_null() {
|
|
return Err(Self::missing_argument(name, method));
|
|
}
|
|
|
|
Ok(value.to_string(context)?.to_std_string_escaped())
|
|
}
|
|
|
|
fn rows_to_js(rows: Vec<PgRow>, context: &mut Context) -> JsResult<JsValue> {
|
|
let constructor = context
|
|
.intrinsics()
|
|
.constructors()
|
|
.array()
|
|
.constructor()
|
|
.clone();
|
|
let array_value = constructor.construct(&[], None, context)?;
|
|
|
|
for (index, row) in rows.iter().enumerate() {
|
|
let js_row = Self::row_to_object(row, context)?;
|
|
let row_value: JsValue = js_row.into();
|
|
array_value.create_data_property_or_throw(index, row_value, context)?;
|
|
}
|
|
|
|
Ok(array_value.into())
|
|
}
|
|
|
|
fn row_to_object(row: &PgRow, context: &mut Context) -> JsResult<JsObject> {
|
|
let object = JsObject::with_null_proto();
|
|
|
|
for (index, column) in row.columns().iter().enumerate() {
|
|
let value = Self::value_to_js(row, index, column.type_info(), context)?;
|
|
object.create_data_property_or_throw(JsString::from(column.name()), value, context)?;
|
|
}
|
|
|
|
Ok(object)
|
|
}
|
|
|
|
fn value_to_js(
|
|
row: &PgRow,
|
|
index: usize,
|
|
type_info: &PgTypeInfo,
|
|
context: &mut Context,
|
|
) -> JsResult<JsValue> {
|
|
let type_name = type_info.to_string().to_ascii_uppercase();
|
|
|
|
macro_rules! optional_number {
|
|
($ty:ty) => {{
|
|
let value: Option<$ty> = row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|inner| JsValue::new(inner as f64))
|
|
.unwrap_or_else(JsValue::null))
|
|
}};
|
|
}
|
|
|
|
match type_name.as_str() {
|
|
"BOOL" => {
|
|
let value: Option<bool> = row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
|
|
}
|
|
"INT2" | "INT4" => optional_number!(i32),
|
|
"INT8" => optional_number!(i64),
|
|
"FLOAT4" | "FLOAT8" => {
|
|
let value: Option<f64> = row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
|
|
}
|
|
"NUMERIC" | "DECIMAL" => {
|
|
let value: Option<f64> = row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value.map(JsValue::new).unwrap_or_else(JsValue::null))
|
|
}
|
|
"TEXT" | "VARCHAR" | "BPCHAR" | "CHAR" | "UUID" | "INET" | "CIDR" => {
|
|
let value: Option<String> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|text| JsValue::from(JsString::from(text)))
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
"JSON" | "JSONB" => {
|
|
let value: Option<JsonValue> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
match value {
|
|
Some(json) => JsValue::from_json(&json, context),
|
|
None => Ok(JsValue::null()),
|
|
}
|
|
}
|
|
"TIMESTAMP" => {
|
|
let value: Option<NaiveDateTime> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|ts| {
|
|
let dt = DateTime::<Utc>::from_naive_utc_and_offset(ts, Utc);
|
|
JsValue::from(JsString::from(dt.to_rfc3339()))
|
|
})
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
"TIMESTAMPTZ" => {
|
|
let value: Option<DateTime<Utc>> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|ts| JsValue::from(JsString::from(ts.to_rfc3339())))
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
"DATE" => {
|
|
let value: Option<NaiveDate> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|date| JsValue::from(JsString::from(date.to_string())))
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
"TIME" | "TIMETZ" => {
|
|
let value: Option<NaiveTime> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|time| JsValue::from(JsString::from(time.to_string())))
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
"BYTEA" => {
|
|
let value: Option<Vec<u8>> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|bytes| {
|
|
let encoded = BASE64_STANDARD.encode(bytes);
|
|
JsValue::from(JsString::from(encoded))
|
|
})
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
_ => {
|
|
let value: Option<String> =
|
|
row.try_get(index).map_err(Self::column_access_error)?;
|
|
Ok(value
|
|
.map(|text| JsValue::from(JsString::from(text)))
|
|
.unwrap_or_else(JsValue::null))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn result_to_js(result: PgQueryResult, context: &mut Context) -> JsResult<JsValue> {
|
|
let mut initializer = ObjectInitializer::new(context);
|
|
initializer.property(
|
|
js_string!("rowsAffected"),
|
|
result.rows_affected() as f64,
|
|
Attribute::READONLY | Attribute::ENUMERABLE,
|
|
);
|
|
|
|
Ok(initializer.build().into())
|
|
}
|
|
|
|
fn missing_argument(name: &str, method: &str) -> JsError {
|
|
JsNativeError::typ()
|
|
.with_message(format!("{method}: missing required argument `{name}`"))
|
|
.into()
|
|
}
|
|
|
|
fn column_access_error(err: sqlx::Error) -> JsError {
|
|
JsNativeError::error()
|
|
.with_message(format!("sql: failed to read column value: {err}"))
|
|
.into()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum SqlParam {
|
|
Int(i64),
|
|
Float(f64),
|
|
Bool(bool),
|
|
Text(String),
|
|
Json(JsonValue),
|
|
Null,
|
|
}
|
|
|
|
impl SqlParam {
|
|
fn bind<'q>(
|
|
self,
|
|
query: sqlx::query::Query<'q, Postgres, PgArguments>,
|
|
) -> sqlx::query::Query<'q, Postgres, PgArguments> {
|
|
match self {
|
|
SqlParam::Int(value) => query.bind(value),
|
|
SqlParam::Float(value) => query.bind(value),
|
|
SqlParam::Bool(value) => query.bind(value),
|
|
SqlParam::Text(value) => query.bind(value),
|
|
SqlParam::Json(value) => query.bind(Json(value)),
|
|
SqlParam::Null => query.bind(Option::<String>::None),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum SqlExecutionError {
|
|
Sql(sqlx::Error),
|
|
}
|
|
|
|
impl From<sqlx::Error> for SqlExecutionError {
|
|
fn from(value: sqlx::Error) -> Self {
|
|
Self::Sql(value)
|
|
}
|
|
}
|
|
|
|
impl SqlExecutionError {
|
|
fn into_js_error(self, method: &str) -> JsError {
|
|
match self {
|
|
Self::Sql(err) => JsNativeError::error()
|
|
.with_message(format!("{method}: database error: {err}"))
|
|
.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<SqlExecutionError> for JsError {
|
|
fn from(value: SqlExecutionError) -> Self {
|
|
value.into_js_error("sql")
|
|
}
|
|
}
|