diff --git a/Cargo.lock b/Cargo.lock index 161ae27..8e008d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4223,7 +4223,7 @@ checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" [[package]] name = "torrust-actix" -version = "4.1.0" +version = "4.1.1" dependencies = [ "actix", "actix-cors", diff --git a/Cargo.toml b/Cargo.toml index adefbb5..740959c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torrust-actix" -version = "4.1.0" +version = "4.1.1" edition = "2024" license = "AGPL-3.0" authors = [ diff --git a/README.md b/README.md index cd90bde..a7276a6 100644 --- a/README.md +++ b/README.md @@ -202,6 +202,11 @@ UDP_0_SIMPLE_PROXY_PROTOCOL ### ChangeLog +#### v4.1.1 +* Added hot reloading of SSL certificates for renewal +* API has an extra endpoint to run the hot reloading +* Some more code optimizations + #### v4.1.0 * Added a full Cluster first version through WebSockets * Option to run the app in Stand-Alone (which is default, as single server), or using the cluster mode diff --git a/docker/Dockerfile b/docker/Dockerfile index 190a370..0a0dbd3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,7 +3,7 @@ FROM rust:alpine RUN apk update --no-interactive RUN apk add git musl-dev curl pkgconfig openssl-dev openssl-libs-static --no-interactive RUN git clone https://github.com/Power2All/torrust-actix.git /app/torrust-actix -RUN cd /app/torrust-actix && git checkout tags/v4.1.0 +RUN cd /app/torrust-actix && git checkout tags/v4.1.1 WORKDIR /app/torrust-actix RUN cd /app/torrust-actix RUN cargo build --release && rm -Rf target/release/.fingerprint target/release/build target/release/deps target/release/examples target/release/incremental diff --git a/docker/build.bat b/docker/build.bat index 69cb171..b44e3fd 100644 --- a/docker/build.bat +++ b/docker/build.bat @@ -1,5 +1,5 @@ @echo off -docker build --no-cache -t power2all/torrust-actix:v4.1.0 -t power2all/torrust-actix:latest . -docker push power2all/torrust-actix:v4.1.0 +docker build --no-cache -t power2all/torrust-actix:v4.1.1 -t power2all/torrust-actix:latest . +docker push power2all/torrust-actix:v4.1.1 docker push power2all/torrust-actix:latest \ No newline at end of file diff --git a/src/api/api.rs b/src/api/api.rs index 0a929d3..34c503e 100644 --- a/src/api/api.rs +++ b/src/api/api.rs @@ -1,4 +1,5 @@ use crate::api::api_blacklists::{api_service_blacklist_delete, api_service_blacklist_get, api_service_blacklist_post, api_service_blacklists_delete, api_service_blacklists_get, api_service_blacklists_post}; +use crate::api::api_certificate::{api_service_certificate_reload, api_service_certificate_status}; use crate::api::api_keys::{api_service_key_delete, api_service_key_get, api_service_key_post, api_service_keys_delete, api_service_keys_get, api_service_keys_post}; use crate::api::api_stats::{api_service_prom_get, api_service_stats_get}; use crate::api::api_torrents::{api_service_torrent_delete, api_service_torrent_get, api_service_torrent_post, api_service_torrents_delete, api_service_torrents_get, api_service_torrents_post}; @@ -8,6 +9,8 @@ use crate::api::structs::api_service_data::ApiServiceData; use crate::common::structs::custom_error::CustomError; use crate::config::structs::api_trackers_config::ApiTrackersConfig; use crate::config::structs::configuration::Configuration; +use crate::ssl::certificate_resolver::DynamicCertificateResolver; +use crate::ssl::certificate_store::ServerIdentifier; use crate::stats::enums::stats_event::StatsEvent; use crate::tracker::structs::torrent_tracker::TorrentTracker; use actix_cors::Cors; @@ -18,9 +21,7 @@ use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; use futures_util::StreamExt; use log::{error, info}; use serde_json::json; -use std::fs::File; use std::future::Future; -use std::io::BufReader; use std::net::{IpAddr, SocketAddr}; use std::process::exit; use std::str::FromStr; @@ -101,6 +102,12 @@ pub fn api_service_routes(data: Arc) -> Box, _>>().unwrap_or_else(|data| { - sentry::capture_error(&data); - panic!("[APIS] SSL cert couldn't be extracted: {data}"); - }); - let tls_key = rustls_pemfile::pkcs8_private_keys(key_file).next().unwrap().unwrap_or_else(|data| { - sentry::capture_error(&data); - panic!("[APIS] SSL key couldn't be extracted: {data}"); - }); + let server_id = ServerIdentifier::ApiServer(addr.to_string()); + if let Err(e) = data.certificate_store.load_certificate( + server_id.clone(), + &api_server_object.ssl_cert, + &api_server_object.ssl_key, + ) { + panic!("[APIS] Failed to load SSL certificate: {}", e); + } + let resolver = match DynamicCertificateResolver::new( + Arc::clone(&data.certificate_store), + server_id, + ) { + Ok(resolver) => Arc::new(resolver), + Err(e) => panic!("[APIS] Failed to create certificate resolver: {}", e), + }; let tls_config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert(tls_certs, rustls::pki_types::PrivateKeyDer::Pkcs8(tls_key)) - .unwrap_or_else(|data| { - sentry::capture_error(&data); - panic!("[APIS] SSL config couldn't be created: {data}"); - }); + .with_cert_resolver(resolver); let server = HttpServer::new(app_factory) .keep_alive(Duration::from_secs(keep_alive)) .client_request_timeout(Duration::from_secs(request_timeout)) diff --git a/src/api/api_certificate.rs b/src/api/api_certificate.rs new file mode 100644 index 0000000..fe7b333 --- /dev/null +++ b/src/api/api_certificate.rs @@ -0,0 +1,153 @@ +use crate::api::api::{api_service_token, api_validation}; +use crate::api::structs::api_service_data::ApiServiceData; +use crate::api::structs::query_token::QueryToken; +use crate::ssl::certificate_store::ServerIdentifier; +use actix_web::http::header::ContentType; +use actix_web::web::Data; +use actix_web::{web, HttpRequest, HttpResponse}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::sync::Arc; + +#[derive(Debug, Deserialize)] +pub struct CertificateReloadRequest { + pub server_type: Option, + pub bind_address: Option, +} + +#[derive(Debug, Serialize)] +pub struct CertificateStatusItem { + pub server_type: String, + pub bind_address: String, + pub cert_path: String, + pub key_path: String, + pub loaded_at: String, +} + +#[derive(Debug, Serialize)] +pub struct CertificateReloadResult { + pub server_type: String, + pub bind_address: String, + pub loaded_at: String, +} + +#[derive(Debug, Serialize)] +pub struct CertificateReloadError { + pub server_type: String, + pub bind_address: String, + pub error: String, +} + +#[tracing::instrument(level = "debug")] +pub async fn api_service_certificate_reload( + request: HttpRequest, + data: Data>, + body: Option>, +) -> HttpResponse { + if let Some(error_return) = api_validation(&request, &data).await { + return error_return; + } + let params = web::Query::::from_query(request.query_string()).unwrap(); + if let Some(response) = api_service_token(params.token.clone(), Arc::clone(&data.torrent_tracker.config)).await { + return response; + } + let certificate_store = &data.torrent_tracker.certificate_store; + let (server_type_filter, bind_address_filter) = match body { + Some(req) => (req.server_type.clone(), req.bind_address.clone()), + None => (None, None), + }; + let certificates_to_reload: Vec = { + let certificates = certificate_store.get_all_certificates(); + certificates + .into_iter() + .filter(|(server_id, _)| { + if let Some(ref filter) = server_type_filter + && server_id.server_type() != filter.to_lowercase() + { + return false; + } + if let Some(ref filter) = bind_address_filter + && server_id.bind_address() != filter + { + return false; + } + true + }) + .map(|(server_id, _)| server_id) + .collect() + }; + if certificates_to_reload.is_empty() { + return HttpResponse::Ok().content_type(ContentType::json()).json(json!({ + "status": "no_certificates", + "message": "No SSL certificates found to reload" + })); + } + let mut reloaded: Vec = Vec::with_capacity(certificates_to_reload.len()); + let mut errors: Vec = Vec::new(); + for server_id in certificates_to_reload { + match certificate_store.reload_certificate(&server_id) { + Ok(()) => { + let loaded_at = certificate_store + .get_certificate(&server_id) + .map(|bundle| bundle.loaded_at.to_rfc3339()) + .unwrap_or_else(|| "unknown".to_string()); + reloaded.push(CertificateReloadResult { + server_type: server_id.server_type().to_string(), + bind_address: server_id.bind_address().to_string(), + loaded_at, + }); + } + Err(e) => { + errors.push(CertificateReloadError { + server_type: server_id.server_type().to_string(), + bind_address: server_id.bind_address().to_string(), + error: e.to_string(), + }); + } + } + } + let status = if errors.is_empty() { + "ok" + } else if reloaded.is_empty() { + "failed" + } else { + "partial" + }; + HttpResponse::Ok().content_type(ContentType::json()).json(json!({ + "status": status, + "reloaded": reloaded, + "errors": errors + })) +} + +#[tracing::instrument(level = "debug")] +pub async fn api_service_certificate_status( + request: HttpRequest, + data: Data>, +) -> HttpResponse { + if let Some(error_return) = api_validation(&request, &data).await { + return error_return; + } + let params = web::Query::::from_query(request.query_string()).unwrap(); + if let Some(response) = api_service_token(params.token.clone(), Arc::clone(&data.torrent_tracker.config)).await { + return response; + } + let certificate_store = &data.torrent_tracker.certificate_store; + let certificates = certificate_store.get_all_certificates(); + let status_items: Vec = certificates + .into_iter() + .map(|(server_id, bundle)| { + CertificateStatusItem { + server_type: server_id.server_type().to_string(), + bind_address: server_id.bind_address().to_string(), + cert_path: bundle.cert_path.clone(), + key_path: bundle.key_path.clone(), + loaded_at: bundle.loaded_at.to_rfc3339(), + } + }) + .collect(); + HttpResponse::Ok().content_type(ContentType::json()).json(json!({ + "status": "ok", + "certificates": status_items + })) +} \ No newline at end of file diff --git a/src/api/mod.rs b/src/api/mod.rs index 6ab1df2..1b64ea4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,6 +2,7 @@ pub mod structs; #[allow(clippy::module_inception)] pub mod api; pub mod api_blacklists; +pub mod api_certificate; pub mod api_keys; pub mod api_torrents; pub mod api_users; diff --git a/src/common/common.rs b/src/common/common.rs index a78d4c0..07fa65a 100644 --- a/src/common/common.rs +++ b/src/common/common.rs @@ -1,14 +1,12 @@ use crate::common::structs::custom_error::CustomError; use crate::config::structs::configuration::Configuration; use async_std::future; -use byteorder::{BigEndian, ReadBytesExt}; use fern::colors::{Color, ColoredLevelConfig}; use log::info; use smallvec::SmallVec; use std::collections::HashMap; use std::fmt; use std::fmt::Formatter; -use std::io::Cursor; use std::time::{Duration, SystemTime}; use tokio_shutdown::Shutdown; @@ -162,9 +160,9 @@ pub fn convert_int_to_bytes(number: &u64) -> Vec { #[inline] pub fn convert_bytes_to_int(array: &[u8]) -> u64 { let mut array_fixed = [0u8; 8]; - let start_idx = 8 - array.len(); - array_fixed[start_idx..].copy_from_slice(array); - Cursor::new(array_fixed).read_u64::().unwrap() + let len = array.len().min(8); + array_fixed[8 - len..].copy_from_slice(&array[..len]); + u64::from_be_bytes(array_fixed) } pub async fn shutdown_waiting(timeout: Duration, shutdown_handler: Shutdown) -> bool { diff --git a/src/http/http.rs b/src/http/http.rs index b5c5fcb..39f101f 100644 --- a/src/http/http.rs +++ b/src/http/http.rs @@ -4,6 +4,8 @@ use crate::config::enums::cluster_mode::ClusterMode; use crate::config::structs::http_trackers_config::HttpTrackersConfig; use crate::http::structs::http_service_data::HttpServiceData; use crate::http::types::{HttpServiceQueryHashingMapErr, HttpServiceQueryHashingMapOk}; +use crate::ssl::certificate_resolver::DynamicCertificateResolver; +use crate::ssl::certificate_store::ServerIdentifier; use crate::stats::enums::stats_event::StatsEvent; use crate::tracker::enums::torrent_peers_type::TorrentPeersType; use crate::tracker::structs::info_hash::InfoHash; @@ -21,9 +23,8 @@ use bip_bencode::{ben_bytes, ben_int, ben_list, ben_map, BMutAccess}; use lazy_static::lazy_static; use log::{debug, error, info}; use std::borrow::Cow; -use std::fs::File; use std::future::Future; -use std::io::{BufReader, Write}; +use std::io::Write; use std::net::{IpAddr, SocketAddr}; use std::process::exit; use std::str::FromStr; @@ -96,29 +97,24 @@ pub async fn http_service( error!("[HTTPS] No SSL key or SSL certificate given, exiting..."); exit(1); } - let key_file = &mut BufReader::new(match File::open(http_server_object.ssl_key.clone()) { - Ok(data) => { data } - Err(data) => { - sentry::capture_error(&data); - panic!("[HTTPS] SSL key unreadable: {data}"); - } - }); - let certs_file = &mut BufReader::new(match File::open(http_server_object.ssl_cert.clone()) { - Ok(data) => { data } - Err(data) => { panic!("[HTTPS] SSL cert unreadable: {data}"); } - }); - let tls_certs = match rustls_pemfile::certs(certs_file).collect::, _>>() { - Ok(data) => { data } - Err(data) => { panic!("[HTTPS] SSL cert couldn't be extracted: {data}"); } - }; - let tls_key = match rustls_pemfile::pkcs8_private_keys(key_file).next().unwrap() { - Ok(data) => { data } - Err(data) => { panic!("[HTTPS] SSL key couldn't be extracted: {data}"); } - }; - let tls_config = match rustls::ServerConfig::builder().with_no_client_auth().with_single_cert(tls_certs, rustls::pki_types::PrivateKeyDer::Pkcs8(tls_key)) { - Ok(data) => { data } - Err(data) => { panic!("[HTTPS] SSL config couldn't be created: {data}"); } + let server_id = ServerIdentifier::HttpTracker(addr.to_string()); + if let Err(e) = data.certificate_store.load_certificate( + server_id.clone(), + &http_server_object.ssl_cert, + &http_server_object.ssl_key, + ) { + panic!("[HTTPS] Failed to load SSL certificate: {}", e); + } + let resolver = match DynamicCertificateResolver::new( + Arc::clone(&data.certificate_store), + server_id, + ) { + Ok(resolver) => Arc::new(resolver), + Err(e) => panic!("[HTTPS] Failed to create certificate resolver: {}", e), }; + let tls_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(resolver); let service_data = Arc::new(HttpServiceData { torrent_tracker: data.clone(), http_trackers_config: Arc::new(http_server_object.clone()) @@ -680,54 +676,37 @@ pub fn http_service_stats_log(ip: IpAddr, tracker: &TorrentTracker) } #[tracing::instrument(level = "debug")] +#[inline] pub async fn http_service_decode_hex_hash(hash: String) -> Result { - match hex::decode(hash) { - Ok(hash_result) => { Ok(InfoHash(<[u8; 20]>::try_from(hash_result[0..20].as_ref()).unwrap())) } - Err(_) => { - Err(HttpResponse::InternalServerError().content_type(ContentType::plaintext()).body(ERR_UNABLE_DECODE_HEX.clone())) - } - } + hex::decode(&hash) + .ok() + .and_then(|bytes| bytes.get(..20).and_then(|slice| <[u8; 20]>::try_from(slice).ok())) + .map(InfoHash) + .ok_or_else(|| HttpResponse::InternalServerError().content_type(ContentType::plaintext()).body(ERR_UNABLE_DECODE_HEX.clone())) } #[tracing::instrument(level = "debug")] +#[inline] pub async fn http_service_decode_hex_user_id(hash: String) -> Result { - match hex::decode(hash) { - Ok(hash_result) => { Ok(UserId(<[u8; 20]>::try_from(hash_result[0..20].as_ref()).unwrap())) } - Err(_) => { - Err(HttpResponse::InternalServerError().content_type(ContentType::plaintext()).body(ERR_UNABLE_DECODE_HEX.clone())) - } - } + hex::decode(&hash) + .ok() + .and_then(|bytes| bytes.get(..20).and_then(|slice| <[u8; 20]>::try_from(slice).ok())) + .map(UserId) + .ok_or_else(|| HttpResponse::InternalServerError().content_type(ContentType::plaintext()).body(ERR_UNABLE_DECODE_HEX.clone())) } #[tracing::instrument(level = "debug")] pub async fn http_service_retrieve_remote_ip(request: HttpRequest, data: Arc) -> Result { - let origin_ip = match request.peer_addr() { - None => { - return Err(()); - } - Some(ip) => { - ip.ip() - } - }; - match request.headers().get(data.real_ip.clone()) { - Some(header) => { - if header.to_str().is_ok() { - if let Ok(ip) = IpAddr::from_str(header.to_str().unwrap()) { - Ok(ip) - } else { - Err(()) - } - } else { - Err(()) - } - } - None => { - Ok(origin_ip) - } - } + let origin_ip = request.peer_addr().map(|addr| addr.ip()).ok_or(())?; + request.headers() + .get(&data.real_ip) + .and_then(|header| header.to_str().ok()) + .and_then(|ip_str| IpAddr::from_str(ip_str).ok()) + .map(Ok) + .unwrap_or(Ok(origin_ip)) } #[tracing::instrument(level = "debug")] diff --git a/src/lib.rs b/src/lib.rs index 63904e3..00ed3e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod common; pub mod config; pub mod database; pub mod http; +pub mod ssl; pub mod stats; pub mod structs; pub mod tracker; diff --git a/src/ssl/certificate_resolver.rs b/src/ssl/certificate_resolver.rs new file mode 100644 index 0000000..3673d16 --- /dev/null +++ b/src/ssl/certificate_resolver.rs @@ -0,0 +1,98 @@ +use super::certificate_store::{CertificateBundle, CertificateError, CertificateStore, ServerIdentifier}; +use parking_lot::RwLock; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use std::sync::Arc; + +pub struct DynamicCertificateResolver { + store: Arc, + server_id: ServerIdentifier, + cached_key: RwLock>>, +} + +impl std::fmt::Debug for DynamicCertificateResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DynamicCertificateResolver") + .field("server_id", &self.server_id) + .field("has_cached_key", &self.cached_key.read().is_some()) + .finish() + } +} + +impl DynamicCertificateResolver { + pub fn new( + store: Arc, + server_id: ServerIdentifier, + ) -> Result { + let resolver = Self { + store, + server_id, + cached_key: RwLock::new(None), + }; + resolver.refresh_cache()?; + Ok(resolver) + } + + pub fn server_id(&self) -> &ServerIdentifier { + &self.server_id + } + + pub fn refresh_cache(&self) -> Result<(), CertificateError> { + let bundle = self + .store + .get_certificate(&self.server_id) + .ok_or_else(|| CertificateError::ServerNotFound(self.server_id.clone()))?; + let certified_key = Self::bundle_to_certified_key(&bundle)?; + *self.cached_key.write() = Some(Arc::new(certified_key)); + log::info!( + "[CERTIFICATE] Refreshed certificate cache for {}", + self.server_id + ); + Ok(()) + } + + pub fn has_certificate(&self) -> bool { + self.cached_key.read().is_some() + } + + fn bundle_to_certified_key(bundle: &Arc) -> Result { + let signing_key = rustls::crypto::ring::sign::any_supported_type(&bundle.key) + .map_err(|e| CertificateError::CertifiedKeyError(format!("{}", e)))?; + Ok(CertifiedKey::new(bundle.certs.clone(), signing_key)) + } +} + +impl ResolvesServerCert for DynamicCertificateResolver { + fn resolve(&self, _client_hello: ClientHello<'_>) -> Option> { + self.cached_key.read().clone() + } +} + +pub fn create_server_config_with_resolver( + resolver: Arc, +) -> rustls::ServerConfig { + rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(resolver) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ssl::certificate_store::create_certificate_store; + + #[test] + fn test_resolver_creation_without_cert() { + let store = create_certificate_store(); + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let result = DynamicCertificateResolver::new(store, server_id); + assert!(result.is_err()); + } + + #[test] + fn test_server_identifier_methods() { + let server_id = ServerIdentifier::ApiServer("0.0.0.0:8443".to_string()); + assert_eq!(server_id.bind_address(), "0.0.0.0:8443"); + assert_eq!(server_id.server_type(), "api"); + } +} \ No newline at end of file diff --git a/src/ssl/certificate_store.rs b/src/ssl/certificate_store.rs new file mode 100644 index 0000000..f3a25fe --- /dev/null +++ b/src/ssl/certificate_store.rs @@ -0,0 +1,319 @@ +use parking_lot::RwLock; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; +use thiserror::Error; + +pub struct CertificateBundle { + pub certs: Vec>, + pub key: PrivateKeyDer<'static>, + pub loaded_at: chrono::DateTime, + pub cert_path: String, + pub key_path: String, +} + +impl std::fmt::Debug for CertificateBundle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CertificateBundle") + .field("certs_count", &self.certs.len()) + .field("cert_path", &self.cert_path) + .field("key_path", &self.key_path) + .field("loaded_at", &self.loaded_at) + .finish() + } +} + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum ServerIdentifier { + HttpTracker(String), + ApiServer(String), + WebSocketMaster(String), +} + +impl ServerIdentifier { + pub fn bind_address(&self) -> &str { + match self { + ServerIdentifier::HttpTracker(addr) => addr, + ServerIdentifier::ApiServer(addr) => addr, + ServerIdentifier::WebSocketMaster(addr) => addr, + } + } + + pub fn server_type(&self) -> &'static str { + match self { + ServerIdentifier::HttpTracker(_) => "http", + ServerIdentifier::ApiServer(_) => "api", + ServerIdentifier::WebSocketMaster(_) => "websocket", + } + } +} + +impl std::fmt::Display for ServerIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ServerIdentifier::HttpTracker(addr) => { + write!(f, "HttpTracker({})", addr) + } + ServerIdentifier::ApiServer(addr) => { + write!(f, "ApiServer({})", addr) + } + ServerIdentifier::WebSocketMaster(addr) => { + write!(f, "WebSocketMaster({})", addr) + } + } + } +} + +#[derive(Debug, Error)] +pub enum CertificateError { + #[error("Certificate file not found: {0}")] + CertFileNotFound(String), + #[error("Key file not found: {0}")] + KeyFileNotFound(String), + #[error("Failed to parse certificate: {0}")] + CertParseError(String), + #[error("Failed to parse key: {0}")] + KeyParseError(String), + #[error("No private key found in file")] + NoKeyFound, + #[error("Failed to build certified key: {0}")] + CertifiedKeyError(String), + #[error("Server not found: {0}")] + ServerNotFound(ServerIdentifier), +} + +#[derive(Debug, Clone)] +pub struct CertificatePaths { + pub cert_path: String, + pub key_path: String, +} + +pub struct CertificateStore { + bundles: RwLock>>, + paths: RwLock>, +} + +impl std::fmt::Debug for CertificateStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let bundles = self.bundles.read(); + f.debug_struct("CertificateStore") + .field("certificates_count", &bundles.len()) + .field("servers", &bundles.keys().collect::>()) + .finish() + } +} + +impl Default for CertificateStore { + fn default() -> Self { + Self::new() + } +} + +impl CertificateStore { + pub fn new() -> Self { + Self { + bundles: RwLock::new(HashMap::new()), + paths: RwLock::new(HashMap::new()), + } + } + + pub fn load_certificate( + &self, + server_id: ServerIdentifier, + cert_path: &str, + key_path: &str, + ) -> Result<(), CertificateError> { + let bundle = Self::load_bundle_from_files(cert_path, key_path)?; + self.paths.write().insert( + server_id.clone(), + CertificatePaths { + cert_path: cert_path.to_string(), + key_path: key_path.to_string(), + }, + ); + self.bundles.write().insert(server_id, Arc::new(bundle)); + Ok(()) + } + + pub fn get_certificate(&self, server_id: &ServerIdentifier) -> Option> { + self.bundles.read().get(server_id).cloned() + } + + pub fn get_paths(&self, server_id: &ServerIdentifier) -> Option { + self.paths.read().get(server_id).cloned() + } + + pub fn reload_certificate( + &self, + server_id: &ServerIdentifier, + ) -> Result<(), CertificateError> { + let paths = self + .paths + .read() + .get(server_id) + .cloned() + .ok_or_else(|| CertificateError::ServerNotFound(server_id.clone()))?; + let bundle = Self::load_bundle_from_files(&paths.cert_path, &paths.key_path)?; + self.bundles + .write() + .insert(server_id.clone(), Arc::new(bundle)); + Ok(()) + } + + pub fn reload_certificate_with_paths( + &self, + server_id: &ServerIdentifier, + cert_path: &str, + key_path: &str, + ) -> Result<(), CertificateError> { + let bundle = Self::load_bundle_from_files(cert_path, key_path)?; + self.paths.write().insert( + server_id.clone(), + CertificatePaths { + cert_path: cert_path.to_string(), + key_path: key_path.to_string(), + }, + ); + self.bundles + .write() + .insert(server_id.clone(), Arc::new(bundle)); + Ok(()) + } + + pub fn all_servers(&self) -> Vec<(ServerIdentifier, CertificatePaths)> { + self.paths + .read() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect() + } + + pub fn get_all_certificates(&self) -> Vec<(ServerIdentifier, Arc)> { + self.bundles + .read() + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + } + + pub fn reload_all(&self) -> Vec<(ServerIdentifier, Result<(), CertificateError>)> { + let servers: Vec<_> = self.paths.read().keys().cloned().collect(); + servers + .into_iter() + .map(|server_id| { + let result = self.reload_certificate(&server_id); + (server_id, result) + }) + .collect() + } + + fn load_bundle_from_files( + cert_path: &str, + key_path: &str, + ) -> Result { + let key_file = File::open(key_path) + .map_err(|e| CertificateError::KeyFileNotFound(format!("{}: {}", key_path, e)))?; + let mut key_reader = BufReader::new(key_file); + let certs_file = File::open(cert_path) + .map_err(|e| CertificateError::CertFileNotFound(format!("{}: {}", cert_path, e)))?; + let mut certs_reader = BufReader::new(certs_file); + let tls_certs: Vec> = rustls_pemfile::certs(&mut certs_reader) + .collect::, _>>() + .map_err(|e| CertificateError::CertParseError(e.to_string()))?; + if tls_certs.is_empty() { + return Err(CertificateError::CertParseError( + "No certificates found in file".to_string(), + )); + } + let tls_key = Self::parse_private_key(&mut key_reader, key_path)?; + Ok(CertificateBundle { + certs: tls_certs, + key: tls_key, + loaded_at: chrono::Utc::now(), + cert_path: cert_path.to_string(), + key_path: key_path.to_string(), + }) + } + + fn parse_private_key( + reader: &mut BufReader, + key_path: &str, + ) -> Result, CertificateError> { + if let Some(key_result) = rustls_pemfile::pkcs8_private_keys(reader).next() { + return key_result + .map(PrivateKeyDer::Pkcs8) + .map_err(|e| CertificateError::KeyParseError(e.to_string())); + } + let key_file = File::open(key_path) + .map_err(|e| CertificateError::KeyFileNotFound(format!("{}: {}", key_path, e)))?; + let mut reader = BufReader::new(key_file); + if let Some(key_result) = rustls_pemfile::rsa_private_keys(&mut reader).next() { + return key_result + .map(PrivateKeyDer::Pkcs1) + .map_err(|e| CertificateError::KeyParseError(e.to_string())); + } + let key_file = File::open(key_path) + .map_err(|e| CertificateError::KeyFileNotFound(format!("{}: {}", key_path, e)))?; + let mut reader = BufReader::new(key_file); + if let Some(key_result) = rustls_pemfile::ec_private_keys(&mut reader).next() { + return key_result + .map(PrivateKeyDer::Sec1) + .map_err(|e| CertificateError::KeyParseError(e.to_string())); + } + Err(CertificateError::NoKeyFound) + } +} + +pub fn create_certificate_store() -> Arc { + Arc::new(CertificateStore::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_server_identifier_display() { + let http = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + assert_eq!(format!("{}", http), "HttpTracker(0.0.0.0:443)"); + let api = ServerIdentifier::ApiServer("0.0.0.0:8443".to_string()); + assert_eq!(format!("{}", api), "ApiServer(0.0.0.0:8443)"); + let ws = ServerIdentifier::WebSocketMaster("0.0.0.0:9443".to_string()); + assert_eq!(format!("{}", ws), "WebSocketMaster(0.0.0.0:9443)"); + } + + #[test] + fn test_certificate_store_new() { + let store = CertificateStore::new(); + assert!(store.all_servers().is_empty()); + } + + #[test] + fn test_server_identifier_equality() { + let id1 = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let id2 = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let id3 = ServerIdentifier::HttpTracker("0.0.0.0:8443".to_string()); + assert_eq!(id1, id2); + assert_ne!(id1, id3); + } + + #[test] + fn test_certificate_error_display() { + let err = CertificateError::CertFileNotFound("/path/to/cert.pem".to_string()); + assert!(err.to_string().contains("Certificate file not found")); + let err = CertificateError::NoKeyFound; + assert!(err.to_string().contains("No private key found")); + } + + #[test] + fn test_server_identifier_methods() { + let http = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + assert_eq!(http.bind_address(), "0.0.0.0:443"); + assert_eq!(http.server_type(), "http"); + let api = ServerIdentifier::ApiServer("0.0.0.0:8443".to_string()); + assert_eq!(api.bind_address(), "0.0.0.0:8443"); + assert_eq!(api.server_type(), "api"); + } +} \ No newline at end of file diff --git a/src/ssl/mod.rs b/src/ssl/mod.rs new file mode 100644 index 0000000..eca0c0d --- /dev/null +++ b/src/ssl/mod.rs @@ -0,0 +1,2 @@ +pub mod certificate_store; +pub mod certificate_resolver; \ No newline at end of file diff --git a/src/tracker/impls/torrent_tracker.rs b/src/tracker/impls/torrent_tracker.rs index 07fb6d6..fb24e42 100644 --- a/src/tracker/impls/torrent_tracker.rs +++ b/src/tracker/impls/torrent_tracker.rs @@ -1,6 +1,7 @@ use crate::cache::structs::cache_connector::CacheConnector; use crate::config::structs::configuration::Configuration; use crate::database::structs::database_connector::DatabaseConnector; +use crate::ssl::certificate_store::CertificateStore; use crate::stats::structs::stats_atomics::StatsAtomics; use crate::tracker::structs::torrent_tracker::TorrentTracker; use chrono::Utc; @@ -40,6 +41,7 @@ impl TorrentTracker { TorrentTracker { config: config.clone(), cache, + certificate_store: Arc::new(CertificateStore::new()), torrents_sharding: Arc::new(Default::default()), torrents_updates: Arc::new(RwLock::new(HashMap::new())), torrents_whitelist: Arc::new(RwLock::new(HashSet::new())), diff --git a/src/tracker/structs/torrent_tracker.rs b/src/tracker/structs/torrent_tracker.rs index 8893ae1..bb42f45 100644 --- a/src/tracker/structs/torrent_tracker.rs +++ b/src/tracker/structs/torrent_tracker.rs @@ -1,6 +1,7 @@ use crate::cache::structs::cache_connector::CacheConnector; use crate::config::structs::configuration::Configuration; use crate::database::structs::database_connector::DatabaseConnector; +use crate::ssl::certificate_store::CertificateStore; use crate::stats::structs::stats_atomics::StatsAtomics; use crate::tracker::enums::updates_action::UpdatesAction; use crate::tracker::structs::info_hash::InfoHash; @@ -19,6 +20,7 @@ pub struct TorrentTracker { pub config: Arc, pub sqlx: DatabaseConnector, pub cache: Option, + pub certificate_store: Arc, pub torrents_sharding: Arc, pub torrents_updates: TorrentsUpdates, pub torrents_whitelist: Arc>>, diff --git a/tests/ssl_tests.rs b/tests/ssl_tests.rs new file mode 100644 index 0000000..29e73cb --- /dev/null +++ b/tests/ssl_tests.rs @@ -0,0 +1,420 @@ +mod common; + +use std::sync::Arc; +use torrust_actix::ssl::certificate_store::{ + CertificateError, CertificatePaths, CertificateStore, ServerIdentifier, +}; + +#[tokio::test] +async fn test_server_identifier_http_tracker() { + let id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + assert_eq!(id.bind_address(), "0.0.0.0:443"); + assert_eq!(id.server_type(), "http"); + assert_eq!(format!("{}", id), "HttpTracker(0.0.0.0:443)"); +} + +#[tokio::test] +async fn test_server_identifier_api_server() { + let id = ServerIdentifier::ApiServer("127.0.0.1:8443".to_string()); + assert_eq!(id.bind_address(), "127.0.0.1:8443"); + assert_eq!(id.server_type(), "api"); + assert_eq!(format!("{}", id), "ApiServer(127.0.0.1:8443)"); +} + +#[tokio::test] +async fn test_server_identifier_websocket_master() { + let id = ServerIdentifier::WebSocketMaster("[::]:9443".to_string()); + assert_eq!(id.bind_address(), "[::]:9443"); + assert_eq!(id.server_type(), "websocket"); + assert_eq!(format!("{}", id), "WebSocketMaster([::]:9443)"); +} + +#[tokio::test] +async fn test_server_identifier_equality() { + let id1 = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let id2 = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let id3 = ServerIdentifier::HttpTracker("0.0.0.0:8443".to_string()); + let id4 = ServerIdentifier::ApiServer("0.0.0.0:443".to_string()); + assert_eq!(id1, id2); + assert_ne!(id1, id3); + assert_ne!(id1, id4); +} + +#[tokio::test] +async fn test_server_identifier_clone() { + let id1 = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let id2 = id1.clone(); + assert_eq!(id1, id2); +} + +#[tokio::test] +async fn test_server_identifier_hash() { + use std::collections::HashSet; + + let mut set = HashSet::new(); + set.insert(ServerIdentifier::HttpTracker("0.0.0.0:443".to_string())); + set.insert(ServerIdentifier::ApiServer("0.0.0.0:443".to_string())); + set.insert(ServerIdentifier::HttpTracker("0.0.0.0:443".to_string())); + assert_eq!(set.len(), 2); +} + +#[tokio::test] +async fn test_certificate_store_new() { + let store = CertificateStore::new(); + assert!(store.all_servers().is_empty()); + assert!(store.get_all_certificates().is_empty()); +} + +#[tokio::test] +async fn test_certificate_store_default() { + let store: CertificateStore = Default::default(); + assert!(store.all_servers().is_empty()); +} + +#[tokio::test] +async fn test_certificate_store_get_nonexistent() { + let store = CertificateStore::new(); + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + assert!(store.get_certificate(&server_id).is_none()); + assert!(store.get_paths(&server_id).is_none()); +} + +#[tokio::test] +async fn test_certificate_store_reload_nonexistent() { + let store = CertificateStore::new(); + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let result = store.reload_certificate(&server_id); + assert!(result.is_err()); + match result { + Err(CertificateError::ServerNotFound(id)) => { + assert_eq!(id, server_id); + } + _ => panic!("Expected ServerNotFound error"), + } +} + +#[tokio::test] +async fn test_certificate_store_load_invalid_cert_path() { + let store = CertificateStore::new(); + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let result = store.load_certificate( + server_id, + "/nonexistent/path/cert.pem", + "/nonexistent/path/key.pem", + ); + assert!(result.is_err()); + match result { + Err(CertificateError::KeyFileNotFound(_)) => {} + Err(e) => panic!("Expected KeyFileNotFound, got: {:?}", e), + Ok(_) => panic!("Expected error, got Ok"), + } +} + +#[tokio::test] +async fn test_certificate_store_debug() { + let store = CertificateStore::new(); + let debug_str = format!("{:?}", store); + assert!(debug_str.contains("CertificateStore")); + assert!(debug_str.contains("certificates_count")); +} + +#[tokio::test] +async fn test_certificate_store_thread_safety() { + use std::thread; + + let store = Arc::new(CertificateStore::new()); + let mut handles = vec![]; + for i in 0..10 { + let store_clone = Arc::clone(&store); + let handle = thread::spawn(move || { + let server_id = ServerIdentifier::HttpTracker(format!("0.0.0.0:{}", 443 + i)); + let _ = store_clone.get_certificate(&server_id); + let _ = store_clone.get_paths(&server_id); + let _ = store_clone.all_servers(); + let _ = store_clone.get_all_certificates(); + }); + handles.push(handle); + } + for handle in handles { + handle.join().expect("Thread should not panic"); + } +} + +#[tokio::test] +async fn test_certificate_store_reload_all_empty() { + let store = CertificateStore::new(); + let results = store.reload_all(); + assert!(results.is_empty()); +} + +#[tokio::test] +async fn test_certificate_paths_clone() { + let paths = CertificatePaths { + cert_path: "/path/to/cert.pem".to_string(), + key_path: "/path/to/key.pem".to_string(), + }; + let cloned = paths.clone(); + assert_eq!(paths.cert_path, cloned.cert_path); + assert_eq!(paths.key_path, cloned.key_path); +} + +#[tokio::test] +async fn test_certificate_paths_debug() { + let paths = CertificatePaths { + cert_path: "/path/to/cert.pem".to_string(), + key_path: "/path/to/key.pem".to_string(), + }; + let debug_str = format!("{:?}", paths); + assert!(debug_str.contains("cert_path")); + assert!(debug_str.contains("key_path")); +} + +#[tokio::test] +async fn test_certificate_error_cert_file_not_found() { + let err = CertificateError::CertFileNotFound("/path/to/cert.pem".to_string()); + let msg = err.to_string(); + assert!(msg.contains("Certificate file not found")); + assert!(msg.contains("/path/to/cert.pem")); +} + +#[tokio::test] +async fn test_certificate_error_key_file_not_found() { + let err = CertificateError::KeyFileNotFound("/path/to/key.pem".to_string()); + let msg = err.to_string(); + assert!(msg.contains("Key file not found")); + assert!(msg.contains("/path/to/key.pem")); +} + +#[tokio::test] +async fn test_certificate_error_cert_parse_error() { + let err = CertificateError::CertParseError("invalid PEM format".to_string()); + let msg = err.to_string(); + assert!(msg.contains("Failed to parse certificate")); + assert!(msg.contains("invalid PEM format")); +} + +#[tokio::test] +async fn test_certificate_error_key_parse_error() { + let err = CertificateError::KeyParseError("invalid key format".to_string()); + let msg = err.to_string(); + assert!(msg.contains("Failed to parse key")); + assert!(msg.contains("invalid key format")); +} + +#[tokio::test] +async fn test_certificate_error_no_key_found() { + let err = CertificateError::NoKeyFound; + let msg = err.to_string(); + assert!(msg.contains("No private key found")); +} + +#[tokio::test] +async fn test_certificate_error_certified_key_error() { + let err = CertificateError::CertifiedKeyError("signing error".to_string()); + let msg = err.to_string(); + assert!(msg.contains("Failed to build certified key")); + assert!(msg.contains("signing error")); +} + +#[tokio::test] +async fn test_certificate_error_server_not_found() { + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + let err = CertificateError::ServerNotFound(server_id); + let msg = err.to_string(); + assert!(msg.contains("Server not found")); + assert!(msg.contains("HttpTracker")); +} + +#[tokio::test] +async fn test_certificate_error_debug() { + let err = CertificateError::NoKeyFound; + let debug_str = format!("{:?}", err); + assert!(debug_str.contains("NoKeyFound")); +} + +use actix_web::{test, web, App}; +use torrust_actix::api::api_certificate::{ + api_service_certificate_reload, api_service_certificate_status, +}; +use torrust_actix::api::structs::api_service_data::ApiServiceData; + +#[actix_web::test] +async fn test_api_certificate_status_empty() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let api_config = common::create_test_api_config(); + let service_data = Arc::new(ApiServiceData { + torrent_tracker: tracker.clone(), + api_trackers_config: api_config, + }); + let app = test::init_service( + App::new() + .app_data(web::Data::new(service_data)) + .route( + "/api/certificate/status", + web::get().to(api_service_certificate_status), + ), + ) + .await; + let req = test::TestRequest::get() + .uri("/api/certificate/status?token=MyApiKey") + .peer_addr("127.0.0.1:8080".parse().unwrap()) + .to_request(); + let resp = test::call_service(&app, req).await; + assert!( + resp.status().is_success(), + "Certificate status endpoint should return 200" + ); + let body = test::read_body(resp).await; + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["status"], "ok"); + assert!(json["certificates"].is_array()); + assert_eq!(json["certificates"].as_array().unwrap().len(), 0); +} + +#[actix_web::test] +async fn test_api_certificate_reload_empty() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let api_config = common::create_test_api_config(); + let service_data = Arc::new(ApiServiceData { + torrent_tracker: tracker.clone(), + api_trackers_config: api_config, + }); + let app = test::init_service( + App::new() + .app_data(web::Data::new(service_data)) + .route( + "/api/certificate/reload", + web::post().to(api_service_certificate_reload), + ), + ) + .await; + let req = test::TestRequest::post() + .uri("/api/certificate/reload?token=MyApiKey") + .peer_addr("127.0.0.1:8080".parse().unwrap()) + .to_request(); + let resp = test::call_service(&app, req).await; + assert!( + resp.status().is_success(), + "Certificate reload endpoint should return 200" + ); + let body = test::read_body(resp).await; + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["status"], "no_certificates"); + assert!(json["message"] + .as_str() + .unwrap() + .contains("No SSL certificates")); +} + +#[actix_web::test] +async fn test_api_certificate_status_requires_token() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let api_config = common::create_test_api_config(); + let service_data = Arc::new(ApiServiceData { + torrent_tracker: tracker.clone(), + api_trackers_config: api_config, + }); + let app = test::init_service( + App::new() + .app_data(web::Data::new(service_data)) + .route( + "/api/certificate/status", + web::get().to(api_service_certificate_status), + ), + ) + .await; + let req = test::TestRequest::get() + .uri("/api/certificate/status") + .peer_addr("127.0.0.1:8080".parse().unwrap()) + .to_request(); + let resp = test::call_service(&app, req).await; + assert!( + resp.status().as_u16() == 401 || resp.status().as_u16() == 400, + "Certificate status should require authentication" + ); +} + +#[actix_web::test] +async fn test_api_certificate_reload_requires_token() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let api_config = common::create_test_api_config(); + let service_data = Arc::new(ApiServiceData { + torrent_tracker: tracker.clone(), + api_trackers_config: api_config, + }); + let app = test::init_service( + App::new() + .app_data(web::Data::new(service_data)) + .route( + "/api/certificate/reload", + web::post().to(api_service_certificate_reload), + ), + ) + .await; + let req = test::TestRequest::post() + .uri("/api/certificate/reload") + .peer_addr("127.0.0.1:8080".parse().unwrap()) + .to_request(); + let resp = test::call_service(&app, req).await; + assert!( + resp.status().as_u16() == 401 || resp.status().as_u16() == 400, + "Certificate reload should require authentication" + ); +} + +#[actix_web::test] +async fn test_api_certificate_reload_with_filter() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let api_config = common::create_test_api_config(); + let service_data = Arc::new(ApiServiceData { + torrent_tracker: tracker.clone(), + api_trackers_config: api_config, + }); + let app = test::init_service( + App::new() + .app_data(web::Data::new(service_data)) + .route( + "/api/certificate/reload", + web::post().to(api_service_certificate_reload), + ), + ) + .await; + let req = test::TestRequest::post() + .uri("/api/certificate/reload?token=MyApiKey") + .peer_addr("127.0.0.1:8080".parse().unwrap()) + .set_json(serde_json::json!({ + "server_type": "http", + "bind_address": "0.0.0.0:443" + })) + .to_request(); + let resp = test::call_service(&app, req).await; + assert!( + resp.status().is_success(), + "Certificate reload with filter should return 200" + ); + let body = test::read_body(resp).await; + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["status"], "no_certificates"); +} + +#[actix_web::test] +async fn test_tracker_has_certificate_store() { + let tracker: common::TestTracker = common::create_test_tracker().await; + assert!(tracker.certificate_store.all_servers().is_empty()); + assert!(tracker.certificate_store.get_all_certificates().is_empty()); +} + +#[actix_web::test] +async fn test_certificate_store_from_tracker() { + let tracker: common::TestTracker = common::create_test_tracker().await; + let server_id = ServerIdentifier::HttpTracker("0.0.0.0:443".to_string()); + assert!(tracker.certificate_store.get_certificate(&server_id).is_none()); + let result = tracker.certificate_store.reload_certificate(&server_id); + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_create_certificate_store_helper() { + let store = torrust_actix::ssl::certificate_store::create_certificate_store(); + assert!(store.all_servers().is_empty()); +} \ No newline at end of file