Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 109 additions & 19 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use base64::engine::general_purpose::STANDARD as base64_engine;
use base64::Engine;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use rocket::data::{Data, ToByteUnit, Limits};
use rocket::data::{Data, Limits, ToByteUnit};
use rocket::fairing::{Fairing, Info, Kind};
use rocket::http::Header;
use rocket::response::content::RawText;
Expand Down Expand Up @@ -99,7 +99,8 @@ struct PairingEntry {
expired: bool,
}

type SharedPairingState = Arc<Mutex<HashMap<String, Arc<Mutex<PairingEntry>>>>>;
type PairingSessionKey = (String, String);
type SharedPairingState = Arc<Mutex<HashMap<PairingSessionKey, Arc<Mutex<PairingEntry>>>>>;
type AllEventState = Arc<DashMap<String, EventState>>;

// Simple rate limiters for the server
Expand All @@ -109,6 +110,10 @@ const MAX_LIVESTREAM_FILE_SIZE: usize = 20; // in mebibytes
const MAX_NUM_PENDING_LIVESTREAM_FILES: usize = 50;
const MAX_COMMAND_FILE_SIZE: usize = 100; // in kibibytes
const MAX_JSON_SIZE: usize = 10; // in kibibytes
#[cfg(not(test))]
const PAIRING_SESSION_TIMEOUT: Duration = Duration::from_secs(45);
#[cfg(test)]
const PAIRING_SESSION_TIMEOUT: Duration = Duration::from_millis(250);

async fn get_num_files(path: &Path) -> io::Result<usize> {
let mut entries = fs::read_dir(path).await?;
Expand Down Expand Up @@ -184,11 +189,25 @@ async fn pair(
}

let token = &data.pairing_token;

// Check for disallowed quote characters in the token
if token.contains('"') {
debug!("[PAIR] Invalid token contains quote character: {}", token);
return Json(PairingResponse {
status: "invalid_token".into(),
notification_target: None,
});
}

let session_key = (auth.username.clone(), token.clone());
let entry_arc = {
let mut sessions = state.lock().unwrap();
debug!("[PAIR] Looking up or creating session for token: {}", token);
debug!(
"[PAIR] Looking up or creating session for user: {}, token: {}",
auth.username, token
);
sessions
.entry(token.clone())
.entry(session_key)
.or_insert_with(|| {
debug!("[PAIR] No existing session found. Creating new entry.");
Arc::new(Mutex::new(PairingEntry {
Expand All @@ -205,17 +224,6 @@ async fn pair(
.clone()
};

let token = &data.pairing_token;

// Check for disallowed quote characters in the token
if token.contains('"') {
debug!("[PAIR] Invalid token contains quote character: {}", token);
return Json(PairingResponse {
status: "invalid_token".into(),
notification_target: None,
});
}

let notify;
let expired_at;
let target_to_persist;
Expand All @@ -236,7 +244,7 @@ async fn pair(
elapsed, entry.phone_notified, entry.camera_notified
);

if elapsed > Duration::from_secs(45) || entry.phone_notified || entry.camera_notified {
if elapsed > PAIRING_SESSION_TIMEOUT || entry.phone_notified || entry.camera_notified {
debug!("[PAIR] Expiring session due to timeout or notification flag");
entry.expired = true;
return Json(PairingResponse {
Expand Down Expand Up @@ -295,7 +303,7 @@ async fn pair(
}

notify = entry.notify.clone();
expired_at = entry.created_at + Duration::from_secs(45);
expired_at = entry.created_at + PAIRING_SESSION_TIMEOUT;
debug!(
"[PAIR] Only one side connected, waiting until {:?}",
expired_at
Expand Down Expand Up @@ -1161,8 +1169,7 @@ pub fn build_rocket() -> rocket::Rocket<rocket::Build> {
let config = rocket::Config {
port: listen_port.unwrap_or(8000),
address: address.parse().unwrap(),
limits: Limits::default()
.limit("json", MAX_JSON_SIZE.kibibytes()),
limits: Limits::default().limit("json", MAX_JSON_SIZE.kibibytes()),
..rocket::Config::default()
};

Expand Down Expand Up @@ -1215,6 +1222,89 @@ pub fn build_rocket() -> rocket::Rocket<rocket::Build> {
)
}

#[cfg(test)]
mod pairing_tests {
use super::{auth::BasicAuth, notification_target, pair, PairingRequest, SharedPairingState};
use rocket::serde::json::Json;
use rocket::State;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

fn test_auth(username: &str) -> BasicAuth {
BasicAuth {
username: username.to_string(),
authenticated: true,
}
}

#[rocket::async_test]
async fn pairing_succeeds_with_same_authenticated_account() {
let state: SharedPairingState = Arc::new(Mutex::new(HashMap::new()));
let policy = notification_target::UnifiedPushPolicy::from_env().unwrap();
let auth = test_auth("sameacctuser01");

let (phone_response, camera_response) = rocket::tokio::join!(
pair(
Json(PairingRequest {
pairing_token: "shared-token".to_string(),
role: "phone".to_string(),
notification_target: None,
}),
State::from(&state),
State::from(&policy),
&auth,
),
pair(
Json(PairingRequest {
pairing_token: "shared-token".to_string(),
role: "camera".to_string(),
notification_target: None,
}),
State::from(&state),
State::from(&policy),
&auth,
)
);

assert_eq!(phone_response.into_inner().status, "paired");
assert_eq!(camera_response.into_inner().status, "paired");
}

#[rocket::async_test]
async fn pairing_does_not_cross_authenticated_accounts() {
let state: SharedPairingState = Arc::new(Mutex::new(HashMap::new()));
let policy = notification_target::UnifiedPushPolicy::from_env().unwrap();
let phone_auth = test_auth("phoneaccount01");
let camera_auth = test_auth("cameraaccount1");

let (phone_response, camera_response) = rocket::tokio::join!(
pair(
Json(PairingRequest {
pairing_token: "shared-token".to_string(),
role: "phone".to_string(),
notification_target: None,
}),
State::from(&state),
State::from(&policy),
&phone_auth,
),
pair(
Json(PairingRequest {
pairing_token: "shared-token".to_string(),
role: "camera".to_string(),
notification_target: None,
}),
State::from(&state),
State::from(&policy),
&camera_auth,
)
);

assert_eq!(phone_response.into_inner().status, "expired");
assert_eq!(camera_response.into_inner().status, "expired");
}
}

#[cfg(test)]
mod contract_tests {
use super::build_rocket;
Expand Down
Loading