diff --git a/crates/core/src/app_state.rs b/crates/core/src/app_state.rs index 3e52d1e8..2338aaa1 100644 --- a/crates/core/src/app_state.rs +++ b/crates/core/src/app_state.rs @@ -92,6 +92,8 @@ pub struct AppState { impl AppState { pub(crate) async fn new(args: AppStateArgs) -> Self { + crate::records::set_expose_internal_errors(args.dev); + let config = Reactive::new(args.config); let public_url = args.public_url.clone(); diff --git a/crates/core/src/records/error.rs b/crates/core/src/records/error.rs index a13ece0f..efe3cbd8 100644 --- a/crates/core/src/records/error.rs +++ b/crates/core/src/records/error.rs @@ -1,8 +1,15 @@ use axum::body::Body; use axum::http::{StatusCode, header::CONTENT_TYPE}; use axum::response::{IntoResponse, Response}; +use std::sync::atomic::{AtomicBool, Ordering}; use thiserror::Error; +static EXPOSE_INTERNAL_ERRORS: AtomicBool = AtomicBool::new(cfg!(debug_assertions)); + +pub(crate) fn set_expose_internal_errors(expose: bool) { + EXPOSE_INTERNAL_ERRORS.store(expose, Ordering::Relaxed); +} + /// Publicly visible errors of record APIs. /// /// This error is deliberately opaque and kept very close to HTTP error codes to avoid the leaking @@ -20,10 +27,21 @@ pub enum RecordError { Forbidden, #[error("Bad request: {0}")] BadRequest(&'static str), + #[error("Access check failed: {0}")] + AccessCheckFailed(Box), #[error("Internal: {0}")] Internal(Box), } +impl RecordError { + pub(crate) fn stable_code(&self) -> Option<&'static str> { + return match self { + Self::AccessCheckFailed(_) => Some("ACCESS_CHECK_EVAL_FAILED"), + _ => None, + }; + } +} + impl From for RecordError { fn from(err: trailbase_sqlite::Error) -> Self { return match err { @@ -80,13 +98,23 @@ impl From for RecordError { impl IntoResponse for RecordError { fn into_response(self) -> Response { + let expose_internal_errors = EXPOSE_INTERNAL_ERRORS.load(Ordering::Relaxed); + let (status, body) = match self { Self::ApiNotFound => (StatusCode::METHOD_NOT_ALLOWED, None), Self::ApiRequiresTable => (StatusCode::METHOD_NOT_ALLOWED, None), Self::RecordNotFound => (StatusCode::NOT_FOUND, None), Self::Forbidden => (StatusCode::FORBIDDEN, None), Self::BadRequest(msg) => (StatusCode::BAD_REQUEST, Some(msg.to_string())), - Self::Internal(err) if cfg!(debug_assertions) => { + Self::AccessCheckFailed(err) if expose_internal_errors => ( + StatusCode::INTERNAL_SERVER_ERROR, + Some(format!("ACCESS_CHECK_EVAL_FAILED: {err}")), + ), + Self::AccessCheckFailed(_err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Some("ACCESS_CHECK_EVAL_FAILED".to_string()), + ), + Self::Internal(err) if expose_internal_errors => { (StatusCode::INTERNAL_SERVER_ERROR, Some(err.to_string())) } Self::Internal(_err) => (StatusCode::INTERNAL_SERVER_ERROR, None), diff --git a/crates/core/src/records/mod.rs b/crates/core/src/records/mod.rs index 7d41f3db..54250c00 100644 --- a/crates/core/src/records/mod.rs +++ b/crates/core/src/records/mod.rs @@ -45,7 +45,7 @@ mod transaction; mod update_record; mod validate; -pub(crate) use error::RecordError; +pub(crate) use error::{RecordError, set_expose_internal_errors}; pub use record_api::RecordApi; pub(crate) use validate::validate_record_api_config; diff --git a/crates/core/src/records/record_api.rs b/crates/core/src/records/record_api.rs index b072c3c1..c9a4bab7 100644 --- a/crates/core/src/records/record_api.rs +++ b/crates/core/src/records/record_api.rs @@ -479,7 +479,7 @@ impl RecordApi { access_query, self.build_named_params(p, record_id, request_params, user)?, ) - .await + .await? { return Ok(()); } @@ -504,14 +504,28 @@ impl RecordApi { let params = self.build_named_params(p, record_id, request_params, user)?; - return match conn + let allowed = conn .query_row(access_query, params) - .ok() - .and_then(|row| row.and_then(|r| r.get::(0).ok())) - { - Some(allowed) if allowed => Ok(()), - _ => Err(RecordError::Forbidden), - }; + .map_err(|err| { + RecordError::AccessCheckFailed( + std::io::Error::other(format!( + "Failed to evaluate record-level access rule: {err}" + )) + .into(), + ) + })? + .and_then(|r| r.get::(0).ok()) + .ok_or_else(|| { + RecordError::AccessCheckFailed( + std::io::Error::other("Access rule query returned no boolean result").into(), + ) + })?; + + if allowed { + return Ok(()); + } + + return Err(RecordError::Forbidden); } #[inline] @@ -519,22 +533,18 @@ impl RecordApi { &self, query: impl AsRef + Send + 'static, params: impl trailbase_sqlite::Params + Send + 'static, - ) -> bool { - // TODO: Remove query from debug_assert and thus the extra allocation. - let q = query.as_ref().to_string(); - + ) -> Result { return match self.state.conn.read_query_row_get(query, params, 0).await { - Ok(Some(allowed)) => allowed, - Ok(None) => { - debug_assert!(false, "RLA query '{q}' returned no result"); - - false - } - Err(err) => { - debug_assert!(false, "RLA query '{q}' failed: {err}"); - - false - } + Ok(Some(allowed)) => Ok(allowed), + Ok(None) => Err(RecordError::AccessCheckFailed( + std::io::Error::other("Access rule query returned no boolean result").into(), + )), + Err(err) => Err(RecordError::AccessCheckFailed( + std::io::Error::other(format!( + "Failed to evaluate record-level access rule: {err}" + )) + .into(), + )), }; } @@ -563,7 +573,7 @@ impl RecordApi { if self .check_record_level_access_impl(access_query.clone(), params) - .await + .await? { return Ok(()); } diff --git a/crates/core/src/records/subscribe/event.rs b/crates/core/src/records/subscribe/event.rs index 0b3a4833..05f9ce40 100644 --- a/crates/core/src/records/subscribe/event.rs +++ b/crates/core/src/records/subscribe/event.rs @@ -23,6 +23,8 @@ pub enum EventErrorStatus { #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct EventError { pub status: EventErrorStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, pub message: Option, } @@ -135,6 +137,7 @@ pub enum TestJsonEventPayload { Delete(JsonObject), Error { status: EventErrorStatus, + code: Option, message: Option, }, Ping, @@ -180,6 +183,7 @@ mod tests { event: Arc::new(EventPayload::from(&JsonEventPayload::Error { value: EventError { status: EventErrorStatus::Loss, + code: None, message: Some("test".to_string()), }, })), @@ -213,6 +217,7 @@ mod tests { assert_eq!( TestJsonEventPayload::Error { status: EventErrorStatus::Forbidden, + code: None, message: Some("test".to_string()), }, test_event.event diff --git a/crates/core/src/records/subscribe/handler.rs b/crates/core/src/records/subscribe/handler.rs index 8400d321..c45f0be3 100644 --- a/crates/core/src/records/subscribe/handler.rs +++ b/crates/core/src/records/subscribe/handler.rs @@ -133,6 +133,26 @@ async fn validate_event( return Ok(Some(ev.payload)); } +fn access_check_error_event(err: &RecordError, dev_mode: bool) -> Arc { + if matches!(err, RecordError::Forbidden) { + return ACCESS_DENIED_EVENT.clone(); + } + + let message = if dev_mode { + Some(format!("Internal access-check error: {err}")) + } else { + Some("Internal access-check error".to_string()) + }; + + return Arc::new(EventPayload::from(&JsonEventPayload::Error { + value: EventError { + status: EventErrorStatus::Unknown, + code: err.stable_code().map(ToOwned::to_owned), + message, + }, + })); +} + pub async fn subscribe_sse( state: AppState, api: RecordApi, @@ -204,15 +224,16 @@ pub async fn subscribe_sse( ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))), )) .boxed(), - Err(_) => { - // Death sentence for record subscriptions to not have access + Err(err) => { + let is_forbidden = matches!(err, RecordError::Forbidden); + if !is_forbidden { + log::error!("SSE record access-check failed: {err}"); + } + stream::iter(vec![ - // First send an error event to the user. - ACCESS_DENIED_EVENT - .clone() + access_check_error_event(&err, args.state.dev_mode()) .into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))), - // Then terminate the stream via the `take_while` below. - Err(RecordError::Forbidden), + Err(err), ]) .boxed() } @@ -312,17 +333,19 @@ pub async fn subscribe_ws( Ok(None) => { continue; } - Err(_) => { + Err(err) => { + if !matches!(err, RecordError::Forbidden) { + log::error!("WS access-check failed: {err}"); + } + if is_record_subscription { - // Death sentence for record subscriptions to not have access - let _ = ACCESS_DENIED_EVENT - .clone() + let _ = access_check_error_event(&err, args.state.dev_mode()) .into_ws_event() .map(|ev| sender.send(ev)); return; - } else { - continue; } + + continue; } }; @@ -402,9 +425,22 @@ pub async fn subscribe_ws( // NOTE: Access checking can only happen post upgrade, since browsers & Node.js don't allow // setting custom headers for the UPGRADE HTTP request. We could maybe use cookies in some // places but instead expect an explicit authorization. - if let Err(_err) = api.check_table_level_access(Permission::Read, user.as_ref()) { - abort(&mut ws_sender, Code::Policy, "unauthorized").await; - return; + match api.check_table_level_access(Permission::Read, user.as_ref()) { + Ok(()) => {} + Err(RecordError::Forbidden) => { + abort(&mut ws_sender, Code::Policy, "unauthorized").await; + return; + } + Err(err) => { + log::error!("WS table subscription access-check failed: {err}"); + abort( + &mut ws_sender, + Code::Unexpected, + "internal access-check error", + ) + .await; + return; + } } let (sender, receiver) = async_channel::bounded::(64); @@ -440,12 +476,25 @@ pub async fn subscribe_ws( // NOTE: Access checking can only happen post upgrade, since browsers & Node.js don't allow // setting custom headers for the UPGRADE HTTP request. We could maybe use cookies in some // places but instead expect an explicit authorization. - if let Err(_) = api + match api .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) .await { - abort(&mut ws_sender, Code::Policy, "unauthorized").await; - return; + Ok(()) => {} + Err(RecordError::Forbidden) => { + abort(&mut ws_sender, Code::Policy, "unauthorized").await; + return; + } + Err(err) => { + log::error!("WS record subscription access-check failed: {err}"); + abort( + &mut ws_sender, + Code::Unexpected, + "internal access-check error", + ) + .await; + return; + } } let (sender, receiver) = async_channel::bounded::(64); @@ -475,6 +524,7 @@ static ACCESS_DENIED_EVENT: LazyLock> = LazyLock::new(|| { Arc::new(EventPayload::from(&JsonEventPayload::Error { value: EventError { status: EventErrorStatus::Forbidden, + code: None, message: Some("Access denied".into()), }, })) @@ -483,6 +533,7 @@ static EVENT_LOSS_EVENT: LazyLock> = LazyLock::new(|| { Arc::new(EventPayload::from(&JsonEventPayload::Error { value: EventError { status: EventErrorStatus::Loss, + code: None, message: None, }, }))