use std::sync::Arc; use std::time::Duration; use tokio::{task, time}; const UDP_PORT: u16 = 5004; const PATCH_MAX_AGE: Duration = Duration::from_secs(30 * 60); const CLEAN_INTERVAL: Duration = Duration::from_secs(10 * 60); const SYSEX_TIMEOUT: Duration = Duration::from_millis(50); struct Config { storage: storage::BankStorage, udp_sock: tokio::net::UdpSocket, } #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) .init(); let storage = storage::BankStorage::new("data")?; let udp_sock = tokio::net::UdpSocket::bind(("0.0.0.0", UDP_PORT)).await?; let config = Arc::new(Config { storage, udp_sock }); tokio::spawn(cleaner_loop(config.clone())); listen_udp(config).await?; Ok(()) } #[tracing::instrument(skip(state), err)] async fn listen_udp(state: Arc) -> anyhow::Result<()> { tracing::info!("Listening for UDP SysEx on port {}", UDP_PORT); loop { let mut buf = vec![0u8; 2048]; let (len, addr) = state.udp_sock.recv_from(&mut buf).await?; let state_clone = state.clone(); tokio::spawn(async move { if let Err(e) = handle_sysex(&state_clone, addr, &buf[..len]).await { let response = sysex::Response::Error("Internal server error".to_string()); let _ = state_clone.udp_sock.send_to(&response.encode(), addr).await; tracing::warn!(?e, "sysex handling error"); } }); } } #[tracing::instrument(skip(state, frame), err)] async fn handle_sysex( state: &Arc, addr: std::net::SocketAddr, frame: &[u8], ) -> anyhow::Result<()> { let msg = sysex::parse_sysex_request(frame)?; tracing::debug!(?msg, "Handling a request."); let response = match msg { sysex::Request::Diag => sysex::Response::Diag, sysex::Request::Put { data } => { let lua_dsp = String::from_utf8_lossy(&data).to_string(); if let Err(error) = run_blocking_with_timeout(SYSEX_TIMEOUT, move || lua_sandbox::validate(&lua_dsp)) .await { let response = sysex::Response::Error(format!("Invalid lua script: {error:#}")); state.udp_sock.send_to(&response.encode(), addr).await?; return Ok(()); }; let private_key = crypto::PrivateKey::new_random(&crypto::DEFAULT_CRYPTO_PARAMS); let public_key = crypto::PublicKey::from_private_key(&private_key); let encrypted_code = public_key.encrypt(&data); let patch_id = uuid::Uuid::new_v4(); let patch_data = storage::PatchData { private_key: private_key.clone(), encrypted_code, }; state.storage.put_patch(patch_id, &patch_data)?; sysex::Response::PutResp { patch_id, private_key, } } sysex::Request::Get { patch_id } => { let (encrypted_code, private_key) = match state.storage.get_patch(patch_id) { Err(error) => { let response = sysex::Response::Error(format!("Couldn't read the patch: {error:#}")); state.udp_sock.send_to(&response.encode(), addr).await?; return Ok(()); } Ok(storage::PatchData { encrypted_code, private_key, }) => (encrypted_code, private_key), }; sysex::Response::GetResp { public_key: crypto::PublicKey::from_private_key(&private_key), encrypted_code, } } sysex::Request::RenderHash { patch_id, note, velocity, } => { let (encrypted_code, private_key) = match state.storage.get_patch(patch_id) { Err(error) => { let response = sysex::Response::Error(format!("Couldn't read the patch: {error:#}")); state.udp_sock.send_to(&response.encode(), addr).await?; return Ok(()); } Ok(storage::PatchData { encrypted_code, private_key, }) => (encrypted_code, private_key), }; let lua_dsp = String::from_utf8(private_key.decrypt(&encrypted_code)?)?; let resulting_hash = match run_blocking_with_timeout(SYSEX_TIMEOUT, move || { lua_sandbox::render_hash(&lua_dsp, note, velocity) }) .await { Ok(resulting_hash) => resulting_hash, Err(error) => { let response = sysex::Response::Error(format!("Couldn't compute hash: {error:#}")); state.udp_sock.send_to(&response.encode(), addr).await?; return Ok(()); } }; sysex::Response::RenderHashResp { resulting_hash } } sysex::Request::GetCryptoParams => sysex::Response::GetCryptoParamsResp { modulus: **crypto::DEFAULT_CRYPTO_PARAMS.modulus(), base: crypto::DEFAULT_CRYPTO_PARAMS.base().value(), }, }; send_response(state, addr, response).await?; Ok(()) } async fn send_response( state: &Arc, addr: std::net::SocketAddr, response: sysex::Response, ) -> anyhow::Result<()> { tracing::debug!(?response, "Sending a response."); let frame = response.encode(); state.udp_sock.send_to(&frame, addr).await?; Ok(()) } async fn cleaner_loop(config: Arc) { let mut ticker = time::interval(CLEAN_INTERVAL); loop { ticker.tick().await; let config = config.clone(); let cleanup = task::spawn_blocking(move || config.storage.clean_old_patches(PATCH_MAX_AGE)); match cleanup.await { Ok(Ok(removed)) => { if removed > 0 { tracing::info!(removed, "Removed expired patches"); } } Ok(Err(error)) => tracing::warn!(?error, "Patch cleanup failed"), Err(join_error) => tracing::warn!(?join_error, "Patch cleanup panicked"), } } } #[allow(dead_code)] async fn forward_to_postgres(query: String) { let body = query; let req = format!( "POST /postgres HTTP/1.1\r\nHost: localhost\r\nContent-Length: {}\r\n\r\n{}", body.len(), body ); if let Ok(stream) = tokio::net::TcpStream::connect("127.0.0.1:4000").await { let _ = stream.try_write(req.as_bytes()); } } #[allow(dead_code)] fn extract_payload(path: &str) -> String { if let Some(idx) = path.find("payload=") { let rest = &path[idx + "payload=".len()..]; rest.split('&').next().unwrap_or("").to_string() } else { "".into() } } #[allow(dead_code)] async fn handle_conn(socket: tokio::net::TcpStream) { let mut buf = vec![0u8; 4096]; let n = match socket.try_read(&mut buf) { Ok(n) if n > 0 => n, _ => return, }; let req = String::from_utf8_lossy(&buf[..n]); let first_line = req.lines().next().unwrap_or(""); let mut parts = first_line.split_whitespace(); let method = parts.next().unwrap_or(""); let path = parts.next().unwrap_or(""); if method == "GET" { let payload = extract_payload(path); let simulated = format!("SELECT * FROM test WHERE field = '{}'", payload); tokio::spawn(forward_to_postgres(simulated)); } let _ = socket.try_write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"); } async fn run_blocking_with_timeout( duration: Duration, f: impl FnOnce() -> anyhow::Result + Send + 'static, ) -> anyhow::Result { let beginning = time::Instant::now(); match time::timeout(duration, task::spawn_blocking(f)).await { Ok(join_result) => match join_result { Ok(result) => { tracing::debug!( elapsed = ?beginning.elapsed(), "Blocking task completed within timeout" ); result } Err(join_error) => Err(anyhow::anyhow!("Task panicked: {}", join_error)), }, Err(_) => Err(anyhow::anyhow!("Task timed out after {:?}", duration)), } }