From 772d2c01f42689f2e69c9248c0d9544c5c33957b Mon Sep 17 00:00:00 2001 From: Savanni D'Gerinel Date: Fri, 18 Nov 2022 10:02:30 -0500 Subject: [PATCH] Finish mocking the user -> session flow --- server/src/authentication.rs | 45 ++++++++++++----- server/src/errors.rs | 2 + server/src/main.rs | 93 +++++++++++++++++++++++++++++------- 3 files changed, 110 insertions(+), 30 deletions(-) diff --git a/server/src/authentication.rs b/server/src/authentication.rs index e356f23..568f269 100644 --- a/server/src/authentication.rs +++ b/server/src/authentication.rs @@ -1,5 +1,6 @@ -use crate::errors::{error, fatal, ok, AppResult}; +use crate::errors::{error, fatal, ok, AppResult, FatalError}; use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::{convert::Infallible, str::FromStr}; use thiserror::Error; @@ -20,7 +21,7 @@ pub enum AuthenticationError { UserNotFound, } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SessionToken(String); impl From<&str> for SessionToken { @@ -43,7 +44,7 @@ impl From for String { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Invitation(String); impl From<&str> for Invitation { @@ -66,7 +67,7 @@ impl From for String { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct UserId(String); impl From<&str> for UserId { @@ -100,7 +101,7 @@ impl FromSql for UserId { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Username(String); impl From<&str> for Username { @@ -156,9 +157,14 @@ pub trait AuthenticationDB: Send + Sync { fn delete_user(&mut self, user: UserId) -> AppResult<(), AuthenticationError>; - fn validate_session(&self, session: SessionToken) -> AppResult<(), AuthenticationError>; + fn validate_session( + &self, + session: SessionToken, + ) -> AppResult<(Username, UserId), AuthenticationError>; - fn get_user_id(&self, username: Username) -> AppResult; + fn get_userid(&self, username: Username) -> AppResult; + + fn get_username(&self, userid: UserId) -> AppResult; fn list_users(&self) -> AppResult, AuthenticationError>; } @@ -261,15 +267,22 @@ impl AuthenticationDB for MemoryAuth { } } - fn validate_session(&self, session: SessionToken) -> AppResult<(), AuthenticationError> { - if self.sessions.contains_key(&session) { - ok(()) + fn validate_session( + &self, + session: SessionToken, + ) -> AppResult<(Username, UserId), AuthenticationError> { + if let Some(userid) = self.sessions.get(&session) { + if let Some(username) = self.inverse_users.get(&userid) { + ok((username.clone(), userid.clone())) + } else { + fatal(FatalError::DatabaseInconsistency) + } } else { - error::<(), AuthenticationError>(AuthenticationError::InvalidSession) + error(AuthenticationError::InvalidSession) } } - fn get_user_id(&self, username: Username) -> AppResult { + fn get_userid(&self, username: Username) -> AppResult { Ok(self .users .get(&username) @@ -277,6 +290,14 @@ impl AuthenticationDB for MemoryAuth { .ok_or(AuthenticationError::UserNotFound)) } + fn get_username(&self, userid: UserId) -> AppResult { + Ok(self + .inverse_users + .get(&userid) + .map(|u| u.clone()) + .ok_or(AuthenticationError::UserNotFound)) + } + fn list_users(&self) -> AppResult, AuthenticationError> { ok(self.users.keys().cloned().collect::>()) } diff --git a/server/src/errors.rs b/server/src/errors.rs index 6a155de..36d4e02 100644 --- a/server/src/errors.rs +++ b/server/src/errors.rs @@ -5,6 +5,8 @@ use thiserror::Error; /// down and that the administrator fix a problem. #[derive(Debug, Error)] pub enum FatalError { + #[error("database is inconsistent")] + DatabaseInconsistency, #[error("disk is full")] DiskFull, #[error("io error: {0}")] diff --git a/server/src/main.rs b/server/src/main.rs index 83a8b7e..34820d8 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,3 +1,4 @@ +use errors::{ok, AppResult}; use serde::{Deserialize, Serialize}; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, @@ -6,7 +7,9 @@ use std::{ use warp::Filter; mod authentication; -use authentication::{AuthenticationDB, AuthenticationError, MemoryAuth, UserId, Username}; +use authentication::{ + AuthenticationDB, AuthenticationError, Invitation, MemoryAuth, SessionToken, UserId, Username, +}; mod database; mod errors; @@ -15,7 +18,7 @@ mod errors; struct AuthenticationRefused; impl warp::reject::Reject for AuthenticationRefused {} -fn with_authentication( +fn with_session( auth_ctx: Arc>, ) -> impl Filter + Clone { let auth_ctx = auth_ctx.clone(); @@ -25,13 +28,11 @@ fn with_authentication( let auth_ctx = auth_ctx.clone(); async move { if auth_header.starts_with("Basic ") { - let username = auth_header.split(" ").skip(1).collect::(); - match auth_ctx - .read() - .unwrap() - .get_user_id(Username::from(username.as_str())) - { - Ok(Ok(userid)) => Ok((Username::from(username.as_str()), userid)), + let session_token = SessionToken::from( + auth_header.split(" ").skip(1).collect::().as_str(), + ); + match auth_ctx.read().unwrap().validate_session(session_token) { + Ok(Ok((username, userid))) => Ok((username, userid)), Ok(Err(_)) => Err(warp::reject::custom(AuthenticationRefused)), Err(err) => panic!("{}", err), } @@ -49,13 +50,13 @@ struct ErrorResponse { } #[derive(Deserialize)] -struct MakeUserParameters { - username: String, +struct MakeUserParams { + username: Username, } #[derive(Serialize)] struct MakeUserResponse { - userid: String, + userid: UserId, } fn make_user( @@ -64,13 +65,67 @@ fn make_user( warp::path!("api" / "v1" / "users") .and(warp::put()) .and(warp::body::json()) - .map(move |params: MakeUserParameters| { + .map(move |params: MakeUserParams| { let mut auth_ctx = auth_ctx.write().unwrap(); - match (*auth_ctx).create_user(Username::from(params.username.as_str())) { - Ok(Ok(userid)) => warp::reply::json(&MakeUserResponse { - userid: String::from(userid), + match (*auth_ctx).create_user(Username::from(params.username)) { + Ok(Ok(userid)) => warp::reply::json(&MakeUserResponse { userid }), + Ok(Err(auth_error)) => warp::reply::json(&ErrorResponse { + error: format!("{:?}", auth_error), }), - Ok(auth_error) => warp::reply::json(&ErrorResponse { + Err(err) => panic!("{}", err), + } + }) +} + +#[derive(Deserialize)] +struct MakeInvitationParams { + userid: UserId, +} + +#[derive(Serialize)] +struct MakeInvitationResponse { + invitation: Invitation, +} + +fn make_invitation( + auth_ctx: Arc>, +) -> impl Filter + Clone { + warp::path!("api" / "v1" / "invitations") + .and(warp::put()) + .and(warp::body::json()) + .map(move |params: MakeInvitationParams| { + let mut auth_ctx = auth_ctx.write().unwrap(); + match (*auth_ctx).create_invitation(params.userid) { + Ok(Ok(invitation)) => warp::reply::json(&MakeInvitationResponse { invitation }), + Ok(Err(auth_error)) => warp::reply::json(&ErrorResponse { + error: format!("{:?}", auth_error), + }), + Err(err) => panic!("{}", err), + } + }) +} + +#[derive(Deserialize)] +struct AuthenticateParams { + invitation: Invitation, +} + +#[derive(Serialize)] +struct AuthenticateResponse { + session_token: SessionToken, +} + +fn authenticate( + auth_ctx: Arc>, +) -> impl Filter + Clone { + warp::path!("api" / "v1" / "authenticate") + .and(warp::put()) + .and(warp::body::json()) + .map(move |params: AuthenticateParams| { + let mut auth_ctx = auth_ctx.write().unwrap(); + match (*auth_ctx).authenticate(params.invitation) { + Ok(Ok(session_token)) => warp::reply::json(&AuthenticateResponse { session_token }), + Ok(Err(auth_error)) => warp::reply::json(&ErrorResponse { error: format!("{:?}", auth_error), }), Err(err) => panic!("{}", err), @@ -117,7 +172,7 @@ pub async fn main() { */ let echo_authenticated = warp::path!("api" / "v1" / "echo" / String) - .and(with_authentication(auth_ctx.clone())) + .and(with_session(auth_ctx.clone())) .map(|param: String, (username, userid)| { println!("param: {:?}", username); println!("param: {:?}", userid); @@ -127,6 +182,8 @@ pub async fn main() { let filter = list_users(auth_ctx.clone()) .or(make_user(auth_ctx.clone())) + .or(make_invitation(auth_ctx.clone())) + .or(authenticate(auth_ctx.clone())) .or(echo_authenticated) .or(echo_unauthenticated);