Set up a bit of code that rejects requests that have no authorization header

This commit is contained in:
Savanni D'Gerinel 2024-12-29 23:39:43 -05:00
parent 085a82776e
commit e4c5ce0236
6 changed files with 106 additions and 40 deletions

View File

@ -22,6 +22,7 @@
name = "ld-tools-devshell"; name = "ld-tools-devshell";
buildInputs = [ buildInputs = [
pkgs.cargo-nextest pkgs.cargo-nextest
pkgs.cargo-watch
pkgs.clang pkgs.clang
pkgs.crate2nix pkgs.crate2nix
pkgs.glib pkgs.glib

View File

@ -1,3 +1,4 @@
[toolchain] [toolchain]
channel = "1.81.0" channel = "1.81.0"
targets = [ "wasm32-unknown-unknown", "thumbv6m-none-eabi" ] targets = [ "wasm32-unknown-unknown", "thumbv6m-none-eabi" ]
components = [ "rustfmt", "rust-analyzer", "clippy" ]

View File

@ -16,6 +16,16 @@ use crate::{
}, },
}; };
// Per-endpoint Authentication:
//
// If an endpoint requires authentication:
// - check the Authorization header for a token
// - if the token is absent or unknown, return a 403
// - if the admin user is absent, return a 403, with a body that indicates the admin user is absent
//
// The login function does not require authentication, but it should return a session ID
fn cors<H, M>(methods: Vec<M>, headers: Vec<H>) -> Builder fn cors<H, M>(methods: Vec<M>, headers: Vec<H>) -> Builder
where where
M: Into<Method>, M: Into<Method>,

View File

