use crate::{services::replicant::ReplicantService, utils::AppError}; use base64::prelude::*; use dollhouse_api_types::FirmwareOutputResponse; use dollhouse_db::{Pool, repositories::ReplicantRepository}; use mlua::{Lua, Table, Value}; use std::sync::Arc; use tokio::time::{Duration, timeout}; use uuid::Uuid; pub struct LuaService; const MEMORY_LIMIT: usize = 10 * 1024 * 1024; const TIME_LIMIT_MS: u64 = 1000; impl LuaService { fn create_lua_instance() -> Result { let lua = Lua::new(); lua.set_memory_limit(MEMORY_LIMIT)?; Ok(lua) } fn setup_sandbox(lua: &Lua) -> Result<(), AppError> { let globals = lua.globals(); let dangerous_libs = [ "os", "io", "debug", "load", "loadstring", "dofile", "loadfile", ]; for lib in &dangerous_libs { globals.set(*lib, Value::Nil)?; } let g_mt = lua.create_table()?; let allowed_globals = vec![ "_VERSION", "print", "type", "assert", "error", "pairs", "ipairs", "next", "select", "pcall", "xpcall", "table", "string", "math", "tonumber", "tostring", "setmetatable", "getmetatable", "rawset", "rawget", "rawequal", ]; let allowed_globals_clone1 = allowed_globals.clone(); g_mt.set( "__newindex", lua.create_function(move |_, (t, name, value): (Table, String, Value)| { if !allowed_globals_clone1.contains(&name.as_str()) { return Err(mlua::Error::RuntimeError(format!( "Security: creating global '{}' is not allowed", name ))); } t.raw_set(name, value)?; Ok(()) })?, )?; let allowed_globals_clone2 = allowed_globals.clone(); let dangerous = vec!["io", "os", "debug", "package"]; g_mt.set( "__index", lua.create_function(move |lua, (t, name): (Table, String)| { if dangerous.contains(&name.as_str()) { return Err(mlua::Error::RuntimeError(format!( "Security: access to '{}' is prohibited", name ))); } if allowed_globals_clone2.contains(&name.as_str()) { let globals = lua.globals(); return Ok(globals.raw_get::(name)?); } Ok(Value::Nil) })?, )?; globals.set_metatable(Some(g_mt)); Self::setup_safe_print(lua)?; Ok(()) } fn setup_safe_print(lua: &Lua) -> Result<(), AppError> { let safe_print = lua.create_function(|_, args: mlua::MultiValue| { let output: String = args .into_iter() .map(|v| v.to_string()) .collect::, _>>()? .join("\t"); println!("{}", output); Ok(()) })?; lua.globals().set("print", safe_print)?; Ok(()) } async fn execute_with_timeout( lua: Arc, bytecode: Vec, ) -> Result { let task = tokio::task::spawn_blocking(move || { let lua_clone = Arc::clone(&lua); lua_clone .load(&bytecode) .set_name("[[user_firmware]]") .eval() }); match timeout(Duration::from_millis(TIME_LIMIT_MS), task).await { Ok(Ok(result)) => result.map_err(|e| { let err_str = e.to_string(); if err_str.contains("not enough memory") { log::error!("Memory limit exceeded"); AppError::InternalServerError } else { AppError::LuaExecutionError(e) } }), Ok(Err(join_err)) => { log::error!("Join error: {}", join_err); Err(AppError::InternalServerError) } Err(_) => { tokio::task::yield_now().await; Err(AppError::InternalServerError) } } } fn value_to_string(value: mlua::Value) -> Result { match value { mlua::Value::String(s) => Ok(s.to_str()?.to_string()), mlua::Value::Nil => Ok("nil".to_string()), mlua::Value::Boolean(b) => Ok(b.to_string()), mlua::Value::Number(n) => Ok(n.to_string()), mlua::Value::Integer(i) => Ok(i.to_string()), mlua::Value::Table(t) => { let mut parts = Vec::new(); for pair in t.pairs::() { let (key, value) = pair?; parts.push(format!("{}: {}", key.to_string()?, value.to_string()?)); } Ok(format!("{{{}}}", parts.join(", "))) } _ => Ok(format!("{:?}", value)), } } pub async fn run( pool: &Pool, user_id: Uuid, replicant_id: Uuid, ) -> Result { let lua = Arc::new(Self::create_lua_instance()?); Self::setup_sandbox(&lua)?; let mut conn = pool.get().await.map_err(|_| AppError::RepositoryError)?; let replicant = ReplicantRepository::get(&mut conn, replicant_id) .await .map_err(|_| AppError::RepositoryError)?; ReplicantService::check_replicant_access( &mut conn, user_id, replicant.is_private, replicant.corp_id, ) .await?; match replicant.firmware_file { Some(filename) => { let firmware_path = std::path::Path::new("firmware").join(filename); let firmware_data = tokio::fs::read(&firmware_path).await.map_err(|e| { log::error!( "Failed to read firmware file from {:?}: {}", firmware_path, e ); AppError::InternalServerError })?; if firmware_data.is_empty() { return Err(AppError::InternalServerError); } let result = Self::execute_with_timeout(Arc::clone(&lua), firmware_data).await?; let mut output = Self::value_to_string(result)?; output = BASE64_STANDARD.encode(output); Ok(FirmwareOutputResponse { output }) } None => Err(AppError::NotFound), } } }