monorepo/memorycache/src/lib.rs

103 lines
2.9 KiB
Rust

use chrono::{DateTime, Utc};
use futures::Future;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
#[derive(Clone)]
pub struct MemoryCacheRecord<T> {
pub expiration: DateTime<Utc>,
pub value: T,
}
pub struct MemoryCache<T>(Arc<RwLock<HashMap<String, MemoryCacheRecord<T>>>>);
impl<T: Clone> Default for MemoryCache<T> {
fn default() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
}
impl<T: Clone> MemoryCache<T> {
pub async fn find(&self, key: &str, f: impl Future<Output = (DateTime<Utc>, T)>) -> T {
let val = {
let cache = self.0.read().unwrap();
cache.get(key).cloned()
};
match val {
Some(ref val) if val.expiration > Utc::now() => val.value.clone(),
_ => {
let response = f.await;
let mut cache = self.0.write().unwrap();
let record = MemoryCacheRecord {
expiration: response.0,
value: response.1.clone(),
};
cache
.entry(key.to_owned())
.and_modify(|rec| *rec = record.clone())
.or_insert(record);
response.1
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Duration;
use std::sync::{Arc, RwLock};
#[derive(Clone, Debug, PartialEq)]
struct Value(i64);
#[tokio::test]
async fn it_runs_the_requestor_when_the_value_does_not_exist() {
let cache = MemoryCache::default();
let value = cache
.find("my_key", async { (Utc::now(), Value(15)) })
.await;
assert_eq!(value, Value(15));
}
#[tokio::test]
async fn it_runs_the_requestor_when_the_value_is_old() {
let run = Arc::new(RwLock::new(false));
let cache = MemoryCache::default();
let _ = cache
.find("my_key", async {
(Utc::now() - Duration::seconds(10), Value(15))
})
.await;
let value = cache
.find("my_key", async {
*run.write().unwrap() = true;
(Utc::now(), Value(16))
})
.await;
assert_eq!(value, Value(16));
assert_eq!(*run.read().unwrap(), true);
}
#[tokio::test]
async fn it_returns_the_cached_value_when_the_value_is_new() {
let run = Arc::new(RwLock::new(false));
let cache = MemoryCache::default();
let _ = cache
.find("my_key", async {
(Utc::now() + Duration::seconds(10), Value(15))
})
.await;
let value = cache
.find("my_key", async {
*run.write().unwrap() = true;
(Utc::now(), Value(16))
})
.await;
assert_eq!(value, Value(15));
assert_eq!(*run.read().unwrap(), false);
}
}