use chrono::{DateTime, Utc}; use futures::Future; use std::{ collections::HashMap, sync::{Arc, RwLock}, }; #[derive(Clone)] pub struct MemoryCacheRecord { pub expiration: DateTime, pub value: T, } pub struct MemoryCache(Arc>>>); impl Default for MemoryCache { fn default() -> Self { Self(Arc::new(RwLock::new(HashMap::new()))) } } impl MemoryCache { pub async fn find(&self, key: &str, f: impl Future, T)>) -> T { let val = { let cache = self.0.read().unwrap(); cache.get(key).cloned() }; match val { Some(ref val) if val.expiration > Utc::now() => val.value.clone(), _ => { let response = f.await; let mut cache = self.0.write().unwrap(); let record = MemoryCacheRecord { expiration: response.0, value: response.1.clone(), }; cache .entry(key.to_owned()) .and_modify(|rec| *rec = record.clone()) .or_insert(record); response.1 } } } } #[cfg(test)] mod tests { use super::*; use chrono::Duration; use std::sync::{Arc, RwLock}; #[derive(Clone, Debug, PartialEq)] struct Value(i64); #[tokio::test] async fn it_runs_the_requestor_when_the_value_does_not_exist() { let cache = MemoryCache::default(); let value = cache .find("my_key", async { (Utc::now(), Value(15)) }) .await; assert_eq!(value, Value(15)); } #[tokio::test] async fn it_runs_the_requestor_when_the_value_is_old() { let run = Arc::new(RwLock::new(false)); let cache = MemoryCache::default(); let _ = cache .find("my_key", async { (Utc::now() - Duration::seconds(10), Value(15)) }) .await; let value = cache .find("my_key", async { *run.write().unwrap() = true; (Utc::now(), Value(16)) }) .await; assert_eq!(value, Value(16)); assert_eq!(*run.read().unwrap(), true); } #[tokio::test] async fn it_returns_the_cached_value_when_the_value_is_new() { let run = Arc::new(RwLock::new(false)); let cache = MemoryCache::default(); let _ = cache .find("my_key", async { (Utc::now() + Duration::seconds(10), Value(15)) }) .await; let value = cache .find("my_key", async { *run.write().unwrap() = true; (Utc::now(), Value(16)) }) .await; assert_eq!(value, Value(15)); assert_eq!(*run.read().unwrap(), false); } }