From 0237393c0b0ae95d122781292b11d3c51150e012 Mon Sep 17 00:00:00 2001 From: Savanni D'Gerinel Date: Mon, 18 Nov 2024 19:08:49 -0500 Subject: [PATCH] Set up a websocket that relays messages --- Cargo.lock | 11 +++--- visions/server/Cargo.toml | 3 ++ visions/server/src/core.rs | 55 ++++++++++++++++++++++++++-- visions/server/src/handlers.rs | 66 ++++++++++++++++++++++++++++------ visions/server/src/main.rs | 54 +++++++++++++++++++++++----- visions/server/src/types.rs | 20 +++++++++++ 6 files changed, 183 insertions(+), 26 deletions(-) create mode 100644 visions/server/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index e4a52eb..7ed4186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -522,7 +522,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" name = "changeset" version = "0.1.0" dependencies = [ - "uuid 1.10.0", + "uuid 0.8.2", ] [[package]] @@ -2357,7 +2357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -4815,9 +4815,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", ] @@ -4857,12 +4857,15 @@ name = "visions" version = "0.1.0" dependencies = [ "authdb", + "futures", "http 1.1.0", "mime 0.3.17", "mime_guess 2.0.5", "serde 1.0.210", "serde_json", "tokio", + "tokio-stream", + "uuid 1.11.0", "warp", ] diff --git a/visions/server/Cargo.toml b/visions/server/Cargo.toml index d0f2f14..94f0ccb 100644 --- a/visions/server/Cargo.toml +++ b/visions/server/Cargo.toml @@ -14,3 +14,6 @@ tokio = { version = "1", features = [ "full" ] } warp = { version = "0.3" } mime_guess = "2.0.5" mime = "0.3.17" +uuid = { version = "1.11.0", features = ["v4"] } +futures = "0.3.31" +tokio-stream = "0.1.16" diff --git a/visions/server/src/core.rs b/visions/server/src/core.rs index 62ec39f..b86ecaa 100644 --- a/visions/server/src/core.rs +++ b/visions/server/src/core.rs @@ -1,17 +1,25 @@ use std::{ + collections::HashMap, io::Read, path::PathBuf, sync::{Arc, RwLock}, }; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use uuid::Uuid; + +use crate::types::Message; + #[derive(Debug)] -pub enum AppError { - JsonError(serde_json::Error), +struct WebsocketClient { + sender: Option>, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct AppState { pub image_base: PathBuf, + + pub clients: HashMap, } #[derive(Clone, Debug)] @@ -21,9 +29,40 @@ impl Core { pub fn new() -> Self { Self(Arc::new(RwLock::new(AppState { image_base: PathBuf::from("/home/savanni/Pictures"), + clients: HashMap::new(), }))) } + pub fn register_client(&self) -> String { + let mut state = self.0.write().unwrap(); + let uuid = Uuid::new_v4().simple().to_string(); + + let client = WebsocketClient { sender: None }; + + state.clients.insert(uuid.clone(), client); + uuid + } + + pub fn unregister_client(&self, client_id: String) { + let mut state = self.0.write().unwrap(); + let _ = state.clients.remove(&client_id); + } + + pub fn connect_client(&self, client_id: String) -> UnboundedReceiver { + let mut state = self.0.write().unwrap(); + + match state.clients.get_mut(&client_id) { + Some(client) => { + let (tx, rx) = unbounded_channel(); + client.sender = Some(tx); + rx + } + None => { + unimplemented!(); + } + } + } + pub fn get_file(&self, file_name: String) -> Vec { let mut full_path = self.0.read().unwrap().image_base.clone(); full_path.push(&file_name); @@ -55,4 +94,14 @@ impl Core { }) .collect() } + + pub fn publish(&self, message: Message) { + let state = self.0.read().unwrap(); + + state.clients.values().for_each(|client| { + if let Some(ref sender) = client.sender { + let _ = sender.send(message.clone()); + } + }); + } } diff --git a/visions/server/src/handlers.rs b/visions/server/src/handlers.rs index 9c178b7..d42efff 100644 --- a/visions/server/src/handlers.rs +++ b/visions/server/src/handlers.rs @@ -1,11 +1,11 @@ -use std::{io::Read, path::PathBuf}; +use std::{pin::Pin, time::Duration}; -use authdb::{AuthDB, AuthToken}; -use http::{Error, StatusCode}; +use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; -use warp::{http::Response, reply::Reply}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use warp::{http::Response, reply::Reply, ws::Message}; -use crate::core::Core; +use crate::{core::Core, types::PlayArea}; /* pub async fn handle_auth( @@ -31,17 +31,12 @@ pub async fn handle_auth( } */ -#[derive(Deserialize, Serialize)] -pub struct PlayArea { - pub background_image: PathBuf, -} - pub async fn handle_playing_field() -> impl Reply { Response::builder() .header("application-type", "application/json") .body( serde_json::to_string(&PlayArea { - background_image: PathBuf::from("tower-in-mist.jpg"), + background_image: "tower-in-mist.jpg".to_owned(), }) .unwrap(), ) @@ -64,3 +59,52 @@ pub async fn handle_available_images(core: Core) -> impl Reply { .body(serde_json::to_string(&core.available_images()).unwrap()) .unwrap() } + +#[derive(Deserialize, Serialize)] +pub struct RegisterRequest {} + +#[derive(Deserialize, Serialize)] +pub struct RegisterResponse { + url: String, +} + +pub async fn handle_register_client(core: Core, request: RegisterRequest) -> impl Reply { + let client_id = core.register_client(); + + warp::reply::json(&RegisterResponse { + url: format!("ws://127.0.0.1:8001/ws/{}", client_id), + }) +} + +pub async fn handle_unregister_client(core: Core, client_id: String) -> impl Reply { + core.unregister_client(client_id); + + warp::reply::reply() +} + +pub async fn handle_connect_websocket( + core: Core, + ws: warp::ws::Ws, + client_id: String, +) -> impl Reply { + println!("handle_connect_websocket: {}", client_id); + ws.on_upgrade(move |socket| { + let core = core.clone(); + async move { + let (mut ws_sender, mut ws_recv) = socket.split(); + let mut receiver = core.connect_client(client_id); + + tokio::task::spawn(async move { + let mut i = 0; + ws_sender.send(Message::text(serde_json::to_string(&crate::types::Message::Count(0)).unwrap())).await; + while let Some(msg) = receiver.recv().await { + let _ = ws_sender + .send(Message::text( + serde_json::to_string(&msg).unwrap(), + )) + .await; + } + }); + } + }) +} diff --git a/visions/server/src/main.rs b/visions/server/src/main.rs index 4ba8fa0..cf58966 100644 --- a/visions/server/src/main.rs +++ b/visions/server/src/main.rs @@ -1,5 +1,8 @@ use authdb::AuthError; -use handlers::{handle_available_images, handle_file, handle_playing_field}; +use handlers::{ + handle_available_images, handle_connect_websocket, handle_file, handle_playing_field, + handle_register_client, handle_unregister_client, RegisterRequest, +}; use std::{ convert::Infallible, net::{IpAddr, Ipv4Addr, SocketAddr}, @@ -16,6 +19,8 @@ mod core; mod handlers; // use handlers::handle_auth; +mod types; + #[derive(Debug)] struct Unauthorized; impl warp::reject::Reject for Unauthorized {} @@ -99,21 +104,54 @@ pub async fn main() { .and(warp::get()) .then({ let core = core.clone(); - move |file_name| { - handle_file(core.clone(), file_name) - } + move |file_name| handle_file(core.clone(), file_name) }); let route_available_images = warp::path!("api" / "v1" / "image").and(warp::get()).then({ let core = core.clone(); - move || { - handle_available_images(core.clone()) - } + move || handle_available_images(core.clone()) }); - let filter = route_playing_field + let route_register_client = warp::path!("api" / "v1" / "client") + .and(warp::post()) + .then({ + let core = core.clone(); + move || handle_register_client(core.clone(), RegisterRequest {}) + }); + + let route_unregister_client = warp::path!("api" / "v1" / "client" / String) + .and(warp::delete()) + .then({ + let core = core.clone(); + move |client_id| handle_unregister_client(core.clone(), client_id) + }); + + let route_websocket = warp::path("ws") + .and(warp::ws()) + .and(warp::path::param()) + .then({ + let core = core.clone(); + move |ws, client_id| handle_connect_websocket(core.clone(), ws, client_id) + }); + + let route_publish = warp::path!("api" / "v1" / "message") + .and(warp::post()) + .and(warp::body::json()) + .map({ + let core = core.clone(); + move |body| { + core.publish(body); + warp::reply() + } + }); + + let filter = route_register_client + .or(route_unregister_client) + .or(route_websocket) + .or(route_playing_field) .or(route_image) .or(route_available_images) + .or(route_publish) .recover(handle_rejection); let server = warp::serve(filter); diff --git a/visions/server/src/types.rs b/visions/server/src/types.rs new file mode 100644 index 0000000..5990a8f --- /dev/null +++ b/visions/server/src/types.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PlayArea { + pub background_image: String, +} + +#[derive(Debug)] +pub enum AppError { + JsonError(serde_json::Error), +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub enum Message { + Count(u32), + // PlayArea(PlayArea), +} + +