@ -1,5 +1,8 @@
use std::{convert::Infallible, future::Future};
use warp::{ use warp::{
http::{header::CONTENT_TYPE, HeaderName, Method}, http::{header::CONTENT_TYPE, HeaderName, Method, Response, StatusCode},
reject,
reply::Reply, reply::Reply,
Filter, Filter,
}; };
@ -9,14 +12,41 @@ use crate::{
handlers::{handle_check_password, handle_get_users, handle_set_admin_password}, handlers::{handle_check_password, handle_get_users, handle_set_admin_password},
}; };
async fn handle_rejection(err: warp::Rejection) -> Result<impl Reply, Infallible> {
println!("handle_rejection: {:?}", err);
if let Some(Unauthorized) = err.find() {
Ok(warp::reply::with_status(
"".to_owned(),
StatusCode::UNAUTHORIZED,
))
} else {
Ok(warp::reply::with_status(
"".to_owned(),
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
#[derive(Debug)]
struct Unauthorized;
impl reject::Reject for Unauthorized {}
use super::cors; use super::cors;
fn route_get_users( fn route_get_users(
core: Core, core: Core,
) -> impl Filter<Extract = (impl Reply,), Error = warp::Rejection> + Clone { ) -> impl Filter<Extract = (Response<Vec<u8>>,), Error = warp::Rejection> + Clone {
warp::path!("api" / "v1" / "users") warp::path!("api" / "v1" / "users")
.and(warp::get()) .and(warp::get())
.then(move || handle_get_users(core.clone())) .and(warp::header::optional::<String>("authorization"))
.and(warp::any().map(move || core.clone()))
.and_then(|auth_token, core: Core| async move {
match auth_token {
Some(token) => Ok(handle_get_users(core.clone()).await),
None => Err(warp::reject::custom(Unauthorized)),
}
})
} }
fn route_set_admin_password( fn route_set_admin_password(
@ -25,10 +55,7 @@ fn route_set_admin_password(
warp::path!("api" / "v1" / "admin_password") warp::path!("api" / "v1" / "admin_password")
.and(warp::put()) .and(warp::put())
.and(warp::body::json()) .and(warp::body::json())
.then({ .then({ move |body| handle_set_admin_password(core.clone(), body) })
let core = core.clone();
move |body| handle_set_admin_password(core.clone(), body)
})
.with(cors(vec![Method::PUT], vec![CONTENT_TYPE])) .with(cors(vec![Method::PUT], vec![CONTENT_TYPE]))
} }
@ -44,10 +71,7 @@ pub fn route_check_password(
warp::path!("api" / "v1" / "auth") warp::path!("api" / "v1" / "auth")
.and(warp::put()) .and(warp::put())
.and(warp::body::json()) .and(warp::body::json())
.then({ .then({ move |body| handle_check_password(core.clone(), body) })
let core = core.clone();
move |body| handle_check_password(core.clone(), body)
})
.with(cors::<HeaderName, Method>(vec![Method::PUT], vec![])) .with(cors::<HeaderName, Method>(vec![Method::PUT], vec![]))
} }
@ -56,8 +80,12 @@ mod test {
use std::{collections::HashMap, path::PathBuf}; use std::{collections::HashMap, path::PathBuf};
use result_extended::ResultExt; use result_extended::ResultExt;
use warp::http::StatusCode;
use crate::{asset_db::mocks::MemoryAssets, database::{Database, DbConn, UserId}}; use crate::{
asset_db::mocks::MemoryAssets,
database::{Database, DbConn, UserId},
};
use super::*; use super::*;
@ -79,8 +107,10 @@ mod test {
} }
match core.user_by_username("admin").await { match core.user_by_username("admin").await {
ResultExt::Ok(Some(user)) => { ResultExt::Ok(Some(user)) => {
let _ = core.set_password(UserId::from(user.id), "aoeu".to_owned()).await; let _ = core
}, .set_password(UserId::from(user.id), "aoeu".to_owned())
.await;
}
ResultExt::Ok(None) => panic!("expected user wasn't found"), ResultExt::Ok(None) => panic!("expected user wasn't found"),
ResultExt::Err(err) => panic!("{}", err), ResultExt::Err(err) => panic!("{}", err),
ResultExt::Fatal(err) => panic!("{}", err), ResultExt::Fatal(err) => panic!("{}", err),
@ -101,7 +131,52 @@ mod test {
println!("response: {}", resp.status()); println!("response: {}", resp.status());
assert!(resp.status().is_success()); assert!(resp.status().is_success());
println!("resp.body(): {}", String::from_utf8(resp.body().to_vec()).unwrap()); println!(
"resp.body(): {}",
String::from_utf8(resp.body().to_vec()).unwrap()
);
serde_json::from_slice::<String>(resp.body()).unwrap(); serde_json::from_slice::<String>(resp.body()).unwrap();
} }
/*
#[tokio::test]
async fn handle_check_auth_token() {
let core = setup().await;
let filter = route_get_users(core);
let response = warp::test::request()
.method("GET")
.path("/api/v1/users")
.header("Authorization", "abcdefg")
.reply(&filter)
.await;
println!("response: {}", response.status());
assert!(false);
}
*/
#[tokio::test]
async fn it_rejects_unauthorized_requests() {
let core = setup().await;
let filter = route_get_users(core)
.recover(handle_rejection);
let response = warp::test::request()
.method("GET")
.path("/api/v1/users")
.reply(&filter)
.await;
println!("response: {:?}", response);
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn it_accepts_authorized_requests() {
unimplemented!();
}
#[tokio::test]
async fn it_returns_special_response_for_no_admin() {
unimplemented!();
}
} }

View File

@ -185,7 +185,9 @@ pub async fn handle_set_background_image(core: Core, image_name: String) -> impl
.await .await
} }
pub async fn handle_get_users(core: Core) -> impl Reply { pub async fn handle_get_users(core: Core) -> Response<Vec<u8>> {
unimplemented!()
/*
handler(async move { handler(async move {
let users = match core.list_users().await { let users = match core.list_users().await {
ResultExt::Ok(users) => users, ResultExt::Ok(users) => users,
@ -200,6 +202,7 @@ pub async fn handle_get_users(core: Core) -> impl Reply {
.unwrap()) .unwrap())
}) })
.await .await
*/
} }
pub async fn handle_get_games(core: Core) -> impl Reply { pub async fn handle_get_games(core: Core) -> impl Reply {

View File

@ -17,14 +17,6 @@ mod filters;
mod handlers; mod handlers;
mod types; mod types;
#[derive(Debug)]
struct Unauthorized;
impl warp::reject::Reject for Unauthorized {}
#[derive(Debug)]
struct AuthDBError(AuthError);
impl warp::reject::Reject for AuthDBError {}
/* /*
fn with_session( fn with_session(
auth_ctx: Arc<AuthDB>, auth_ctx: Arc<AuthDB>,
@ -76,21 +68,6 @@ fn route_echo_authenticated(
} }
*/ */
async fn handle_rejection(err: warp::Rejection) -> Result<impl Reply, Infallible> {
println!("handle_rejection: {:?}", err);
if let Some(Unauthorized) = err.find() {
Ok(warp::reply::with_status(
"".to_owned(),
StatusCode::UNAUTHORIZED,
))
} else {
Ok(warp::reply::with_status(
"".to_owned(),
StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
#[tokio::main] #[tokio::main]
pub async fn main() { pub async fn main() {
pretty_env_logger::init(); pretty_env_logger::init();
@ -106,7 +83,6 @@ pub async fn main() {
unauthenticated_endpoints unauthenticated_endpoints
.or(authenticated_endpoints) .or(authenticated_endpoints)
.with(warp::log("visions")) .with(warp::log("visions"))
.recover(handle_rejection),
); );
server server
.run(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8001)) .run(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8001))