use base64ct::{Base64, Encoding}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use sqlx::{ sqlite::{SqlitePool, SqliteRow}, Row, }; use std::ops::Deref; use std::path::PathBuf; use thiserror::Error; use uuid::Uuid; #[derive(Debug, Error)] pub enum AuthError { #[error("authentication token is duplicated")] DuplicateAuthToken, #[error("session token is duplicated")] DuplicateSessionToken, #[error("database failed")] SqlError(sqlx::Error), } impl From for AuthError { fn from(err: sqlx::Error) -> AuthError { AuthError::SqlError(err) } } #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct Username(String); impl From for Username { fn from(s: String) -> Self { Self(s) } } impl From<&str> for Username { fn from(s: &str) -> Self { Self(s.to_owned()) } } impl From for String { fn from(s: Username) -> Self { Self::from(&s) } } impl From<&Username> for String { fn from(s: &Username) -> Self { let Username(s) = s; Self::from(s) } } impl Deref for Username { type Target = String; fn deref(&self) -> &Self::Target { &self.0 } } impl sqlx::FromRow<'_, SqliteRow> for Username { fn from_row(row: &SqliteRow) -> sqlx::Result { let name: String = row.try_get("username")?; Ok(Username::from(name)) } } #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct AuthToken(String); impl From for AuthToken { fn from(s: String) -> Self { Self(s) } } impl From<&str> for AuthToken { fn from(s: &str) -> Self { Self(s.to_owned()) } } impl From for PathBuf { fn from(s: AuthToken) -> Self { Self::from(&s) } } impl From<&AuthToken> for PathBuf { fn from(s: &AuthToken) -> Self { let AuthToken(s) = s; Self::from(s) } } impl Deref for AuthToken { type Target = String; fn deref(&self) -> &Self::Target { &self.0 } } #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct SessionToken(String); impl From for SessionToken { fn from(s: String) -> Self { Self(s) } } impl From<&str> for SessionToken { fn from(s: &str) -> Self { Self(s.to_owned()) } } impl From for PathBuf { fn from(s: SessionToken) -> Self { Self::from(&s) } } impl From<&SessionToken> for PathBuf { fn from(s: &SessionToken) -> Self { let SessionToken(s) = s; Self::from(s) } } impl Deref for SessionToken { type Target = String; fn deref(&self) -> &Self::Target { &self.0 } } #[derive(Clone)] pub struct AuthDB { pool: SqlitePool, } impl AuthDB { pub async fn new(path: PathBuf) -> Result { let migrator = sqlx::migrate!("./migrations"); let pool = SqlitePool::connect(&format!("sqlite://{}", path.to_str().unwrap())).await?; migrator.run(&pool).await?; Ok(Self { pool }) } pub async fn add_user(&self, username: Username) -> Result { let mut hasher = Sha256::new(); hasher.update(Uuid::new_v4().hyphenated().to_string()); hasher.update(username.to_string()); let auth_token = Base64::encode_string(&hasher.finalize()); let _ = sqlx::query("INSERT INTO users (username, token) VALUES ($1, $2)") .bind(username.to_string()) .bind(auth_token.clone()) .execute(&self.pool) .await?; Ok(AuthToken::from(auth_token)) } pub async fn list_users(&self) -> Result, AuthError> { let usernames = sqlx::query_as::<_, Username>("SELECT (username) FROM users") .fetch_all(&self.pool) .await?; Ok(usernames) } pub async fn authenticate(&self, token: AuthToken) -> Result, AuthError> { let results = sqlx::query("SELECT * FROM users WHERE token = $1") .bind(token.to_string()) .fetch_all(&self.pool) .await?; if results.len() > 1 { return Err(AuthError::DuplicateAuthToken); } if results.is_empty() { return Ok(None); } let user_id: i64 = results[0].try_get("id")?; let mut hasher = Sha256::new(); hasher.update(Uuid::new_v4().hyphenated().to_string()); hasher.update(token.to_string()); let session_token = Base64::encode_string(&hasher.finalize()); let _ = sqlx::query("INSERT INTO sessions (token, user_id) VALUES ($1, $2)") .bind(session_token.clone()) .bind(user_id) .execute(&self.pool) .await?; Ok(Some(SessionToken::from(session_token))) } pub async fn validate_session( &self, token: SessionToken, ) -> Result, AuthError> { let rows = sqlx::query( "SELECT users.username FROM sessions INNER JOIN users ON sessions.user_id = users.id WHERE sessions.token = $1", ) .bind(token.to_string()) .fetch_all(&self.pool) .await?; if rows.len() > 1 { return Err(AuthError::DuplicateSessionToken); } if rows.is_empty() { return Ok(None); } let username: String = rows[0].try_get("username")?; Ok(Some(Username::from(username))) } } #[cfg(test)] mod tests { use super::*; use cool_asserts::assert_matches; use std::collections::HashSet; #[tokio::test] async fn can_create_and_list_users() { let db = AuthDB::new(PathBuf::from(":memory:")) .await .expect("a memory-only database will be created"); let _ = db .add_user(Username::from("savanni")) .await .expect("user to be created"); assert_matches!(db.list_users().await, Ok(names) => { let names = names.into_iter().collect::>(); assert!(names.contains(&Username::from("savanni"))); }) } #[tokio::test] async fn unknown_auth_token_returns_nothing() { let db = AuthDB::new(PathBuf::from(":memory:")) .await .expect("a memory-only database will be created"); let _ = db .add_user(Username::from("savanni")) .await .expect("user to be created"); let token = AuthToken::from("0000000000"); assert_matches!(db.authenticate(token).await, Ok(None)); } #[tokio::test] async fn auth_token_becomes_session_token() { let db = AuthDB::new(PathBuf::from(":memory:")) .await .expect("a memory-only database will be created"); let token = db .add_user(Username::from("savanni")) .await .expect("user to be created"); assert_matches!(db.authenticate(token).await, Ok(_)); } #[tokio::test] async fn can_validate_session_token() { let db = AuthDB::new(PathBuf::from(":memory:")) .await .expect("a memory-only database will be created"); let token = db .add_user(Username::from("savanni")) .await .expect("user to be created"); let session = db .authenticate(token) .await .expect("token authentication should succeed") .expect("session token should be found"); assert_matches!( db.validate_session(session).await, Ok(Some(username)) => { assert_eq!(username, Username::from("savanni")); }); } }