Add the ability to create and list users

This commit is contained in:
Savanni D'Gerinel 2023-10-03 13:31:58 -04:00
parent 6aedff8cda
commit 4a7d741224

View File

@ -1,9 +1,15 @@
use base64ct::{Base64, Encoding};
use serde::{Deserialize, Serialize};
use sqlx::sqlite::SqlitePool;
use sha2::{Digest, Sha256};
use sqlx::{
sqlite::{SqlitePool, SqliteRow},
Executor, Row,
};
use std::collections::HashSet;
use std::{ops::Deref, path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio::sync::RwLock;
use uuid::Uuid;
mod filehandle;
mod fileinfo;
@ -62,7 +68,13 @@ pub enum ReadFileError {
#[derive(Debug, Error)]
pub enum AuthError {
#[error("database failed")]
SqlError,
SqlError(sqlx::Error),
}
impl From<sqlx::Error> for AuthError {
fn from(err: sqlx::Error) -> AuthError {
AuthError::SqlError(err)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
@ -80,13 +92,13 @@ impl From<&str> for Username {
}
}
impl From<Username> for PathBuf {
impl From<Username> for String {
fn from(s: Username) -> Self {
Self::from(&s)
}
}
impl From<&Username> for PathBuf {
impl From<&Username> for String {
fn from(s: &Username) -> Self {
let Username(s) = s;
Self::from(s)
@ -100,6 +112,13 @@ impl Deref for Username {
}
}
impl sqlx::FromRow<'_, SqliteRow> for Username {
fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
let name: String = row.try_get("username")?;
Ok(Username::from(name))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct AuthToken(String);
@ -263,19 +282,47 @@ pub struct AuthDB {
impl AuthDB {
pub async fn new(path: PathBuf) -> Result<Self, sqlx::Error> {
let migrator = sqlx::migrate!("./migrations");
let pool = SqlitePool::connect(&format!("sqlite://{}", path.to_str().unwrap())).await?;
migrator.run(&pool).await?;
Ok(Self { pool })
}
async fn add_user(&self, username: Username) -> Result<AuthToken, AuthError> {
let auth_plaintext = format!("{}:{}", Uuid::new_v4(), username.to_string());
let mut hasher = Sha256::new();
hasher.update(auth_plaintext);
let auth_token = Base64::encode_string(&hasher.finalize());
let _ = sqlx::query("INSERT INTO users (username, token) VALUES ($1, $2)")
.bind(username.to_string())
.bind(auth_token.clone())
.execute(&self.pool)
.await?;
Ok(AuthToken::from(auth_token))
}
async fn list_users(&self) -> Result<Vec<Username>, AuthError> {
let usernames = sqlx::query_as::<_, Username>("SELECT (username) FROM users")
.fetch_all(&self.pool)
.await?;
/*
let usernames = result
.into_iter()
.map(|row| Username::from(row.column(0)))
.collect::<Vec<Username>>();
*/
Ok(usernames)
}
async fn auth_token(&self, _token: AuthToken) -> Result<SessionToken, AuthError> {
unimplemented!()
}
async fn auth_session(&self, _token: SessionToken) -> Result<Username, AuthError> {
/*
let conn = self.pool.acquire().await.map_err(|_| AuthError::SqlError)?;
conn.transaction(|tr| {})
*/
unimplemented!()
}
}
@ -417,3 +464,34 @@ mod test {
});
}
}
#[cfg(test)]
mod authdb_test {
use super::*;
use cool_asserts::assert_matches;
#[tokio::test]
async fn can_create_and_list_users() {
let db = AuthDB::new(PathBuf::from(":memory:"))
.await
.expect("a memory-only database will be created");
let _ = db
.add_user(Username::from("savanni"))
.await
.expect("user to be created");
assert_matches!(db.list_users().await, Ok(names) => {
let names = names.into_iter().collect::<HashSet<Username>>();
assert!(names.contains(&Username::from("savanni")));
})
}
#[test]
fn can_authenticate_token() {
unimplemented!()
}
#[test]
fn can_validate_session_token() {
unimplemented!()
}
}