diff --git a/result-extended/src/lib.rs b/result-extended/src/lib.rs index e4aba62..fab07ee 100644 --- a/result-extended/src/lib.rs +++ b/result-extended/src/lib.rs @@ -156,6 +156,13 @@ pub fn fatal(err: FE) -> ResultExt { ResultExt::Fatal(err) } +pub fn result_as_fatal(result: Result) -> ResultExt { + match result { + Ok(a) => ResultExt::Ok(a), + Err(err) => ResultExt::Fatal(err), + } +} + /// Return early from the current function if the value is a fatal error. #[macro_export] macro_rules! return_fatal { diff --git a/visions/server/Taskfile.yml b/visions/server/Taskfile.yml index df7c99e..a02f7f4 100644 --- a/visions/server/Taskfile.yml +++ b/visions/server/Taskfile.yml @@ -3,7 +3,7 @@ version: '3' tasks: build: cmds: - - cargo build + - cargo watch -x build test: cmds: diff --git a/visions/server/src/core.rs b/visions/server/src/core.rs index 44396e9..d131afc 100644 --- a/visions/server/src/core.rs +++ b/visions/server/src/core.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use async_std::sync::RwLock; use chrono::{DateTime, TimeDelta, Utc}; use mime::Mime; -use result_extended::{error, fatal, ok, return_error, ResultExt}; +use result_extended::{error, fatal, ok, result_as_fatal, return_error, ResultExt}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use typeshare::typeshare; @@ -12,11 +12,8 @@ use uuid::Uuid; use crate::{ asset_db::{self, AssetId, Assets}, database::{CharacterId, Database, GameId, SessionId, UserId}, + types::AccountState, types::{AppError, FatalError, GameOverview, Message, Rgb, Tabletop, User, UserOverview}, - types::{ - AccountState, AppError, FatalError, GameOverview, Message, Rgb, Tabletop, User, - UserOverview, - }, }; const DEFAULT_BACKGROUND_COLOR: Rgb = Rgb { @@ -28,7 +25,7 @@ const DEFAULT_BACKGROUND_COLOR: Rgb = Rgb { #[derive(Clone, Serialize)] #[typeshare] pub struct Status { - pub admin_enabled: bool, + pub ok: bool, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -36,7 +33,8 @@ pub struct Status { #[typeshare] pub enum AuthResponse { Success(SessionId), - Expired, + PasswordReset(SessionId), + Locked, } #[derive(Debug)] @@ -73,6 +71,7 @@ impl Core { } pub async fn status(&self) -> ResultExt { + /* let state = self.0.write().await; let admin_user = return_error!(match state.db.user(&UserId::from("admin")).await { Ok(Some(admin_user)) => ok(admin_user), @@ -85,8 +84,10 @@ impl Core { }); ok(Status { - admin_enabled: !admin_user.password.is_empty(), + ok: !admin_user.password.is_empty(), }) + */ + ok(Status { ok: true }) } pub async fn register_client(&self) -> String { @@ -351,26 +352,26 @@ impl Core { ) -> ResultExt { let now = Utc::now(); let state = self.0.read().await; - match state.db.user_by_username(username).await { - Ok(Some(row)) if row.password == password => match row.state { - AccountState::Normal => match state.db.create_session(&row.id).await { - Ok(session_id) => ok(AuthResponse::Success(session_id)), - Err(err) => fatal(err), - }, - AccountState::PasswordReset(exp) => { - if exp < now { - ok(AuthResponse::Expired) - } else { - match state.db.create_session(&row.id).await { - Ok(session_id) => ok(AuthResponse::Success(session_id)), - Err(err) => fatal(err), - } - } - } - AccountState::Locked => error(AppError::AuthFailed), - }, + let user_info = return_error!(match state.db.user_by_username(username).await { + Ok(Some(row)) if row.password == password => ok(row), Ok(_) => error(AppError::AuthFailed), Err(err) => fatal(err), + }); + + match user_info.state { + AccountState::Normal => result_as_fatal(state.db.create_session(&user_info.id).await) + .map(|session_id| AuthResponse::Success(session_id)), + + AccountState::PasswordReset(exp) => { + if exp < now { + error(AppError::AuthFailed) + } else { + result_as_fatal(state.db.create_session(&user_info.id).await) + .map(|session_id| AuthResponse::PasswordReset(session_id)) + } + } + + AccountState::Locked => ok(AuthResponse::Locked), } } @@ -517,7 +518,8 @@ mod test { Err(err) => panic!("{}", err), } } - ResultExt::Ok(AuthResponse::Expired) => panic!("user has expired"), + ResultExt::Ok(AuthResponse::PasswordReset(_)) => panic!("user is in password reset state"), + ResultExt::Ok(AuthResponse::Locked) => panic!("user has been locked"), ResultExt::Err(err) => panic!("{}", err), ResultExt::Fatal(err) => panic!("{}", err), } diff --git a/visions/server/src/handlers/mod.rs b/visions/server/src/handlers/mod.rs index 31dda81..b90c483 100644 --- a/visions/server/src/handlers/mod.rs +++ b/visions/server/src/handlers/mod.rs @@ -45,7 +45,7 @@ where pub async fn healthcheck(core: Core) -> Vec { match core.status().await { ResultExt::Ok(s) => serde_json::to_vec(&HealthCheck { - ok: s.admin_enabled, + ok: s.ok, }) .unwrap(), ResultExt::Err(_) => serde_json::to_vec(&HealthCheck { ok: false }).unwrap(), diff --git a/visions/server/src/routes.rs b/visions/server/src/routes.rs index 8f297c2..430bca6 100644 --- a/visions/server/src/routes.rs +++ b/visions/server/src/routes.rs @@ -40,7 +40,11 @@ pub fn routes(core: Core) -> Router { "/api/v1/auth", post({ let core = core.clone(); - move |req: Json| wrap_handler(|| check_password(core, req)) + move |req: Json| wrap_handler(|| async { + let password_result = check_password(core, req).await; + println!("check_auth result: {:?}", password_result); + password_result + }) }) .layer( CorsLayer::new() @@ -117,7 +121,7 @@ pub fn routes(core: Core) -> Router { #[cfg(test)] mod test { - use std::path::PathBuf; + use std::{path::PathBuf, time::Duration}; use axum::http::StatusCode; use axum_test::TestServer; @@ -135,9 +139,10 @@ mod test { }; async fn initialize_test_server() -> (Core, TestServer) { + let password_exp = Utc::now() + Duration::from_secs(5); let memory_db: Option = None; let conn = DbConn::new(memory_db); - let admin_id = conn.create_user("admin", "aoeu", true, AccountState::PasswordReset(Utc::now())).await.unwrap(); + let _admin_id = conn.create_user("admin", "aoeu", true, AccountState::PasswordReset(password_exp)).await.unwrap(); let core = Core::new(FsAssets::new(PathBuf::from("/home/savanni/Pictures")), conn); let app = routes(core.clone()); let server = TestServer::new(app).unwrap(); @@ -223,8 +228,9 @@ mod test { let session_id = response.json::>().unwrap(); let session_id = match session_id { - AuthResponse::Success(session_id) => session_id, - AuthResponse::Expired => panic!("admin user is already expired"), + AuthResponse::PasswordReset(session_id) => session_id, + AuthResponse::Success(_) => panic!("admin user password has already been set"), + AuthResponse::Locked => panic!("admin user is already expired"), }; let response = server @@ -244,7 +250,7 @@ mod test { .await; response.assert_status_ok(); let session = response.json::>().unwrap(); - assert_matches!(session, AuthResponse::Expired); + assert_matches!(session, AuthResponse::PasswordReset(_)); } #[ignore] @@ -291,8 +297,7 @@ mod test { }) .await; response.assert_status_ok(); - let session_id: Option = response.json(); - assert!(session_id.is_some()); + assert_matches!(response.json(), Some(AuthResponse::PasswordReset(_))); } #[tokio::test] @@ -310,8 +315,7 @@ mod test { }) .await; response.assert_status_ok(); - let session_id: Option = response.json(); - let session_id = session_id.unwrap(); + let session_id = assert_matches!(response.json(), Some(AuthResponse::PasswordReset(session_id)) => session_id); println!("it_returns_user_profile: {}", session_id); let response = server @@ -321,7 +325,6 @@ mod test { response.assert_status_ok(); let profile: Option = response.json(); let profile = profile.unwrap(); - assert_eq!(profile.id, UserId::from("admin")); assert_eq!(profile.name, "admin"); } diff --git a/visions/server/src/types.rs b/visions/server/src/types.rs index f3a8bbf..cf7a342 100644 --- a/visions/server/src/types.rs +++ b/visions/server/src/types.rs @@ -76,6 +76,7 @@ pub struct Rgb { } #[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type", content = "content")] #[typeshare] pub enum AccountState { Normal, @@ -91,7 +92,6 @@ impl FromSql for AccountState { Ok(AccountState::Normal) } else if text.starts_with("PasswordReset") { let exp_str = text.strip_prefix("PasswordReset ").unwrap(); - println!("{}", exp_str); let exp = NaiveDateTime::parse_from_str(exp_str, "%Y-%m-%d %H:%M:%S") .unwrap() .and_utc();