diff --git a/Makefile b/Makefile index 6f4e8a3..83529d0 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,16 @@ -.PHONY: server-dev client-dev client-test +.PHONY: server-dev server-test client-dev client-test + +test: + cd server && make test-oneshot + cd v-client && make test server-dev: cd server && make dev +server-test: + cd server && make test + client-dev: cd v-client && make dev diff --git a/server/Cargo.lock b/server/Cargo.lock index b37df9e..c7ad80e 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -705,6 +705,7 @@ dependencies = [ "rand", "rusqlite", "sha2", + "tempfile", "thiserror", "tokio", "uuid", diff --git a/server/Cargo.toml b/server/Cargo.toml index 205751c..fefa208 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -15,3 +15,5 @@ tokio = { version = "1", features = ["full"] } uuid = { version = "0.8", features = ["v4"] } warp = { version = "0.3" } +[dev-dependencies] +tempfile = { version = "3" } diff --git a/server/Makefile b/server/Makefile index a331ce0..4bad9a0 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,3 +3,9 @@ dev: cargo watch -x run + +test: + cargo watch -x test + +test-oneshot: + cargo test diff --git a/server/src/bin/create-invitation.rs b/server/src/bin/create-invitation.rs new file mode 100644 index 0000000..3af16a2 --- /dev/null +++ b/server/src/bin/create-invitation.rs @@ -0,0 +1,3 @@ +fn main() { + println!("There is a tool"); +} diff --git a/server/src/database.rs b/server/src/database.rs new file mode 100644 index 0000000..2195fdc --- /dev/null +++ b/server/src/database.rs @@ -0,0 +1,111 @@ +use rusqlite::{params, Connection}; +use std::{ + ops::{Deref, DerefMut}, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +pub struct ManagedConnection<'a> { + pool: &'a Database, + connection: Option, +} + +pub struct Database { + file_path: PathBuf, + pool: Arc>>, +} + +impl Database { + pub fn new(file_path: PathBuf) -> Result { + let mut connection = Connection::open(file_path.clone())?; + + let tx = connection.transaction()?; + let version: i32 = tx.pragma_query_value(None, "user_version", |r| r.get(0))?; + if version == 0 { + tx.execute_batch( + "CREATE TABLE users (id STRING PRIMARY KEY, name TEXT); + CREATE TABLE invitations (token STRING PRIMARY KEY, user_id STRING, FOREIGN KEY(user_id) REFERENCES users(id)); + CREATE TABLE sessions (token STRING PRIMARY KEY, user_id STRING, FOREIGN KEY(user_id) REFERENCES users(id)); + PRAGMA user_version = 1;", + )?; + } + tx.commit()?; + + Ok(Database { + file_path, + pool: Arc::new(Mutex::new(vec![connection])), + }) + } + + pub fn connect<'a>(&'a self) -> Result, anyhow::Error> { + let mut pool = self.pool.lock().unwrap(); + match pool.pop() { + Some(connection) => Ok(ManagedConnection { + pool: &self, + connection: Some(connection), + }), + None => { + let connection = Connection::open(self.file_path.clone())?; + Ok(ManagedConnection { + pool: &self, + connection: Some(connection), + }) + } + } + } + + pub fn release(&self, connection: Connection) { + let mut pool = self.pool.lock().unwrap(); + pool.push(connection); + } +} + +impl Deref for ManagedConnection<'_> { + type Target = Connection; + + fn deref(&self) -> &Connection { + self.connection.as_ref().unwrap() + } +} + +impl DerefMut for ManagedConnection<'_> { + fn deref_mut(&mut self) -> &mut Connection { + self.connection.as_mut().unwrap() + } +} + +impl Drop for ManagedConnection<'_> { + fn drop(&mut self) { + self.pool.release(self.connection.take().unwrap()); + } +} + +#[cfg(test)] +mod test { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn it_can_create_users() { + let path = NamedTempFile::new().unwrap().into_temp_path(); + let database = Database::new(path.to_path_buf()).unwrap(); + let mut connection = database.connect().unwrap(); + let tr = connection.transaction().unwrap(); + tr.execute( + "INSERT INTO users VALUES(?, ?)", + params!["abcdefg", "mercer"], + ) + .unwrap(); + tr.commit().unwrap(); + + let connection = database.connect().unwrap(); + let id: Option = connection + .query_row( + "SELECT id FROM users WHERE name = ?", + [String::from("mercer")], + |row| row.get("id"), + ) + .unwrap(); + assert_eq!(id, Some(String::from("abcdefg"))); + } +} diff --git a/server/src/main.rs b/server/src/main.rs index b39350e..95b8075 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,6 +1,8 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use warp::Filter; +mod database; + #[tokio::main] pub async fn main() { let echo_unauthenticated = warp::path!("api" / "v1" / "echo" / String).map(|param: String| {