Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/core/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub struct AppState {

impl AppState {
pub(crate) async fn new(args: AppStateArgs) -> Self {
crate::records::set_expose_internal_errors(args.dev);

let config = Reactive::new(args.config);

let public_url = args.public_url.clone();
Expand Down
30 changes: 29 additions & 1 deletion crates/core/src/records/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use axum::body::Body;
use axum::http::{StatusCode, header::CONTENT_TYPE};
use axum::response::{IntoResponse, Response};
use std::sync::atomic::{AtomicBool, Ordering};
use thiserror::Error;

static EXPOSE_INTERNAL_ERRORS: AtomicBool = AtomicBool::new(cfg!(debug_assertions));

pub(crate) fn set_expose_internal_errors(expose: bool) {
EXPOSE_INTERNAL_ERRORS.store(expose, Ordering::Relaxed);
}

/// Publicly visible errors of record APIs.
///
/// This error is deliberately opaque and kept very close to HTTP error codes to avoid the leaking
Expand All @@ -20,10 +27,21 @@ pub enum RecordError {
Forbidden,
#[error("Bad request: {0}")]
BadRequest(&'static str),
#[error("Access check failed: {0}")]
AccessCheckFailed(Box<dyn std::error::Error + Send + Sync>),
#[error("Internal: {0}")]
Internal(Box<dyn std::error::Error + Send + Sync>),
}

impl RecordError {
pub(crate) fn stable_code(&self) -> Option<&'static str> {
return match self {
Self::AccessCheckFailed(_) => Some("ACCESS_CHECK_EVAL_FAILED"),
_ => None,
};
}
}

impl From<trailbase_sqlite::Error> for RecordError {
fn from(err: trailbase_sqlite::Error) -> Self {
return match err {
Expand Down Expand Up @@ -80,13 +98,23 @@ impl From<object_store::Error> for RecordError {

impl IntoResponse for RecordError {
fn into_response(self) -> Response {
let expose_internal_errors = EXPOSE_INTERNAL_ERRORS.load(Ordering::Relaxed);

let (status, body) = match self {
Self::ApiNotFound => (StatusCode::METHOD_NOT_ALLOWED, None),
Self::ApiRequiresTable => (StatusCode::METHOD_NOT_ALLOWED, None),
Self::RecordNotFound => (StatusCode::NOT_FOUND, None),
Self::Forbidden => (StatusCode::FORBIDDEN, None),
Self::BadRequest(msg) => (StatusCode::BAD_REQUEST, Some(msg.to_string())),
Self::Internal(err) if cfg!(debug_assertions) => {
Self::AccessCheckFailed(err) if expose_internal_errors => (
StatusCode::INTERNAL_SERVER_ERROR,
Some(format!("ACCESS_CHECK_EVAL_FAILED: {err}")),
),
Self::AccessCheckFailed(_err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Some("ACCESS_CHECK_EVAL_FAILED".to_string()),
),
Self::Internal(err) if expose_internal_errors => {
(StatusCode::INTERNAL_SERVER_ERROR, Some(err.to_string()))
}
Self::Internal(_err) => (StatusCode::INTERNAL_SERVER_ERROR, None),
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/records/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ mod transaction;
mod update_record;
mod validate;

pub(crate) use error::RecordError;
pub(crate) use error::{RecordError, set_expose_internal_errors};
pub use record_api::RecordApi;
pub(crate) use validate::validate_record_api_config;

Expand Down
58 changes: 34 additions & 24 deletions crates/core/src/records/record_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ impl RecordApi {
access_query,
self.build_named_params(p, record_id, request_params, user)?,
)
.await
.await?
{
return Ok(());
}
Expand All @@ -504,37 +504,47 @@ impl RecordApi {

let params = self.build_named_params(p, record_id, request_params, user)?;

return match conn
let allowed = conn
.query_row(access_query, params)
.ok()
.and_then(|row| row.and_then(|r| r.get::<bool>(0).ok()))
{
Some(allowed) if allowed => Ok(()),
_ => Err(RecordError::Forbidden),
};
.map_err(|err| {
RecordError::AccessCheckFailed(
std::io::Error::other(format!(
"Failed to evaluate record-level access rule: {err}"
))
.into(),
)
})?
.and_then(|r| r.get::<bool>(0).ok())
.ok_or_else(|| {
RecordError::AccessCheckFailed(
std::io::Error::other("Access rule query returned no boolean result").into(),
)
})?;

if allowed {
return Ok(());
}

return Err(RecordError::Forbidden);
}

#[inline]
async fn check_record_level_access_impl(
&self,
query: impl AsRef<str> + Send + 'static,
params: impl trailbase_sqlite::Params + Send + 'static,
) -> bool {
// TODO: Remove query from debug_assert and thus the extra allocation.
let q = query.as_ref().to_string();

) -> Result<bool, RecordError> {
return match self.state.conn.read_query_row_get(query, params, 0).await {
Ok(Some(allowed)) => allowed,
Ok(None) => {
debug_assert!(false, "RLA query '{q}' returned no result");

false
}
Err(err) => {
debug_assert!(false, "RLA query '{q}' failed: {err}");

false
}
Ok(Some(allowed)) => Ok(allowed),
Ok(None) => Err(RecordError::AccessCheckFailed(
std::io::Error::other("Access rule query returned no boolean result").into(),
)),
Err(err) => Err(RecordError::AccessCheckFailed(
std::io::Error::other(format!(
"Failed to evaluate record-level access rule: {err}"
))
.into(),
)),
};
}

Expand Down Expand Up @@ -563,7 +573,7 @@ impl RecordApi {

if self
.check_record_level_access_impl(access_query.clone(), params)
.await
.await?
{
return Ok(());
}
Expand Down
5 changes: 5 additions & 0 deletions crates/core/src/records/subscribe/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum EventErrorStatus {
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct EventError {
pub status: EventErrorStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
pub message: Option<String>,
}

Expand Down Expand Up @@ -135,6 +137,7 @@ pub enum TestJsonEventPayload {
Delete(JsonObject),
Error {
status: EventErrorStatus,
code: Option<String>,
message: Option<String>,
},
Ping,
Expand Down Expand Up @@ -180,6 +183,7 @@ mod tests {
event: Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Loss,
code: None,
message: Some("test".to_string()),
},
})),
Expand Down Expand Up @@ -213,6 +217,7 @@ mod tests {
assert_eq!(
TestJsonEventPayload::Error {
status: EventErrorStatus::Forbidden,
code: None,
message: Some("test".to_string()),
},
test_event.event
Expand Down
89 changes: 70 additions & 19 deletions crates/core/src/records/subscribe/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ async fn validate_event(
return Ok(Some(ev.payload));
}

fn access_check_error_event(err: &RecordError, dev_mode: bool) -> Arc<EventPayload> {
if matches!(err, RecordError::Forbidden) {
return ACCESS_DENIED_EVENT.clone();
}

let message = if dev_mode {
Some(format!("Internal access-check error: {err}"))
} else {
Some("Internal access-check error".to_string())
};

return Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Unknown,
code: err.stable_code().map(ToOwned::to_owned),
message,
},
}));
}

pub async fn subscribe_sse(
state: AppState,
api: RecordApi,
Expand Down Expand Up @@ -204,15 +224,16 @@ pub async fn subscribe_sse(
ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
))
.boxed(),
Err(_) => {
// Death sentence for record subscriptions to not have access
Err(err) => {
let is_forbidden = matches!(err, RecordError::Forbidden);
if !is_forbidden {
log::error!("SSE record access-check failed: {err}");
}

stream::iter(vec![
// First send an error event to the user.
ACCESS_DENIED_EVENT
.clone()
access_check_error_event(&err, args.state.dev_mode())
.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
// Then terminate the stream via the `take_while` below.
Err(RecordError::Forbidden),
Err(err),
])
.boxed()
}
Expand Down Expand Up @@ -312,17 +333,19 @@ pub async fn subscribe_ws(
Ok(None) => {
continue;
}
Err(_) => {
Err(err) => {
if !matches!(err, RecordError::Forbidden) {
log::error!("WS access-check failed: {err}");
}

if is_record_subscription {
// Death sentence for record subscriptions to not have access
let _ = ACCESS_DENIED_EVENT
.clone()
let _ = access_check_error_event(&err, args.state.dev_mode())
.into_ws_event()
.map(|ev| sender.send(ev));
return;
} else {
continue;
}

continue;
}
};

Expand Down Expand Up @@ -402,9 +425,22 @@ pub async fn subscribe_ws(
// NOTE: Access checking can only happen post upgrade, since browsers & Node.js don't allow
// setting custom headers for the UPGRADE HTTP request. We could maybe use cookies in some
// places but instead expect an explicit authorization.
if let Err(_err) = api.check_table_level_access(Permission::Read, user.as_ref()) {
abort(&mut ws_sender, Code::Policy, "unauthorized").await;
return;
match api.check_table_level_access(Permission::Read, user.as_ref()) {
Ok(()) => {}
Err(RecordError::Forbidden) => {
abort(&mut ws_sender, Code::Policy, "unauthorized").await;
return;
}
Err(err) => {
log::error!("WS table subscription access-check failed: {err}");
abort(
&mut ws_sender,
Code::Unexpected,
"internal access-check error",
)
.await;
return;
}
}

let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
Expand Down Expand Up @@ -440,12 +476,25 @@ pub async fn subscribe_ws(
// NOTE: Access checking can only happen post upgrade, since browsers & Node.js don't allow
// setting custom headers for the UPGRADE HTTP request. We could maybe use cookies in some
// places but instead expect an explicit authorization.
if let Err(_) = api
match api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await
{
abort(&mut ws_sender, Code::Policy, "unauthorized").await;
return;
Ok(()) => {}
Err(RecordError::Forbidden) => {
abort(&mut ws_sender, Code::Policy, "unauthorized").await;
return;
}
Err(err) => {
log::error!("WS record subscription access-check failed: {err}");
abort(
&mut ws_sender,
Code::Unexpected,
"internal access-check error",
)
.await;
return;
}
}

let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
Expand Down Expand Up @@ -475,6 +524,7 @@ static ACCESS_DENIED_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Forbidden,
code: None,
message: Some("Access denied".into()),
},
}))
Expand All @@ -483,6 +533,7 @@ static EVENT_LOSS_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Loss,
code: None,
message: None,
},
}))
Expand Down
Loading