227 lines
6.9 KiB
Rust
227 lines
6.9 KiB
Rust
|
|
use crate::{services::replicant::ReplicantService, utils::AppError};
|
||
|
|
use base64::prelude::*;
|
||
|
|
use dollhouse_api_types::FirmwareOutputResponse;
|
||
|
|
use dollhouse_db::{Pool, repositories::ReplicantRepository};
|
||
|
|
use mlua::{Lua, Table, Value};
|
||
|
|
use std::sync::Arc;
|
||
|
|
use tokio::time::{Duration, timeout};
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
pub struct LuaService;
|
||
|
|
|
||
|
|
const MEMORY_LIMIT: usize = 10 * 1024 * 1024;
|
||
|
|
const TIME_LIMIT_MS: u64 = 1000;
|
||
|
|
|
||
|
|
impl LuaService {
|
||
|
|
fn create_lua_instance() -> Result<Lua, AppError> {
|
||
|
|
let lua = Lua::new();
|
||
|
|
lua.set_memory_limit(MEMORY_LIMIT)?;
|
||
|
|
Ok(lua)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn setup_sandbox(lua: &Lua) -> Result<(), AppError> {
|
||
|
|
let globals = lua.globals();
|
||
|
|
|
||
|
|
let dangerous_libs = [
|
||
|
|
"os",
|
||
|
|
"io",
|
||
|
|
"debug",
|
||
|
|
"load",
|
||
|
|
"loadstring",
|
||
|
|
"dofile",
|
||
|
|
"loadfile",
|
||
|
|
];
|
||
|
|
for lib in &dangerous_libs {
|
||
|
|
globals.set(*lib, Value::Nil)?;
|
||
|
|
}
|
||
|
|
|
||
|
|
let g_mt = lua.create_table()?;
|
||
|
|
|
||
|
|
let allowed_globals = vec![
|
||
|
|
"_VERSION",
|
||
|
|
"print",
|
||
|
|
"type",
|
||
|
|
"assert",
|
||
|
|
"error",
|
||
|
|
"pairs",
|
||
|
|
"ipairs",
|
||
|
|
"next",
|
||
|
|
"select",
|
||
|
|
"pcall",
|
||
|
|
"xpcall",
|
||
|
|
"table",
|
||
|
|
"string",
|
||
|
|
"math",
|
||
|
|
"tonumber",
|
||
|
|
"tostring",
|
||
|
|
"setmetatable",
|
||
|
|
"getmetatable",
|
||
|
|
"rawset",
|
||
|
|
"rawget",
|
||
|
|
"rawequal",
|
||
|
|
];
|
||
|
|
|
||
|
|
let allowed_globals_clone1 = allowed_globals.clone();
|
||
|
|
|
||
|
|
g_mt.set(
|
||
|
|
"__newindex",
|
||
|
|
lua.create_function(move |_, (t, name, value): (Table, String, Value)| {
|
||
|
|
if !allowed_globals_clone1.contains(&name.as_str()) {
|
||
|
|
return Err(mlua::Error::RuntimeError(format!(
|
||
|
|
"Security: creating global '{}' is not allowed",
|
||
|
|
name
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
t.raw_set(name, value)?;
|
||
|
|
Ok(())
|
||
|
|
})?,
|
||
|
|
)?;
|
||
|
|
|
||
|
|
let allowed_globals_clone2 = allowed_globals.clone();
|
||
|
|
let dangerous = vec!["io", "os", "debug", "package"];
|
||
|
|
|
||
|
|
g_mt.set(
|
||
|
|
"__index",
|
||
|
|
lua.create_function(move |lua, (t, name): (Table, String)| {
|
||
|
|
if dangerous.contains(&name.as_str()) {
|
||
|
|
return Err(mlua::Error::RuntimeError(format!(
|
||
|
|
"Security: access to '{}' is prohibited",
|
||
|
|
name
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
if allowed_globals_clone2.contains(&name.as_str()) {
|
||
|
|
let globals = lua.globals();
|
||
|
|
return Ok(globals.raw_get::<Value>(name)?);
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(Value::Nil)
|
||
|
|
})?,
|
||
|
|
)?;
|
||
|
|
|
||
|
|
globals.set_metatable(Some(g_mt));
|
||
|
|
|
||
|
|
Self::setup_safe_print(lua)?;
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn setup_safe_print(lua: &Lua) -> Result<(), AppError> {
|
||
|
|
let safe_print = lua.create_function(|_, args: mlua::MultiValue| {
|
||
|
|
let output: String = args
|
||
|
|
.into_iter()
|
||
|
|
.map(|v| v.to_string())
|
||
|
|
.collect::<Result<Vec<_>, _>>()?
|
||
|
|
.join("\t");
|
||
|
|
|
||
|
|
println!("{}", output);
|
||
|
|
Ok(())
|
||
|
|
})?;
|
||
|
|
|
||
|
|
lua.globals().set("print", safe_print)?;
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn execute_with_timeout(
|
||
|
|
lua: Arc<Lua>,
|
||
|
|
bytecode: Vec<u8>,
|
||
|
|
) -> Result<mlua::Value, AppError> {
|
||
|
|
let task = tokio::task::spawn_blocking(move || {
|
||
|
|
let lua_clone = Arc::clone(&lua);
|
||
|
|
lua_clone
|
||
|
|
.load(&bytecode)
|
||
|
|
.set_name("[[user_firmware]]")
|
||
|
|
.eval()
|
||
|
|
});
|
||
|
|
|
||
|
|
match timeout(Duration::from_millis(TIME_LIMIT_MS), task).await {
|
||
|
|
Ok(Ok(result)) => result.map_err(|e| {
|
||
|
|
let err_str = e.to_string();
|
||
|
|
if err_str.contains("not enough memory") {
|
||
|
|
log::error!("Memory limit exceeded");
|
||
|
|
AppError::InternalServerError
|
||
|
|
} else {
|
||
|
|
AppError::LuaExecutionError(e)
|
||
|
|
}
|
||
|
|
}),
|
||
|
|
Ok(Err(join_err)) => {
|
||
|
|
log::error!("Join error: {}", join_err);
|
||
|
|
Err(AppError::InternalServerError)
|
||
|
|
}
|
||
|
|
Err(_) => {
|
||
|
|
tokio::task::yield_now().await;
|
||
|
|
Err(AppError::InternalServerError)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn value_to_string(value: mlua::Value) -> Result<String, AppError> {
|
||
|
|
match value {
|
||
|
|
mlua::Value::String(s) => Ok(s.to_str()?.to_string()),
|
||
|
|
mlua::Value::Nil => Ok("nil".to_string()),
|
||
|
|
mlua::Value::Boolean(b) => Ok(b.to_string()),
|
||
|
|
mlua::Value::Number(n) => Ok(n.to_string()),
|
||
|
|
mlua::Value::Integer(i) => Ok(i.to_string()),
|
||
|
|
mlua::Value::Table(t) => {
|
||
|
|
let mut parts = Vec::new();
|
||
|
|
for pair in t.pairs::<mlua::Value, mlua::Value>() {
|
||
|
|
let (key, value) = pair?;
|
||
|
|
parts.push(format!("{}: {}", key.to_string()?, value.to_string()?));
|
||
|
|
}
|
||
|
|
Ok(format!("{{{}}}", parts.join(", ")))
|
||
|
|
}
|
||
|
|
_ => Ok(format!("{:?}", value)),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn run(
|
||
|
|
pool: &Pool,
|
||
|
|
user_id: Uuid,
|
||
|
|
replicant_id: Uuid,
|
||
|
|
) -> Result<FirmwareOutputResponse, AppError> {
|
||
|
|
let lua = Arc::new(Self::create_lua_instance()?);
|
||
|
|
|
||
|
|
Self::setup_sandbox(&lua)?;
|
||
|
|
|
||
|
|
let mut conn = pool.get().await.map_err(|_| AppError::RepositoryError)?;
|
||
|
|
let replicant = ReplicantRepository::get(&mut conn, replicant_id)
|
||
|
|
.await
|
||
|
|
.map_err(|_| AppError::RepositoryError)?;
|
||
|
|
|
||
|
|
ReplicantService::check_replicant_access(
|
||
|
|
&mut conn,
|
||
|
|
user_id,
|
||
|
|
replicant.is_private,
|
||
|
|
replicant.corp_id,
|
||
|
|
)
|
||
|
|
.await?;
|
||
|
|
|
||
|
|
match replicant.firmware_file {
|
||
|
|
Some(filename) => {
|
||
|
|
let firmware_path = std::path::Path::new("firmware").join(filename);
|
||
|
|
let firmware_data = tokio::fs::read(&firmware_path).await.map_err(|e| {
|
||
|
|
log::error!(
|
||
|
|
"Failed to read firmware file from {:?}: {}",
|
||
|
|
firmware_path,
|
||
|
|
e
|
||
|
|
);
|
||
|
|
AppError::InternalServerError
|
||
|
|
})?;
|
||
|
|
|
||
|
|
if firmware_data.is_empty() {
|
||
|
|
return Err(AppError::InternalServerError);
|
||
|
|
}
|
||
|
|
|
||
|
|
let result = Self::execute_with_timeout(Arc::clone(&lua), firmware_data).await?;
|
||
|
|
let mut output = Self::value_to_string(result)?;
|
||
|
|
|
||
|
|
output = BASE64_STANDARD.encode(output);
|
||
|
|
|
||
|
|
Ok(FirmwareOutputResponse { output })
|
||
|
|
}
|
||
|
|
None => Err(AppError::NotFound),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|