303 lines
7.6 KiB
Rust
303 lines
7.6 KiB
Rust
use base64ct::{Base64, Encoding};
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::{Digest, Sha256};
|
|
use sqlx::{
|
|
sqlite::{SqlitePool, SqliteRow},
|
|
Row,
|
|
};
|
|
use std::ops::Deref;
|
|
use std::path::PathBuf;
|
|
use thiserror::Error;
|
|
use uuid::Uuid;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum AuthError {
|
|
#[error("authentication token is duplicated")]
|
|
DuplicateAuthToken,
|
|
|
|
#[error("session token is duplicated")]
|
|
DuplicateSessionToken,
|
|
|
|
#[error("database failed")]
|
|
SqlError(sqlx::Error),
|
|
}
|
|
|
|
impl From<sqlx::Error> for AuthError {
|
|
fn from(err: sqlx::Error) -> AuthError {
|
|
AuthError::SqlError(err)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
|
pub struct Username(String);
|
|
|
|
impl From<String> for Username {
|
|
fn from(s: String) -> Self {
|
|
Self(s)
|
|
}
|
|
}
|
|
|
|
impl From<&str> for Username {
|
|
fn from(s: &str) -> Self {
|
|
Self(s.to_owned())
|
|
}
|
|
}
|
|
|
|
impl From<Username> for String {
|
|
fn from(s: Username) -> Self {
|
|
Self::from(&s)
|
|
}
|
|
}
|
|
|
|
impl From<&Username> for String {
|
|
fn from(s: &Username) -> Self {
|
|
let Username(s) = s;
|
|
Self::from(s)
|
|
}
|
|
}
|
|
|
|
impl Deref for Username {
|
|
type Target = String;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl sqlx::FromRow<'_, SqliteRow> for Username {
|
|
fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
|
|
let name: String = row.try_get("username")?;
|
|
Ok(Username::from(name))
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
|
pub struct AuthToken(String);
|
|
|
|
impl From<String> for AuthToken {
|
|
fn from(s: String) -> Self {
|
|
Self(s)
|
|
}
|
|
}
|
|
|
|
impl From<&str> for AuthToken {
|
|
fn from(s: &str) -> Self {
|
|
Self(s.to_owned())
|
|
}
|
|
}
|
|
|
|
impl From<AuthToken> for PathBuf {
|
|
fn from(s: AuthToken) -> Self {
|
|
Self::from(&s)
|
|
}
|
|
}
|
|
|
|
impl From<&AuthToken> for PathBuf {
|
|
fn from(s: &AuthToken) -> Self {
|
|
let AuthToken(s) = s;
|
|
Self::from(s)
|
|
}
|
|
}
|
|
|
|
impl Deref for AuthToken {
|
|
type Target = String;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
|
pub struct SessionToken(String);
|
|
|
|
impl From<String> for SessionToken {
|
|
fn from(s: String) -> Self {
|
|
Self(s)
|
|
}
|
|
}
|
|
|
|
impl From<&str> for SessionToken {
|
|
fn from(s: &str) -> Self {
|
|
Self(s.to_owned())
|
|
}
|
|
}
|
|
|
|
impl From<SessionToken> for PathBuf {
|
|
fn from(s: SessionToken) -> Self {
|
|
Self::from(&s)
|
|
}
|
|
}
|
|
|
|
impl From<&SessionToken> for PathBuf {
|
|
fn from(s: &SessionToken) -> Self {
|
|
let SessionToken(s) = s;
|
|
Self::from(s)
|
|
}
|
|
}
|
|
|
|
impl Deref for SessionToken {
|
|
type Target = String;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AuthDB {
|
|
pool: SqlitePool,
|
|
}
|
|
|
|
impl AuthDB {
|
|
pub async fn new(path: PathBuf) -> Result<Self, sqlx::Error> {
|
|
let migrator = sqlx::migrate!("./migrations");
|
|
let pool = SqlitePool::connect(&format!("sqlite://{}", path.to_str().unwrap())).await?;
|
|
migrator.run(&pool).await?;
|
|
Ok(Self { pool })
|
|
}
|
|
|
|
pub async fn add_user(&self, username: Username) -> Result<AuthToken, AuthError> {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(Uuid::new_v4().hyphenated().to_string());
|
|
hasher.update(username.to_string());
|
|
let auth_token = Base64::encode_string(&hasher.finalize());
|
|
|
|
let _ = sqlx::query("INSERT INTO users (username, token) VALUES ($1, $2)")
|
|
.bind(username.to_string())
|
|
.bind(auth_token.clone())
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(AuthToken::from(auth_token))
|
|
}
|
|
|
|
pub async fn list_users(&self) -> Result<Vec<Username>, AuthError> {
|
|
let usernames = sqlx::query_as::<_, Username>("SELECT (username) FROM users")
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(usernames)
|
|
}
|
|
|
|
pub async fn authenticate(&self, token: AuthToken) -> Result<Option<SessionToken>, AuthError> {
|
|
let results = sqlx::query("SELECT * FROM users WHERE token = $1")
|
|
.bind(token.to_string())
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
if results.len() > 1 {
|
|
return Err(AuthError::DuplicateAuthToken);
|
|
}
|
|
|
|
if results.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let user_id: i64 = results[0].try_get("id")?;
|
|
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(Uuid::new_v4().hyphenated().to_string());
|
|
hasher.update(token.to_string());
|
|
let session_token = Base64::encode_string(&hasher.finalize());
|
|
|
|
let _ = sqlx::query("INSERT INTO sessions (token, user_id) VALUES ($1, $2)")
|
|
.bind(session_token.clone())
|
|
.bind(user_id)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(Some(SessionToken::from(session_token)))
|
|
}
|
|
|
|
pub async fn validate_session(
|
|
&self,
|
|
token: SessionToken,
|
|
) -> Result<Option<Username>, AuthError> {
|
|
let rows = sqlx::query(
|
|
"SELECT users.username FROM sessions INNER JOIN users ON sessions.user_id = users.id WHERE sessions.token = $1",
|
|
)
|
|
.bind(token.to_string())
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
if rows.len() > 1 {
|
|
return Err(AuthError::DuplicateSessionToken);
|
|
}
|
|
|
|
if rows.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let username: String = rows[0].try_get("username")?;
|
|
Ok(Some(Username::from(username)))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use cool_asserts::assert_matches;
|
|
use std::collections::HashSet;
|
|
|
|
#[tokio::test]
|
|
async fn can_create_and_list_users() {
|
|
let db = AuthDB::new(PathBuf::from(":memory:"))
|
|
.await
|
|
.expect("a memory-only database will be created");
|
|
let _ = db
|
|
.add_user(Username::from("savanni"))
|
|
.await
|
|
.expect("user to be created");
|
|
assert_matches!(db.list_users().await, Ok(names) => {
|
|
let names = names.into_iter().collect::<HashSet<Username>>();
|
|
assert!(names.contains(&Username::from("savanni")));
|
|
})
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn unknown_auth_token_returns_nothing() {
|
|
let db = AuthDB::new(PathBuf::from(":memory:"))
|
|
.await
|
|
.expect("a memory-only database will be created");
|
|
let _ = db
|
|
.add_user(Username::from("savanni"))
|
|
.await
|
|
.expect("user to be created");
|
|
|
|
let token = AuthToken::from("0000000000");
|
|
|
|
assert_matches!(db.authenticate(token).await, Ok(None));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn auth_token_becomes_session_token() {
|
|
let db = AuthDB::new(PathBuf::from(":memory:"))
|
|
.await
|
|
.expect("a memory-only database will be created");
|
|
let token = db
|
|
.add_user(Username::from("savanni"))
|
|
.await
|
|
.expect("user to be created");
|
|
|
|
assert_matches!(db.authenticate(token).await, Ok(_));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn can_validate_session_token() {
|
|
let db = AuthDB::new(PathBuf::from(":memory:"))
|
|
.await
|
|
.expect("a memory-only database will be created");
|
|
let token = db
|
|
.add_user(Username::from("savanni"))
|
|
.await
|
|
.expect("user to be created");
|
|
let session = db
|
|
.authenticate(token)
|
|
.await
|
|
.expect("token authentication should succeed")
|
|
.expect("session token should be found");
|
|
|
|
assert_matches!(
|
|
db.validate_session(session).await,
|
|
Ok(Some(username)) => {
|
|
assert_eq!(username, Username::from("savanni"));
|
|
});
|
|
}
|
|
}
|