use actix_web::{ dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, Error, HttpMessage, HttpRequest, }; use dashmap::DashMap; use futures_util::future::{ready, LocalBoxFuture, Ready}; use std::sync::Arc; use uuid::Uuid; const SESSION_HEADER: &str = "X-Session-Id"; #[derive(Clone)] pub struct SessionManager { sessions: Arc>, } #[derive(Clone)] pub struct SessionRecord { pub user_id: Uuid, pub username: String, } #[derive(Clone)] pub struct SessionClaims { pub session_id: Uuid, pub user_id: Uuid, pub username: String, } impl SessionManager { pub fn new() -> Self { Self { sessions: Arc::new(DashMap::new()), } } pub fn create_session(&self, user_id: Uuid, username: String) -> Uuid { let session_id = Uuid::new_v4(); self.sessions.insert( session_id, SessionRecord { user_id, username, }, ); session_id } pub fn get(&self, session_id: &Uuid) -> Option { self.sessions.get(session_id).map(|entry| entry.clone()) } } #[derive(Clone)] pub struct SessionLayer { manager: SessionManager, } impl SessionLayer { pub fn new(manager: SessionManager) -> Self { Self { manager } } pub fn claims(req: &HttpRequest) -> Option { req.extensions().get::().cloned() } } impl Transform for SessionLayer where S: Service, Error = Error> + 'static, S::Future: 'static, { type Response = ServiceResponse; type Error = Error; type InitError = (); type Transform = SessionMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(SessionMiddleware { service, manager: self.manager.clone(), })) } } pub struct SessionMiddleware { service: S, manager: SessionManager, } impl Service for SessionMiddleware where S: Service, Error = Error> + 'static, S::Future: 'static, { type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; forward_ready!(service); fn call(&self, mut req: ServiceRequest) -> Self::Future { let manager = self.manager.clone(); if let Some(claims) = extract_claims(&req, &manager) { req.extensions_mut().insert(claims); } let fut = self.service.call(req); Box::pin(async move { let res = fut.await?; Ok(res) }) } } fn extract_claims(req: &ServiceRequest, manager: &SessionManager) -> Option { let session_id = req .headers() .get(SESSION_HEADER) .and_then(|value| value.to_str().ok()) .and_then(|value| Uuid::parse_str(value).ok())?; manager.get(&session_id).map(|record| SessionClaims { session_id, user_id: record.user_id, username: record.username, }) }