diff --git a/Cargo.toml b/Cargo.toml index a45f199..33216ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ sqlx = { version = "0.8.6", features = ["runtime-tokio", "sqlite", "migrate"] } argon2 = "0.5.3" base64 = "0.22.1" clap = { version = "4.5.58", features = ["derive"] } +moka = { version = "0.12", features = ["future"] } toml = "0.8" reqwest = { version = "0.13.2", features = ["json"] } rain_orderbook_js_api = { path = "lib/rain.orderbook/crates/js_api", default-features = false } diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..93011eb --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,158 @@ +use moka::future::Cache; +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; + +pub(crate) struct AppCache(Cache) +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static; + +impl AppCache +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + pub(crate) fn new(max_capacity: u64, ttl: Duration) -> Self { + Self( + Cache::builder() + .max_capacity(max_capacity) + .time_to_live(ttl) + .build(), + ) + } + + pub(crate) async fn get(&self, key: &K) -> Option { + self.0.get(key).await + } + + pub(crate) async fn insert(&self, key: K, value: V) { + self.0.insert(key, value).await + } + + pub(crate) async fn get_or_try_insert(&self, key: K, fetch: F) -> Result> + where + F: FnOnce() -> Fut, + Fut: Future>, + E: Send + Sync + 'static, + { + self.0.try_get_with(key, fetch()).await + } + + pub(crate) fn invalidate_all(&self) { + self.0.invalidate_all() + } +} + +trait Invalidatable: Send + Sync { + fn invalidate_all(&self); +} + +impl Invalidatable for Cache +where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, +{ + fn invalidate_all(&self) { + Cache::invalidate_all(self) + } +} + +pub(crate) struct CacheGroup { + caches: Vec>, +} + +impl CacheGroup { + pub(crate) fn new() -> Self { + Self { caches: Vec::new() } + } + + pub(crate) fn register(&mut self, cache: &AppCache) + where + K: std::hash::Hash + Eq + Send + Sync + 'static, + V: Clone + Send + Sync + 'static, + { + self.caches.push(Arc::new(cache.0.clone())); + } + + pub(crate) fn invalidate_all(&self) { + for cache in &self.caches { + cache.invalidate_all(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[rocket::async_test] + async fn test_app_cache_insert_and_get() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("key", 42).await; + assert_eq!(cache.get(&"key").await, Some(42)); + } + + #[rocket::async_test] + async fn test_app_cache_get_returns_none_for_missing_key() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + assert!(cache.get(&"missing").await.is_none()); + } + + #[rocket::async_test] + async fn test_app_cache_invalidate_all_clears_entries() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("a", 1).await; + cache.insert("b", 2).await; + cache.invalidate_all(); + tokio::task::yield_now().await; + assert!(cache.get(&"a").await.is_none()); + assert!(cache.get(&"b").await.is_none()); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_calls_fetch_on_miss() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let result: Result> = + cache.get_or_try_insert("key", || async { Ok(42) }).await; + assert_eq!(result.unwrap(), 42); + assert_eq!(cache.get(&"key").await, Some(42)); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_returns_cached_on_hit() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + cache.insert("key", 42).await; + let result: Result> = cache + .get_or_try_insert("key", || async { panic!("fetch should not be called") }) + .await; + assert_eq!(result.unwrap(), 42); + } + + #[rocket::async_test] + async fn test_get_or_try_insert_does_not_cache_errors() { + let cache: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let result: Result> = cache + .get_or_try_insert("key", || async { Err("fail".to_string()) }) + .await; + assert!(result.is_err()); + assert!(cache.get(&"key").await.is_none()); + } + + #[rocket::async_test] + async fn test_cache_group_invalidate_all_clears_registered_caches() { + let cache_a: AppCache<&str, u32> = AppCache::new(10, Duration::from_secs(60)); + let cache_b: AppCache = AppCache::new(10, Duration::from_secs(60)); + cache_a.insert("x", 10).await; + cache_b.insert(1, "hello".into()).await; + + let mut group = CacheGroup::new(); + group.register(&cache_a); + group.register(&cache_b); + group.invalidate_all(); + + tokio::task::yield_now().await; + assert!(cache_a.get(&"x").await.is_none()); + assert!(cache_b.get(&1).await.is_none()); + } +} diff --git a/src/error.rs b/src/error.rs index 1e78064..01b71eb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -36,6 +36,12 @@ pub enum ApiError { RateLimited(String), } +impl From> for ApiError { + fn from(arc: std::sync::Arc) -> Self { + (*arc).clone() + } +} + impl<'r> Responder<'r, 'static> for ApiError { fn respond_to(self, req: &'r Request<'_>) -> rocket::response::Result<'static> { let (status, code, message) = match &self { diff --git a/src/main.rs b/src/main.rs index a1e4885..2446da2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ extern crate rocket; mod auth; +mod cache; mod catchers; mod cli; mod config; @@ -121,10 +122,20 @@ pub(crate) fn rocket( let options = Options::Index | Options::NormalizeDirs; + let order_cache = routes::order::order_detail_cache(); + let swap_cache = routes::swap::swap_quote_cache(); + + let mut registry_caches = cache::CacheGroup::new(); + registry_caches.register(&order_cache); + registry_caches.register(&swap_cache); + Ok(rocket::custom(figment) .manage(pool) .manage(rate_limiter) .manage(raindex_config) + .manage(order_cache) + .manage(swap_cache) + .manage(registry_caches) .mount("/", routes::health::routes()) .mount("/v1/tokens", routes::tokens::routes()) .mount("/v1/swap", routes::swap::routes()) diff --git a/src/routes/admin.rs b/src/routes/admin.rs index e2540f5..b9b2ea8 100644 --- a/src/routes/admin.rs +++ b/src/routes/admin.rs @@ -1,4 +1,5 @@ use crate::auth::AdminKey; +use crate::cache::CacheGroup; use crate::db::{settings, DbPool}; use crate::error::{ApiError, ApiErrorResponse}; use crate::fairings::{GlobalRateLimit, TracingSpan}; @@ -37,6 +38,7 @@ pub async fn put_registry( pool: &State, span: TracingSpan, request: Json, + registry_caches: &State, ) -> Result, ApiError> { let req = request.into_inner(); async move { @@ -67,7 +69,9 @@ pub async fn put_registry( *guard = new_provider; drop(guard); - tracing::info!(registry_url = %req.registry_url, "registry updated"); + registry_caches.invalidate_all(); + + tracing::info!(registry_url = %req.registry_url, "registry updated, caches invalidated"); Ok(Json(RegistryResponse { registry_url: req.registry_url, @@ -225,4 +229,92 @@ mod tests { assert_eq!(response.status(), Status::BadRequest); } + + #[rocket::async_test] + async fn test_put_registry_invalidates_caches() { + use crate::routes::order::OrderDetailCache; + use crate::routes::swap::SwapQuoteCache; + use crate::types::common::TokenRef; + use crate::types::order::{OrderDetail, OrderDetailsInfo, OrderType}; + use crate::types::swap::SwapQuoteResponse; + use alloy::primitives::{address, Address, U256}; + + let client = TestClientBuilder::new().build().await; + let (key_id, secret) = seed_admin_key(&client).await; + let header = basic_auth_header(&key_id, &secret); + + let order_hash = "0x000000000000000000000000000000000000000000000000000000000000abcd" + .parse() + .unwrap(); + let dummy_order = OrderDetail { + order_hash, + owner: Address::ZERO, + order_details: OrderDetailsInfo { + type_: OrderType::Solver, + io_ratio: "1.0".into(), + }, + input_token: TokenRef { + address: Address::ZERO, + symbol: "USDC".into(), + decimals: 6, + }, + output_token: TokenRef { + address: Address::ZERO, + symbol: "WETH".into(), + decimals: 18, + }, + input_vault_id: U256::ZERO, + output_vault_id: U256::ZERO, + input_vault_balance: "0".into(), + output_vault_balance: "0".into(), + io_ratio: "1.0".into(), + created_at: 0, + orderbook_id: Address::ZERO, + trades: vec![], + }; + let order_cache = client + .rocket() + .state::() + .expect("OrderDetailCache in state"); + order_cache.insert(order_hash, dummy_order).await; + assert!(order_cache.get(&order_hash).await.is_some()); + + let usdc = address!("833589fCD6eDb6E08f4c7C32D4f71b54bdA02913"); + let weth = address!("4200000000000000000000000000000000000006"); + let cache_key = (usdc, weth, "100".to_string()); + let dummy_quote = SwapQuoteResponse { + input_token: usdc, + output_token: weth, + output_amount: "100".into(), + estimated_output: "100".into(), + estimated_input: "150".into(), + estimated_io_ratio: "1.5".into(), + }; + let swap_cache = client + .rocket() + .state::() + .expect("SwapQuoteCache in state"); + swap_cache.insert(cache_key.clone(), dummy_quote).await; + assert!(swap_cache.get(&cache_key).await.is_some()); + + let new_url = mock_raindex_registry_url().await; + let response = client + .put("/admin/registry") + .header(Header::new("Authorization", header)) + .header(ContentType::JSON) + .body(format!(r#"{{"registry_url":"{new_url}"}}"#)) + .dispatch() + .await; + assert_eq!(response.status(), Status::Ok); + + tokio::task::yield_now().await; + assert!( + order_cache.get(&order_hash).await.is_none(), + "order cache must be empty after registry update" + ); + assert!( + swap_cache.get(&cache_key).await.is_none(), + "swap cache must be empty after registry update" + ); + } } diff --git a/src/routes/order/get_order.rs b/src/routes/order/get_order.rs index 2044491..951ad40 100644 --- a/src/routes/order/get_order.rs +++ b/src/routes/order/get_order.rs @@ -1,4 +1,4 @@ -use super::{OrderDataSource, RaindexOrderDataSource}; +use super::{OrderDataSource, OrderDetailCache, RaindexOrderDataSource}; use crate::auth::AuthenticatedKey; use crate::error::{ApiError, ApiErrorResponse}; use crate::fairings::{GlobalRateLimit, TracingSpan}; @@ -35,18 +35,25 @@ pub async fn get_order( shared_raindex: &State, span: TracingSpan, order_hash: ValidatedFixedBytes, + order_cache: &State, ) -> Result, ApiError> { async move { tracing::info!(order_hash = ?order_hash, "request received"); let hash = order_hash.0; - let raindex = shared_raindex.read().await; - let detail = raindex - .run_with_client(move |client| async move { - let ds = RaindexOrderDataSource { client: &client }; - process_get_order(&ds, hash).await + + let detail = order_cache + .get_or_try_insert(hash, || async { + let raindex = shared_raindex.read().await; + raindex + .run_with_client(move |client| async move { + let ds = RaindexOrderDataSource { client: &client }; + process_get_order(&ds, hash).await + }) + .await + .map_err(ApiError::from)? }) - .await - .map_err(ApiError::from)??; + .await?; + Ok(Json(detail)) } .instrument(span.0) @@ -303,4 +310,68 @@ mod tests { "failed to initialize orderbook client" ); } + + #[rocket::async_test] + async fn test_get_order_returns_cached_entry() { + use super::OrderDetailCache; + use crate::types::common::TokenRef; + use crate::types::order::{OrderDetailsInfo, OrderType}; + use alloy::primitives::U256; + + let config = mock_invalid_raindex_config().await; + let client = TestClientBuilder::new() + .raindex_config(config) + .build() + .await; + let (key_id, secret) = seed_api_key(&client).await; + let header = basic_auth_header(&key_id, &secret); + + let order_hash: alloy::primitives::B256 = + "0x000000000000000000000000000000000000000000000000000000000000abcd" + .parse() + .unwrap(); + let dummy = OrderDetail { + order_hash, + owner: Address::ZERO, + order_details: OrderDetailsInfo { + type_: OrderType::Solver, + io_ratio: "2.5".into(), + }, + input_token: TokenRef { + address: Address::ZERO, + symbol: "USDC".into(), + decimals: 6, + }, + output_token: TokenRef { + address: Address::ZERO, + symbol: "WETH".into(), + decimals: 18, + }, + input_vault_id: U256::ZERO, + output_vault_id: U256::ZERO, + input_vault_balance: "100".into(), + output_vault_balance: "50".into(), + io_ratio: "2.5".into(), + created_at: 0, + orderbook_id: Address::ZERO, + trades: vec![], + }; + + let cache = client + .rocket() + .state::() + .expect("OrderDetailCache in state"); + cache.insert(order_hash, dummy).await; + + let response = client + .get("/v1/order/0x000000000000000000000000000000000000000000000000000000000000abcd") + .header(Header::new("Authorization", header)) + .dispatch() + .await; + assert_eq!(response.status(), Status::Ok); + let body: serde_json::Value = + serde_json::from_str(&response.into_string().await.unwrap()).unwrap(); + assert_eq!(body["ioRatio"], "2.5"); + assert_eq!(body["inputVaultBalance"], "100"); + } } diff --git a/src/routes/order/mod.rs b/src/routes/order/mod.rs index 620ed23..f6c775a 100644 --- a/src/routes/order/mod.rs +++ b/src/routes/order/mod.rs @@ -3,7 +3,9 @@ mod deploy_dca; mod deploy_solver; mod get_order; +use crate::cache::AppCache; use crate::error::ApiError; +use crate::types::order::OrderDetail; use alloy::primitives::{Bytes, B256}; use async_trait::async_trait; use rain_orderbook_common::raindex_client::order_quotes::RaindexOrderQuote; @@ -11,6 +13,16 @@ use rain_orderbook_common::raindex_client::orders::{GetOrdersFilters, RaindexOrd use rain_orderbook_common::raindex_client::trades::RaindexTrade; use rain_orderbook_common::raindex_client::RaindexClient; use rocket::Route; +use std::time::Duration; + +const ORDER_CACHE_TTL: Duration = Duration::from_secs(60); +const ORDER_CACHE_CAPACITY: u64 = 1_000; + +pub(crate) type OrderDetailCache = AppCache; + +pub(crate) fn order_detail_cache() -> OrderDetailCache { + AppCache::new(ORDER_CACHE_CAPACITY, ORDER_CACHE_TTL) +} #[async_trait(?Send)] pub(crate) trait OrderDataSource { diff --git a/src/routes/swap/mod.rs b/src/routes/swap/mod.rs index 2cffbce..db9d0a8 100644 --- a/src/routes/swap/mod.rs +++ b/src/routes/swap/mod.rs @@ -1,8 +1,9 @@ mod calldata; mod quote; +use crate::cache::AppCache; use crate::error::ApiError; -use crate::types::swap::SwapCalldataResponse; +use crate::types::swap::{SwapCalldataResponse, SwapQuoteResponse}; use alloy::primitives::Address; use async_trait::async_trait; use rain_orderbook_common::raindex_client::orders::{ @@ -15,6 +16,16 @@ use rain_orderbook_common::take_orders::{ build_take_order_candidates_for_pair, TakeOrderCandidate, }; use rocket::Route; +use std::time::Duration; + +const SWAP_QUOTE_CACHE_TTL: Duration = Duration::from_secs(5); +const SWAP_QUOTE_CACHE_CAPACITY: u64 = 500; + +pub(crate) type SwapQuoteCache = AppCache<(Address, Address, String), SwapQuoteResponse>; + +pub(crate) fn swap_quote_cache() -> SwapQuoteCache { + AppCache::new(SWAP_QUOTE_CACHE_CAPACITY, SWAP_QUOTE_CACHE_TTL) +} #[async_trait(?Send)] pub(crate) trait SwapDataSource { diff --git a/src/routes/swap/quote.rs b/src/routes/swap/quote.rs index fb031e2..356680f 100644 --- a/src/routes/swap/quote.rs +++ b/src/routes/swap/quote.rs @@ -1,4 +1,4 @@ -use super::{RaindexSwapDataSource, SwapDataSource}; +use super::{RaindexSwapDataSource, SwapDataSource, SwapQuoteCache}; use crate::auth::AuthenticatedKey; use crate::error::{ApiError, ApiErrorResponse}; use crate::fairings::{GlobalRateLimit, TracingSpan}; @@ -32,18 +32,26 @@ pub async fn post_swap_quote( shared_raindex: &State, span: TracingSpan, request: Json, + swap_cache: &State, ) -> Result, ApiError> { let req = request.into_inner(); async move { tracing::info!(body = ?req, "request received"); - let raindex = shared_raindex.read().await; - let response = raindex - .run_with_client(move |client| async move { - let ds = RaindexSwapDataSource { client: &client }; - process_swap_quote(&ds, req).await + + let cache_key = (req.input_token, req.output_token, req.output_amount.clone()); + let response = swap_cache + .get_or_try_insert(cache_key, || async { + let raindex = shared_raindex.read().await; + raindex + .run_with_client(move |client| async move { + let ds = RaindexSwapDataSource { client: &client }; + process_swap_quote(&ds, req).await + }) + .await + .map_err(ApiError::from)? }) - .await - .map_err(ApiError::from)??; + .await?; + Ok(Json(response)) } .instrument(span.0) @@ -287,4 +295,46 @@ mod tests { "failed to initialize orderbook client" ); } + + #[rocket::async_test] + async fn test_swap_quote_returns_cached_entry() { + use super::SwapQuoteCache; + use crate::types::swap::SwapQuoteResponse; + + let config = mock_invalid_raindex_config().await; + let client = TestClientBuilder::new() + .raindex_config(config) + .build() + .await; + let (key_id, secret) = seed_api_key(&client).await; + let header = basic_auth_header(&key_id, &secret); + + let dummy = SwapQuoteResponse { + input_token: USDC, + output_token: WETH, + output_amount: "100".into(), + estimated_output: "100".into(), + estimated_input: "250".into(), + estimated_io_ratio: "2.5".into(), + }; + + let cache = client + .rocket() + .state::() + .expect("SwapQuoteCache in state"); + cache.insert((USDC, WETH, "100".to_string()), dummy).await; + + let response = client + .post("/v1/swap/quote") + .header(Header::new("Authorization", header)) + .header(ContentType::JSON) + .body(r#"{"inputToken":"0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913","outputToken":"0x4200000000000000000000000000000000000006","outputAmount":"100"}"#) + .dispatch() + .await; + assert_eq!(response.status(), Status::Ok); + let body: serde_json::Value = + serde_json::from_str(&response.into_string().await.unwrap()).unwrap(); + assert_eq!(body["estimatedInput"], "250"); + assert_eq!(body["estimatedIoRatio"], "2.5"); + } } diff --git a/src/routes/tokens.rs b/src/routes/tokens.rs index c54e92f..ec50ff5 100644 --- a/src/routes/tokens.rs +++ b/src/routes/tokens.rs @@ -1,4 +1,5 @@ use crate::auth::AuthenticatedKey; +use crate::cache::AppCache; use crate::error::{ApiError, ApiErrorResponse}; use crate::fairings::{GlobalRateLimit, TracingSpan}; use crate::types::tokens::{RemoteTokenList, TokenInfo, TokenListResponse}; @@ -11,6 +12,9 @@ use tracing::Instrument; const TOKEN_LIST_URL: &str = "https://raw.githubusercontent.com/S01-Issuer/st0x-tokens/ad1a637a79d5a220ad089aecdc5b7239d3473f6e/src/st0xTokens.json"; const TARGET_CHAIN_ID: u32 = crate::CHAIN_ID; const TOKEN_LIST_TIMEOUT_SECS: u64 = 10; +const TOKEN_CACHE_TTL: Duration = Duration::from_secs(600); + +pub(crate) type TokenCache = AppCache<(), TokenListResponse>; pub(crate) struct TokensConfig { pub(crate) url: String, @@ -38,12 +42,17 @@ impl TokensConfig { pub(crate) fn fairing() -> AdHoc { AdHoc::on_ignite("Tokens Config", |rocket| async { - if rocket.state::().is_some() { + let rocket = if rocket.state::().is_some() { tracing::info!("TokensConfig already managed; skipping default initialization"); rocket } else { tracing::info!(url = %TOKEN_LIST_URL, "initializing default TokensConfig"); rocket.manage(TokensConfig::default()) + }; + if rocket.state::().is_some() { + rocket + } else { + rocket.manage(TokenCache::new(1, TOKEN_CACHE_TTL)) } }) } @@ -83,51 +92,63 @@ pub async fn get_tokens( _key: AuthenticatedKey, span: TracingSpan, tokens_config: &State, + token_cache: &State, ) -> Result, ApiError> { let url = tokens_config.url.clone(); let client = tokens_config.client.clone(); async move { tracing::info!("request received"); - tracing::info!(url = %url, timeout_secs = TOKEN_LIST_TIMEOUT_SECS, "fetching token list"); + let result = token_cache + .get_or_try_insert((), || async { + tracing::info!(url = %url, timeout_secs = TOKEN_LIST_TIMEOUT_SECS, "fetching token list"); - let response = client - .get(&url) - .timeout(Duration::from_secs(TOKEN_LIST_TIMEOUT_SECS)) - .send() - .await - .map_err(TokenError::Fetch)?; - - let status = response.status(); - if !status.is_success() { - return Err(TokenError::BadStatus(status).into()); - } + let response = client + .get(&url) + .timeout(Duration::from_secs(TOKEN_LIST_TIMEOUT_SECS)) + .send() + .await + .map_err(TokenError::Fetch) + .map_err(ApiError::from)?; - let remote: RemoteTokenList = response.json().await.map_err(TokenError::Deserialize)?; - - let tokens: Vec = remote - .tokens - .into_iter() - .filter(|t| t.chain_id == TARGET_CHAIN_ID) - .map(|t| { - let isin = t - .extensions - .get("isin") - .or_else(|| t.extensions.get("ISIN")) - .and_then(|v| v.as_str()) - .map(String::from); - TokenInfo { - address: t.address, - symbol: t.symbol, - name: t.name, - isin, - decimals: t.decimals, + let status = response.status(); + if !status.is_success() { + return Err(ApiError::from(TokenError::BadStatus(status))); } + + let remote: RemoteTokenList = response + .json() + .await + .map_err(TokenError::Deserialize) + .map_err(ApiError::from)?; + + let tokens: Vec = remote + .tokens + .into_iter() + .filter(|t| t.chain_id == TARGET_CHAIN_ID) + .map(|t| { + let isin = t + .extensions + .get("isin") + .or_else(|| t.extensions.get("ISIN")) + .and_then(|v| v.as_str()) + .map(String::from); + TokenInfo { + address: t.address, + symbol: t.symbol, + name: t.name, + isin, + decimals: t.decimals, + } + }) + .collect(); + + Ok(TokenListResponse { tokens }) }) - .collect(); + .await?; - tracing::info!(count = tokens.len(), "returning tokens"); - Ok(Json(TokenListResponse { tokens })) + tracing::info!(count = result.tokens.len(), "returning tokens"); + Ok(Json(result)) } .instrument(span.0) .await @@ -347,4 +368,65 @@ mod tests { .unwrap() .contains("failed to retrieve token list")); } + + #[rocket::async_test] + async fn test_get_tokens_cache_hit_on_second_request() { + let body = r#"{"tokens":[{"chainId":8453,"address":"0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913","name":"USD Coin","symbol":"USDC","decimals":6}]}"#; + let response_bytes = format!( + "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let response_bytes: &'static [u8] = + Box::leak(response_bytes.into_bytes().into_boxed_slice()); + let url = mock_server(response_bytes).await; + let client = TestClientBuilder::new().token_list_url(&url).build().await; + let (key_id, secret) = seed_api_key(&client).await; + let header = basic_auth_header(&key_id, &secret); + + let first = client + .get("/v1/tokens") + .header(Header::new("Authorization", header.clone())) + .dispatch() + .await; + assert_eq!(first.status(), Status::Ok); + + let second = client + .get("/v1/tokens") + .header(Header::new("Authorization", header)) + .dispatch() + .await; + assert_eq!(second.status(), Status::Ok); + + let body: serde_json::Value = + serde_json::from_str(&second.into_string().await.unwrap()).unwrap(); + let tokens = body["tokens"].as_array().unwrap(); + assert_eq!(tokens.len(), 1); + assert_eq!(tokens[0]["symbol"], "USDC"); + } + + #[rocket::async_test] + async fn test_get_tokens_error_response_is_not_cached() { + let url = mock_server( + b"HTTP/1.1 500 Internal Server Error\r\nConnection: close\r\nContent-Length: 0\r\n\r\n", + ) + .await; + let client = TestClientBuilder::new().token_list_url(&url).build().await; + let (key_id, secret) = seed_api_key(&client).await; + let header = basic_auth_header(&key_id, &secret); + + let response = client + .get("/v1/tokens") + .header(Header::new("Authorization", header)) + .dispatch() + .await; + assert_eq!(response.status(), Status::InternalServerError); + + let token_cache = client + .rocket() + .state::() + .expect("TokenCache in state"); + let cached = token_cache.get(&()).await; + assert!(cached.is_none(), "error response must not be cached"); + } }