diff --git a/.github/workflows/pkg-pr-new.yaml b/.github/workflows/pkg-pr-new.yaml index a3e38cb7e0..260d428ba4 100644 --- a/.github/workflows/pkg-pr-new.yaml +++ b/.github/workflows/pkg-pr-new.yaml @@ -24,4 +24,4 @@ jobs: find /tmp/inspector-repack -name '*.map' -delete tar czf inspector.tar.gz -C /tmp/inspector-repack . rm -rf /tmp/inspector-repack - - run: pnpm dlx pkg-pr-new publish 'shared/typescript/*' 'engine/sdks/typescript/runner/' 'engine/sdks/typescript/runner-protocol/' 'rivetkit-typescript/packages/*' --packageManager pnpm --template './examples/*' + - run: pnpm dlx pkg-pr-new publish 'shared/typescript/*' 'engine/sdks/typescript/runner/' 'engine/sdks/typescript/runner-protocol/' 'engine/sdks/typescript/envoy-client/' 'engine/sdks/typescript/envoy-protocol/' 'rivetkit-typescript/packages/*' --packageManager pnpm --template './examples/*' diff --git a/engine/artifacts/errors/actor.no_runner_config_configured.json b/engine/artifacts/errors/actor.no_runner_config_configured.json new file mode 100644 index 0000000000..bbd759bc23 --- /dev/null +++ b/engine/artifacts/errors/actor.no_runner_config_configured.json @@ -0,0 +1,5 @@ +{ + "code": "no_runner_config_configured", + "group": "actor", + "message": "No runner config configured in any datacenter. Validate a provider is listed that matches requested pool name." +} \ No newline at end of file diff --git a/engine/artifacts/errors/actor.no_runners_available.json b/engine/artifacts/errors/actor.no_runners_available.json deleted file mode 100644 index b187a48345..0000000000 --- a/engine/artifacts/errors/actor.no_runners_available.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "code": "no_runners_available", - "group": "actor", - "message": "No runners are available in any datacenter. Validate the runner is listed in the Connect tab and that the runner's name matches the requested runner name." -} \ No newline at end of file diff --git a/engine/artifacts/openapi.json b/engine/artifacts/openapi.json index bfc639b46f..5f8800a9c2 100644 --- a/engine/artifacts/openapi.json +++ b/engine/artifacts/openapi.json @@ -337,50 +337,37 @@ ] } }, - "/actors2": { - "put": { + "/datacenters": { + "get": { "tags": [ - "actors::get_or_create" - ], - "operationId": "actors2_get_or_create", - "parameters": [ - { - "name": "namespace", - "in": "query", - "required": true, - "schema": { - "type": "string" - } - } + "datacenters" ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ActorsGetOrCreateRequest" - } - } - }, - "required": true - }, + "operationId": "datacenters_list", "responses": { "200": { "description": "", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ActorsGetOrCreateResponse" + "$ref": "#/components/schemas/DatacentersListResponse" } } } } - } - }, - "post": { + }, + "security": [ + { + "bearer_auth": [] + } + ] + } + }, + "/envoys": { + "get": { "tags": [ - "actors::create" + "envoys" ], - "operationId": "actors2_create", + "operationId": "envoys_list", "parameters": [ { "name": "namespace", @@ -389,45 +376,51 @@ "schema": { "type": "string" } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ActorsCreateRequest" - } + }, + { + "name": "name", + "in": "query", + "required": false, + "schema": { + "type": "string" } }, - "required": true - }, - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ActorsCreateResponse" - } + { + "name": "envoy_key", + "in": "query", + "required": false, + "schema": { + "type": "array", + "items": { + "type": "string" } } + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { + "type": "integer", + "minimum": 0 + } + }, + { + "name": "cursor", + "in": "query", + "required": false, + "schema": { + "type": "string" + } } - } - } - }, - "/datacenters": { - "get": { - "tags": [ - "datacenters" ], - "operationId": "datacenters_list", "responses": { "200": { "description": "", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DatacentersListResponse" + "$ref": "#/components/schemas/EnvoysListResponse" } } } @@ -1377,6 +1370,101 @@ }, "additionalProperties": false }, + "Envoy": { + "type": "object", + "required": [ + "envoy_key", + "namespace_id", + "datacenter", + "pool_name", + "version", + "slots", + "create_ts", + "last_ping_ts", + "last_rtt" + ], + "properties": { + "create_ts": { + "type": "integer", + "format": "int64" + }, + "datacenter": { + "type": "string" + }, + "envoy_key": { + "type": "string" + }, + "last_connected_ts": { + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "last_ping_ts": { + "type": "integer", + "format": "int64" + }, + "last_rtt": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "metadata": { + "type": [ + "object", + "null" + ], + "additionalProperties": {}, + "propertyNames": { + "type": "string" + } + }, + "namespace_id": { + "$ref": "#/components/schemas/RivetId" + }, + "pool_name": { + "type": "string" + }, + "slots": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "stop_ts": { + "type": [ + "integer", + "null" + ], + "format": "int64" + }, + "version": { + "type": "integer", + "format": "int32", + "minimum": 0 + } + }, + "additionalProperties": false + }, + "EnvoysListResponse": { + "type": "object", + "required": [ + "envoys", + "pagination" + ], + "properties": { + "envoys": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Envoy" + } + }, + "pagination": { + "$ref": "#/components/schemas/Pagination" + } + }, + "additionalProperties": false + }, "HealthFanoutResponse": { "type": "object", "required": [ diff --git a/engine/packages/api-peer/src/actors/create.rs b/engine/packages/api-peer/src/actors/create.rs index 17413e4104..5f632ba7d3 100644 --- a/engine/packages/api-peer/src/actors/create.rs +++ b/engine/packages/api-peer/src/actors/create.rs @@ -38,39 +38,3 @@ pub async fn create( Ok(CreateResponse { actor: res.actor }) } - -#[tracing::instrument(skip_all)] -pub async fn create2( - ctx: ApiCtx, - _path: (), - query: CreateQuery, - body: CreateRequest, -) -> Result { - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { - name: query.namespace.clone(), - }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; - - let actor_id = Id::new_v1(ctx.config().dc_label()); - - let res = ctx - .op(pegboard::ops::actor::create::Input2 { - actor_id, - namespace_id: namespace.namespace_id, - name: body.name.clone(), - key: body.key, - pool_name: body.runner_name_selector, - input: body.input.clone(), - crash_policy: body.crash_policy, - // NOTE: This can forward if the user attempts to create an actor with a target dc and this dc - // ends up forwarding to another. - forward_request: true, - // api-peer is always creating in its own datacenter - datacenter_name: None, - }) - .await?; - - Ok(CreateResponse { actor: res.actor }) -} diff --git a/engine/packages/api-peer/src/actors/get_or_create.rs b/engine/packages/api-peer/src/actors/get_or_create.rs index 364d6d146c..8b7332c064 100644 --- a/engine/packages/api-peer/src/actors/get_or_create.rs +++ b/engine/packages/api-peer/src/actors/get_or_create.rs @@ -97,97 +97,6 @@ pub async fn get_or_create( } } -#[tracing::instrument(skip_all)] -pub async fn get_or_create2( - ctx: ApiCtx, - _path: (), - query: GetOrCreateQuery, - body: GetOrCreateRequest, -) -> Result { - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { - name: query.namespace.clone(), - }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; - - // Check if actor already exists for the key - let existing = ctx - .op(pegboard::ops::actor::get_for_key::Input { - namespace_id: namespace.namespace_id, - name: body.name.clone(), - key: body.key.clone(), - fetch_error: true, - }) - .await?; - - if let Some(actor) = existing.actor { - // Actor exists, return it - return Ok(GetOrCreateResponse { - actor, - created: false, - }); - } - - // Actor doesn't exist, create it - let actor_id = Id::new_v1(ctx.config().dc_label()); - - match ctx - .op(pegboard::ops::actor::create::Input2 { - actor_id, - namespace_id: namespace.namespace_id, - name: body.name.clone(), - key: Some(body.key.clone()), - pool_name: body.runner_name_selector, - input: body.input.clone(), - crash_policy: body.crash_policy, - // NOTE: This can forward if the user attempts to create an actor with a target dc and this dc - // ends up forwarding to another. - forward_request: true, - // api-peer is always creating in its own datacenter - datacenter_name: None, - }) - .await - { - Ok(res) => Ok(GetOrCreateResponse { - actor: res.actor, - created: true, - }), - Err(err) => { - // Check if this is a DuplicateKey error and extract the existing actor ID - if let Some(existing_actor_id) = extract_duplicate_key_error(&err) { - tracing::info!( - ?existing_actor_id, - "received duplicate key error, fetching existing actor" - ); - - // Fetch the existing actor - it should be in this datacenter since - // the duplicate key error came from this datacenter - let res = ctx - .op(pegboard::ops::actor::get::Input { - actor_ids: vec![existing_actor_id], - fetch_error: true, - }) - .await?; - - let actor = res - .actors - .into_iter() - .next() - .ok_or_else(|| pegboard::errors::Actor::NotFound.build())?; - - return Ok(GetOrCreateResponse { - actor, - created: false, - }); - } - - // Re-throw the original error if it's not a DuplicateKey - Err(err) - } - } -} - /// Helper function to extract the existing actor ID from a duplicate key error /// /// Returns Some(actor_id) if the error is a duplicate key error with metadata, None otherwise diff --git a/engine/packages/api-peer/src/envoys.rs b/engine/packages/api-peer/src/envoys.rs new file mode 100644 index 0000000000..6c5670c89e --- /dev/null +++ b/engine/packages/api-peer/src/envoys.rs @@ -0,0 +1,57 @@ +use anyhow::Result; +use rivet_api_builder::ApiCtx; +use rivet_api_types::{envoys::list::*, pagination::Pagination}; + +#[utoipa::path( + get, + operation_id = "envoys_list", + path = "/envoys", + params(ListQuery), + responses( + (status = 200, body = ListResponse), + ), +)] +#[tracing::instrument(skip_all)] +pub async fn list(ctx: ApiCtx, _path: (), query: ListQuery) -> Result { + let namespace = ctx + .op(namespace::ops::resolve_for_name_global::Input { + name: query.namespace.clone(), + }) + .await? + .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; + + if !query.envoy_key.is_empty() { + let envoys = ctx + .op(pegboard::ops::envoy::get::Input { + namespace_id: namespace.namespace_id, + envoy_keys: query.envoy_key.clone(), + }) + .await? + .envoys; + + Ok(ListResponse { + envoys, + pagination: Pagination { cursor: None }, + }) + } else { + let list_res = ctx + .op(pegboard::ops::envoy::list::Input { + namespace_id: namespace.namespace_id, + pool_name: query.name, + created_before: query + .cursor + .as_deref() + .map(|c| c.parse::()) + .transpose()?, + limit: query.limit.unwrap_or(100), + }) + .await?; + + let cursor = list_res.envoys.last().map(|x| x.create_ts.to_string()); + + Ok(ListResponse { + envoys: list_res.envoys, + pagination: Pagination { cursor }, + }) + } +} diff --git a/engine/packages/api-peer/src/lib.rs b/engine/packages/api-peer/src/lib.rs index 496d7f9459..b52cb2d70a 100644 --- a/engine/packages/api-peer/src/lib.rs +++ b/engine/packages/api-peer/src/lib.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use anyhow::*; pub mod actors; +pub mod envoys; pub mod internal; pub mod namespaces; pub mod router; diff --git a/engine/packages/api-peer/src/router.rs b/engine/packages/api-peer/src/router.rs index a8c0a14f86..7fe2c9beb0 100644 --- a/engine/packages/api-peer/src/router.rs +++ b/engine/packages/api-peer/src/router.rs @@ -1,6 +1,6 @@ use rivet_api_builder::{create_router, prelude::*}; -use crate::{actors, internal, namespaces, runner_configs, runners}; +use crate::{actors, envoys, internal, namespaces, runner_configs, runners}; #[tracing::instrument(skip_all)] pub async fn router( @@ -23,9 +23,7 @@ pub async fn router( // MARK: Actors .route("/actors", get(actors::list::list)) .route("/actors", post(actors::create::create)) - .route("/actors2", post(actors::create::create2)) .route("/actors", put(actors::get_or_create::get_or_create)) - .route("/actors2", put(actors::get_or_create::get_or_create2)) .route("/actors/{actor_id}", delete(actors::delete::delete)) .route("/actors/names", get(actors::list_names::list_names)) .route( @@ -35,6 +33,8 @@ pub async fn router( // MARK: Runners .route("/runners", get(runners::list)) .route("/runners/names", get(runners::list_names)) + // MARK: Envoys + .route("/envoys", get(envoys::list)) // MARK: Internal .route("/cache/purge", post(internal::cache_purge)) .route( diff --git a/engine/packages/api-public/src/actors/create.rs b/engine/packages/api-public/src/actors/create.rs index 109a0c65d7..c86a8e9647 100644 --- a/engine/packages/api-public/src/actors/create.rs +++ b/engine/packages/api-public/src/actors/create.rs @@ -88,67 +88,3 @@ async fn create_inner( .await } } - -#[utoipa::path( - post, - operation_id = "actors2_create", - path = "/actors2", - params(CreateQuery), - request_body(content = CreateRequest, content_type = "application/json"), - responses( - (status = 200, body = CreateResponse), - ), -)] -pub async fn create2( - Extension(ctx): Extension, - Query(query): Query, - Json(body): Json, -) -> Response { - match create2_inner(ctx, query, body).await { - Ok(response) => Json(response).into_response(), - Err(err) => ApiError::from(err).into_response(), - } -} - -#[tracing::instrument(skip_all)] -async fn create2_inner( - ctx: ApiCtx, - query: CreateQuery, - body: CreateRequest, -) -> Result { - ctx.skip_auth(); - - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { - name: query.namespace.clone(), - }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; - - let target_dc_label = super::utils::find_dc_for_actor_creation( - &ctx, - namespace.namespace_id, - &query.namespace, - &body.runner_name_selector, - body.datacenter.as_ref().map(String::as_str), - ) - .await?; - - let query = rivet_api_types::actors::create::CreateQuery { - namespace: query.namespace, - }; - - if target_dc_label == ctx.config().dc_label() { - rivet_api_peer::actors::create::create2(ctx.into(), (), query, body).await - } else { - request_remote_datacenter::( - ctx.config(), - target_dc_label, - "/actors2", - axum::http::Method::POST, - Some(&query), - Some(&body), - ) - .await - } -} diff --git a/engine/packages/api-public/src/actors/get_or_create.rs b/engine/packages/api-public/src/actors/get_or_create.rs index b48e21bb89..73cf084077 100644 --- a/engine/packages/api-public/src/actors/get_or_create.rs +++ b/engine/packages/api-public/src/actors/get_or_create.rs @@ -98,67 +98,3 @@ async fn get_or_create_inner( .await } } - -#[utoipa::path( - put, - operation_id = "actors2_get_or_create", - path = "/actors2", - params(GetOrCreateQuery), - request_body(content = GetOrCreateRequest, content_type = "application/json"), - responses( - (status = 200, body = GetOrCreateResponse), - ), -)] -pub async fn get_or_create2( - Extension(ctx): Extension, - Query(query): Query, - Json(body): Json, -) -> Response { - match get_or_create_inner2(ctx, query, body).await { - Ok(response) => Json(response).into_response(), - Err(err) => ApiError::from(err).into_response(), - } -} - -#[tracing::instrument(skip_all)] -async fn get_or_create_inner2( - ctx: ApiCtx, - query: GetOrCreateQuery, - body: GetOrCreateRequest, -) -> Result { - ctx.skip_auth(); - - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { - name: query.namespace.clone(), - }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; - - let target_dc_label = super::utils::find_dc_for_actor_creation( - &ctx, - namespace.namespace_id, - &query.namespace, - &body.runner_name_selector, - body.datacenter.as_ref().map(String::as_str), - ) - .await?; - - let query = GetOrCreateQuery { - namespace: query.namespace, - }; - - if target_dc_label == ctx.config().dc_label() { - rivet_api_peer::actors::get_or_create::get_or_create2(ctx.into(), (), query, body).await - } else { - request_remote_datacenter::( - ctx.config(), - target_dc_label, - "/actors2", - axum::http::Method::PUT, - Some(&query), - Some(&body), - ) - .await - } -} diff --git a/engine/packages/api-public/src/actors/utils.rs b/engine/packages/api-public/src/actors/utils.rs index a312042695..10f58f70b4 100644 --- a/engine/packages/api-public/src/actors/utils.rs +++ b/engine/packages/api-public/src/actors/utils.rs @@ -156,9 +156,9 @@ pub async fn find_dc_for_actor_creation( if let Some(dc_label) = res.dc_labels.into_iter().next() { dc_label } else { - return Err(pegboard::errors::Actor::NoRunnersAvailable { + return Err(pegboard::errors::Actor::NoRunnerConfigConfigured { namespace: namespace_name.into(), - runner_name: runner_name.into(), + pool_name: runner_name.into(), } .build()); } diff --git a/engine/packages/api-public/src/envoys.rs b/engine/packages/api-public/src/envoys.rs new file mode 100644 index 0000000000..3307d2e0d5 --- /dev/null +++ b/engine/packages/api-public/src/envoys.rs @@ -0,0 +1,57 @@ +use anyhow::Result; +use axum::response::{IntoResponse, Response}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Json, Query}, +}; +use rivet_api_types::{envoys::list::*, pagination::Pagination}; +use rivet_api_util::fanout_to_datacenters; + +use crate::ctx::ApiCtx; + +#[utoipa::path( + get, + operation_id = "envoys_list", + path = "/envoys", + params(ListQuery), + responses( + (status = 200, body = ListResponse), + ), + security(("bearer_auth" = [])), +)] +#[tracing::instrument(skip_all)] +pub async fn list(Extension(ctx): Extension, Query(query): Query) -> Response { + match list_inner(ctx, query).await { + Ok(response) => Json(response).into_response(), + Err(err) => ApiError::from(err).into_response(), + } +} + +async fn list_inner(ctx: ApiCtx, query: ListQuery) -> Result { + ctx.auth().await?; + + // Fanout to all datacenters + let mut envoys = + fanout_to_datacenters::>( + ctx.into(), + "/envoys", + query.clone(), + |ctx, query| async move { rivet_api_peer::envoys::list(ctx, (), query).await }, + |_, res, agg| agg.extend(res.envoys), + ) + .await?; + + // Sort by create ts desc + envoys.sort_by_cached_key(|x| std::cmp::Reverse(x.create_ts)); + + // Shorten array since returning all envoys from all regions could end up returning `regions * + // limit` results, which is a lot. + envoys.truncate(query.limit.unwrap_or(100)); + + let cursor = envoys.last().map(|x| x.create_ts.to_string()); + + Ok(ListResponse { + envoys, + pagination: Pagination { cursor }, + }) +} diff --git a/engine/packages/api-public/src/lib.rs b/engine/packages/api-public/src/lib.rs index b3bf9a51a8..c7a2c33706 100644 --- a/engine/packages/api-public/src/lib.rs +++ b/engine/packages/api-public/src/lib.rs @@ -1,6 +1,7 @@ pub mod actors; pub mod ctx; pub mod datacenters; +pub mod envoys; mod errors; pub mod health; pub mod metadata; diff --git a/engine/packages/api-public/src/router.rs b/engine/packages/api-public/src/router.rs index debb8856a2..68d3a2a18e 100644 --- a/engine/packages/api-public/src/router.rs +++ b/engine/packages/api-public/src/router.rs @@ -8,21 +8,22 @@ use rivet_api_builder::{create_router, extract::FailedExtraction}; use tower_http::cors::CorsLayer; use utoipa::OpenApi; -use crate::{actors, ctx, datacenters, health, metadata, namespaces, runner_configs, runners, ui}; +use crate::{ + actors, ctx, datacenters, envoys, health, metadata, namespaces, runner_configs, runners, ui, +}; #[derive(OpenApi)] #[openapi( paths( actors::list::list, actors::create::create, - actors::create::create2, actors::delete::delete, actors::list_names::list_names, actors::get_or_create::get_or_create, - actors::get_or_create::get_or_create2, actors::kv_get::kv_get, runners::list, runners::list_names, + envoys::list, namespaces::list, namespaces::create, runner_configs::list::list, @@ -81,15 +82,10 @@ pub async fn router( // MARK: Actors .route("/actors", axum::routing::get(actors::list::list)) .route("/actors", axum::routing::post(actors::create::create)) - .route("/actors2", axum::routing::post(actors::create::create2)) .route( "/actors", axum::routing::put(actors::get_or_create::get_or_create), ) - .route( - "/actors2", - axum::routing::put(actors::get_or_create::get_or_create2), - ) .route( "/actors/{actor_id}", axum::routing::delete(actors::delete::delete), @@ -104,6 +100,8 @@ pub async fn router( ) // MARK: Runners .route("/runners", axum::routing::get(runners::list)) + // MARK: Envoys + .route("/envoys", axum::routing::get(envoys::list)) .route("/runners/names", axum::routing::get(runners::list_names)) // MARK: Datacenters .route("/datacenters", axum::routing::get(datacenters::list)) diff --git a/engine/packages/api-public/src/runner_configs/utils.rs b/engine/packages/api-public/src/runner_configs/utils.rs index 5049530795..04cdd35cd2 100644 --- a/engine/packages/api-public/src/runner_configs/utils.rs +++ b/engine/packages/api-public/src/runner_configs/utils.rs @@ -13,7 +13,7 @@ pub struct ServerlessMetadata { pub runtime: String, pub version: String, pub actor_names: HashMap, - pub runner_version: Option, + pub envoy_version: Option, } impl From for ServerlessMetadata { @@ -26,7 +26,7 @@ impl From for ServerlessMetad .into_iter() .map(|a| (a.name, serde_json::Value::Object(a.metadata))) .collect(), - runner_version: output.runner_version, + envoy_version: output.envoy_version, } } } @@ -40,12 +40,10 @@ pub async fn fetch_serverless_runner_metadata( url: String, headers: HashMap, ) -> Result { - let result = ctx - .op(pegboard::ops::serverless_metadata::fetch::Input { url, headers }) + ctx.op(pegboard::ops::serverless_metadata::fetch::Input { url, headers }) .await - .map_err(|_| ServerlessMetadataError::RequestFailed {})?; - - result.map(ServerlessMetadata::from) + .map_err(|_| ServerlessMetadataError::RequestFailed {})? + .map(ServerlessMetadata::from) } /// Fetches metadata from the given URL and populates actor names in the database. diff --git a/engine/packages/api-types/src/envoys/list.rs b/engine/packages/api-types/src/envoys/list.rs new file mode 100644 index 0000000000..64fbb710e4 --- /dev/null +++ b/engine/packages/api-types/src/envoys/list.rs @@ -0,0 +1,24 @@ +use serde::{Deserialize, Serialize}; +use utoipa::{IntoParams, ToSchema}; + +use crate::pagination::Pagination; + +#[derive(Debug, Serialize, Deserialize, Clone, IntoParams)] +#[serde(deny_unknown_fields)] +#[into_params(parameter_in = Query)] +pub struct ListQuery { + pub namespace: String, + pub name: Option, + #[serde(default)] + pub envoy_key: Vec, + pub limit: Option, + pub cursor: Option, +} + +#[derive(Serialize, Deserialize, ToSchema)] +#[serde(deny_unknown_fields)] +#[schema(as = EnvoysListResponse)] +pub struct ListResponse { + pub envoys: Vec, + pub pagination: Pagination, +} diff --git a/engine/packages/api-types/src/envoys/mod.rs b/engine/packages/api-types/src/envoys/mod.rs new file mode 100644 index 0000000000..d17e233fbf --- /dev/null +++ b/engine/packages/api-types/src/envoys/mod.rs @@ -0,0 +1 @@ +pub mod list; diff --git a/engine/packages/api-types/src/lib.rs b/engine/packages/api-types/src/lib.rs index 29ad6a120e..f90c62cb23 100644 --- a/engine/packages/api-types/src/lib.rs +++ b/engine/packages/api-types/src/lib.rs @@ -1,5 +1,6 @@ pub mod actors; pub mod datacenters; +pub mod envoys; pub mod namespaces; pub mod pagination; pub mod runner_configs; diff --git a/engine/packages/gasoline/src/worker.rs b/engine/packages/gasoline/src/worker.rs index 1ff51f8cdb..7654df9e64 100644 --- a/engine/packages/gasoline/src/worker.rs +++ b/engine/packages/gasoline/src/worker.rs @@ -344,10 +344,10 @@ impl Worker { let shutdown_start = Instant::now(); loop { // Future will resolve once all workflow tasks complete - let join_fut = async { while let Some(_) = wf_futs.next().await {} }; + let complete_fut = async { while let Some(_) = wf_futs.next().await {} }; tokio::select! { - _ = join_fut => { + _ = complete_fut => { break; } _ = progress_interval.tick() => { diff --git a/engine/packages/guard/src/routing/pegboard_gateway.rs b/engine/packages/guard/src/routing/pegboard_gateway.rs index 15fa5ed2ab..e06fdb70b0 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway.rs @@ -64,7 +64,9 @@ pub async fn route_request_path_based( .context("invalid x-rivet-token header")? }; - route_request_inner(ctx, shared_state, req_ctx, actor_id, stripped_path, token).await + route_request_inner(ctx, shared_state, req_ctx, actor_id, stripped_path, token) + .await + .map(Some) } /// Route requests to actor services based on headers @@ -167,7 +169,7 @@ async fn route_request_inner( .dc_for_label(actor_id.label()) .context("dc with the given label not found")?; - return Ok(Some(RoutingOutput::Route(RouteConfig { + return Ok(RoutingOutput::Route(RouteConfig { targets: vec![RouteTarget { host: peer_dc .proxy_url_host() @@ -178,7 +180,7 @@ async fn route_request_inner( .context("bad peer dc proxy url port")?, path: req_ctx.path().to_owned(), }], - }))); + })); } // Create subs before checking if actor exists/is not destroyed diff --git a/engine/packages/pegboard-envoy/src/conn.rs b/engine/packages/pegboard-envoy/src/conn.rs index 7294c8ed00..8c3e6ad8d0 100644 --- a/engine/packages/pegboard-envoy/src/conn.rs +++ b/engine/packages/pegboard-envoy/src/conn.rs @@ -197,6 +197,15 @@ pub async fn handle_init( ), (), )?; + tx.write( + &pegboard::keys::ns::ActiveEnvoyByNameKey::new( + namespace_id, + pool_name.clone(), + create_ts, + envoy_key.clone(), + ), + (), + )?; // Unset expired (upon reconnection) if create_ts_entry.is_some() { diff --git a/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs index 25b36c3212..e50f3c8c13 100644 --- a/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-envoy/src/tunnel_to_ws_task.rs @@ -29,7 +29,10 @@ pub async fn task( .await? { Ok(msg) => { - handle_message(&ctx, &conn, msg).await?; + let evicted = handle_message(&ctx, &conn, msg).await?; + if evicted { + return Ok(LifecycleResult::Evicted); + } } Err(lifecycle_res) => return Ok(lifecycle_res), } @@ -78,13 +81,17 @@ async fn recv_msg( Ok(Ok(tunnel_msg)) } -async fn handle_message(ctx: &StandaloneCtx, conn: &Conn, tunnel_msg: ups::Message) -> Result<()> { +async fn handle_message( + ctx: &StandaloneCtx, + conn: &Conn, + tunnel_msg: ups::Message, +) -> Result { // Parse message let msg = match versioned::ToEnvoyConn::deserialize_with_embedded_version(&tunnel_msg.payload) { Result::Ok(x) => x, Err(err) => { tracing::error!(?err, "failed to parse tunnel message"); - return Ok(()); + return Ok(false); } }; @@ -113,8 +120,9 @@ async fn handle_message(ctx: &StandaloneCtx, conn: &Conn, tunnel_msg: ups::Messa })?; // Not sent to envoy - return Ok(()); + return Ok(false); } + protocol::ToEnvoyConn::ToEnvoyConnClose => return Ok(true), protocol::ToEnvoyConn::ToEnvoyCommands(mut command_wrappers) => { // TODO: Parallelize for command_wrapper in &mut command_wrappers { @@ -159,5 +167,5 @@ async fn handle_message(ctx: &StandaloneCtx, conn: &Conn, tunnel_msg: ups::Messa .await .context("failed to send message to WebSocket")?; - Ok(()) + Ok(false) } diff --git a/engine/packages/pegboard-outbound/src/lib.rs b/engine/packages/pegboard-outbound/src/lib.rs index 7d607727b2..4644b3648b 100644 --- a/engine/packages/pegboard-outbound/src/lib.rs +++ b/engine/packages/pegboard-outbound/src/lib.rs @@ -1,15 +1,15 @@ use anyhow::Result; -use futures_util::StreamExt; +use futures_util::{StreamExt, stream::FuturesUnordered}; use gas::prelude::*; use pegboard::pubsub_subjects::ServerlessOutboundSubject; use reqwest::header::{HeaderName, HeaderValue}; use reqwest_eventsource as sse; -use rivet_envoy_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use rivet_envoy_protocol::{self as protocol, versioned}; use rivet_runtime::TermSignal; use rivet_types::actor::RunnerPoolError; use rivet_types::runner_configs::RunnerConfigKind; use std::collections::HashMap; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::task::JoinHandle; use universalpubsub::NextOutput; use vbare::OwnedVersionedData; @@ -20,6 +20,7 @@ const X_RIVET_ENDPOINT: HeaderName = HeaderName::from_static("x-rivet-endpoint") const X_RIVET_TOKEN: HeaderName = HeaderName::from_static("x-rivet-token"); const X_RIVET_POOL_NAME: HeaderName = HeaderName::from_static("x-rivet-pool-name"); const X_RIVET_NAMESPACE_NAME: HeaderName = HeaderName::from_static("x-rivet-namespace-name"); +const SHUTDOWN_PROGRESS_INTERVAL: Duration = Duration::from_secs(7); #[tracing::instrument(skip_all)] pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> Result<()> { @@ -45,8 +46,52 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R } } - // Wait for remaining conns to stop - futures_util::future::join_all(conns.into_iter().map(|c| c.handle).collect::>()).await; + let mut term_signal = TermSignal::get(); + let shutdown_duration = config.runtime.guard_shutdown_duration(); + tracing::info!(remaining_conns=%conns.len(), duration=?shutdown_duration, "starting outbound shutdown"); + + let mut conn_futs = conns + .iter_mut() + .map(|conn| &mut conn.handle) + .collect::>(); + + let mut progress_interval = tokio::time::interval(SHUTDOWN_PROGRESS_INTERVAL); + progress_interval.tick().await; + + let shutdown_start = Instant::now(); + loop { + // Future will resolve once all workflow tasks complete + let complete_fut = async { while let Some(_) = conn_futs.next().await {} }; + + tokio::select! { + // Wait for remaining conns to stop + _ = complete_fut => { + break; + } + _ = progress_interval.tick() => { + tracing::info!(remaining_conns=%conn_futs.len(), "outbound still shutting down"); + } + abort = term_signal.recv() => { + if abort { + tracing::warn!("aborting outbound shutdown"); + break; + } + } + _ = tokio::time::sleep(shutdown_duration.saturating_sub(shutdown_start.elapsed())) => { + tracing::warn!("outbound shutdown timed out"); + break; + } + } + } + + let remaining_conns = conn_futs.into_iter().count(); + if remaining_conns == 0 { + tracing::info!("all outbound connections complete"); + } else { + tracing::warn!(%remaining_conns, "not all outbound connections completed"); + } + + tracing::info!("outbound shutdown complete"); res } @@ -114,6 +159,8 @@ async fn handle(ctx: &StandaloneCtx, packet: protocol::ToOutbound) -> Result<()> let actor_id = Id::parse(&checkpoint.actor_id)?; let generation = checkpoint.generation; + tracing::debug!(?namespace_id, %pool_name, ?actor_id, ?generation, "received outbound request"); + // Check pool let (pool_res, namespace_res) = tokio::try_join!( ctx.op(pegboard::ops::runner_config::get::Input { @@ -151,7 +198,7 @@ async fn handle(ctx: &StandaloneCtx, packet: protocol::ToOutbound) -> Result<()> }), }, ])) - .serialize_with_embedded_version(PROTOCOL_VERSION)?; + .serialize_with_embedded_version(pool.protocol_version.unwrap_or(1))?; let RunnerConfigKind::Serverless { url, diff --git a/engine/packages/pegboard/src/errors.rs b/engine/packages/pegboard/src/errors.rs index 94a560446f..6744bce7c8 100644 --- a/engine/packages/pegboard/src/errors.rs +++ b/engine/packages/pegboard/src/errors.rs @@ -57,13 +57,13 @@ pub enum Actor { KeyReservedInDifferentDatacenter { datacenter_label: u16 }, #[error( - "no_runners_available", - "No runners are available in any datacenter. Validate the runner is listed in the Connect tab and that the runner's name matches the requested runner name.", - "No runners with name '{runner_name}' are available in any datacenter for the namespace '{namespace}'. Validate the runner is listed in the Connect tab and that the runner's name matches the requested runner name." + "no_runner_config_configured", + "No runner config configured in any datacenter. Validate a provider is listed that matches requested pool name.", + "No runner config with name '{pool_name}' are available in any datacenter for the namespace '{namespace}'. Validate a provider is listed that matches the requested pool name." )] - NoRunnersAvailable { + NoRunnerConfigConfigured { namespace: String, - runner_name: String, + pool_name: String, }, #[error("kv_key_not_found", "The KV key does not exist for this actor.")] diff --git a/engine/packages/pegboard/src/keys/ns.rs b/engine/packages/pegboard/src/keys/ns.rs index cbfd78ec25..04cba4ec44 100644 --- a/engine/packages/pegboard/src/keys/ns.rs +++ b/engine/packages/pegboard/src/keys/ns.rs @@ -1716,3 +1716,131 @@ impl TuplePack for ActiveEnvoySubspaceKey { Ok(offset) } } + +#[derive(Debug)] +pub struct ActiveEnvoyByNameKey { + namespace_id: Id, + pub name: String, + pub create_ts: i64, + pub envoy_key: String, +} + +impl ActiveEnvoyByNameKey { + pub fn new(namespace_id: Id, name: String, create_ts: i64, envoy_key: String) -> Self { + ActiveEnvoyByNameKey { + namespace_id, + name, + create_ts, + envoy_key, + } + } + + pub fn subspace(namespace_id: Id, name: String) -> ActiveEnvoyByNameSubspaceKey { + ActiveEnvoyByNameSubspaceKey::new(namespace_id, name) + } + + pub fn subspace_with_create_ts( + namespace_id: Id, + name: String, + create_ts: i64, + ) -> ActiveEnvoyByNameSubspaceKey { + ActiveEnvoyByNameSubspaceKey::new_with_create_ts(namespace_id, name, create_ts) + } +} + +impl FormalKey for ActiveEnvoyByNameKey { + type Value = (); + + fn deserialize(&self, _raw: &[u8]) -> Result { + Ok(()) + } + + fn serialize(&self, _value: Self::Value) -> Result> { + Ok(Vec::new()) + } +} + +impl TuplePack for ActiveEnvoyByNameKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + NAMESPACE, + ENVOY, + BY_NAME, + ACTIVE, + self.namespace_id, + &self.name, + self.create_ts, + &self.envoy_key, + ); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for ActiveEnvoyByNameKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, _, _, _, namespace_id, name, create_ts, envoy_key)) = + <(usize, usize, usize, usize, Id, String, i64, String)>::unpack(input, tuple_depth)?; + let v = ActiveEnvoyByNameKey { + namespace_id, + name, + create_ts, + envoy_key, + }; + + Ok((input, v)) + } +} + +pub struct ActiveEnvoyByNameSubspaceKey { + namespace_id: Id, + name: String, + create_ts: Option, +} + +impl ActiveEnvoyByNameSubspaceKey { + pub fn new(namespace_id: Id, name: String) -> Self { + ActiveEnvoyByNameSubspaceKey { + namespace_id, + name, + create_ts: None, + } + } + + pub fn new_with_create_ts(namespace_id: Id, name: String, create_ts: i64) -> Self { + ActiveEnvoyByNameSubspaceKey { + namespace_id, + name, + create_ts: Some(create_ts), + } + } +} + +impl TuplePack for ActiveEnvoyByNameSubspaceKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let mut offset = VersionstampOffset::None { size: 0 }; + + let t = ( + NAMESPACE, + ENVOY, + BY_NAME, + ACTIVE, + self.namespace_id, + &self.name, + ); + offset += t.pack(w, tuple_depth)?; + + if let Some(create_ts) = &self.create_ts { + offset += create_ts.pack(w, tuple_depth)?; + } + + Ok(offset) + } +} diff --git a/engine/packages/pegboard/src/keys/runner_config.rs b/engine/packages/pegboard/src/keys/runner_config.rs index 7974a121ff..d3490d9fe9 100644 --- a/engine/packages/pegboard/src/keys/runner_config.rs +++ b/engine/packages/pegboard/src/keys/runner_config.rs @@ -209,3 +209,56 @@ impl TuplePack for ByVariantSubspaceKey { Ok(offset) } } + +#[derive(Debug)] +pub struct ProtocolVersionKey { + pub namespace_id: Id, + pub name: String, +} + +impl ProtocolVersionKey { + pub fn new(namespace_id: Id, name: String) -> Self { + ProtocolVersionKey { namespace_id, name } + } +} + +impl FormalKey for ProtocolVersionKey { + type Value = u16; + + fn deserialize(&self, raw: &[u8]) -> Result { + Ok(u16::from_be_bytes(raw.try_into()?)) + } + + fn serialize(&self, value: Self::Value) -> Result> { + Ok(value.to_be_bytes().to_vec()) + } +} + +impl TuplePack for ProtocolVersionKey { + fn pack( + &self, + w: &mut W, + tuple_depth: TupleDepth, + ) -> std::io::Result { + let t = ( + RUNNER, + CONFIG, + DATA, + PROTOCOL_VERSION, + self.namespace_id, + &self.name, + ); + t.pack(w, tuple_depth) + } +} + +impl<'de> TupleUnpack<'de> for ProtocolVersionKey { + fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> { + let (input, (_, _, _, _, namespace_id, name)) = + <(usize, usize, usize, usize, Id, String)>::unpack(input, tuple_depth)?; + + let v = ProtocolVersionKey { namespace_id, name }; + + Ok((input, v)) + } +} diff --git a/engine/packages/pegboard/src/metrics.rs b/engine/packages/pegboard/src/metrics.rs index 6db13e6dee..0a7ad51802 100644 --- a/engine/packages/pegboard/src/metrics.rs +++ b/engine/packages/pegboard/src/metrics.rs @@ -37,6 +37,13 @@ lazy_static::lazy_static! { *REGISTRY ).unwrap(); + pub static ref ENVOY_VERSION_UPGRADE_DRAIN_TOTAL: IntCounterVec = register_int_counter_vec_with_registry!( + "pegboard_envoy_version_upgrade_drain_total", + "Count of envoys drained due to version upgrade.", + &["namespace_id", "pool_name"], + *REGISTRY + ).unwrap(); + pub static ref SERVERLESS_OUTBOUND_REQ_TOTAL: IntCounterVec = register_int_counter_vec_with_registry!( "pegboard_serverless_outbound_req_total", "Count of serverless outbound requests made.", diff --git a/engine/packages/pegboard/src/ops/actor/create.rs b/engine/packages/pegboard/src/ops/actor/create.rs index 604975e813..2cd78a7da2 100644 --- a/engine/packages/pegboard/src/ops/actor/create.rs +++ b/engine/packages/pegboard/src/ops/actor/create.rs @@ -30,166 +30,133 @@ pub struct Output { #[operation] pub async fn pegboard_actor_create(ctx: &OperationCtx, input: &Input) -> Result { // Set up subscriptions before dispatching workflow - let mut create_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - let mut fail_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - let mut destroy_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - - // TODO: check rivetkit version before choosing actor version - - // Dispatch actor workflow - ctx.workflow(crate::workflows::actor::Input { - actor_id: input.actor_id, - name: input.name.clone(), - runner_name_selector: input.runner_name_selector.clone(), - key: input.key.clone(), - namespace_id: input.namespace_id, - crash_policy: input.crash_policy, - input: input.input.clone(), - }) - .tag("actor_id", input.actor_id) - .dispatch() - .await?; + let ( + mut create_sub, + mut fail_sub, + mut destroy_sub, + mut create_sub2, + mut fail_sub2, + mut destroy_sub2, + pool_res, + ) = tokio::try_join!( + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.subscribe::(("actor_id", input.actor_id)), + ctx.op(crate::ops::runner_config::get::Input { + runners: vec![(input.namespace_id, input.runner_name_selector.clone())], + bypass_cache: false, + }), + )?; - // Wait for actor creation to complete, fail, or be destroyed - tokio::select! { - res = create_sub.next() => { res?; }, - res = fail_sub.next() => { - let msg = res?; - let error = msg.into_body().error; + let actor_v2 = pool_res + .into_iter() + .next() + .map(|p| p.protocol_version.is_some()) + .unwrap_or_default(); + + if actor_v2 { + // Dispatch actor workflow + ctx.workflow(crate::workflows::actor2::Input { + actor_id: input.actor_id, + name: input.name.clone(), + pool_name: input.runner_name_selector.clone(), + key: input.key.clone(), + namespace_id: input.namespace_id, + crash_policy: input.crash_policy, + input: input.input.clone(), + from_v1: false, + }) + .tag("actor_id", input.actor_id) + .dispatch() + .await?; - // Check if this request needs to be forwarded - // - // We cannot forward if `datacenter_name` is specified because this actor is being - // restricted to the given datacenter. - if input.forward_request && input.datacenter_name.is_none() { - if let crate::errors::Actor::KeyReservedInDifferentDatacenter { datacenter_label } = &error { - // Forward the request to the correct datacenter - return forward_to_datacenter( - ctx, - *datacenter_label, - input.namespace_id, - input.name.clone(), - input.key.clone(), - input.runner_name_selector.clone(), - input.input.clone(), - input.crash_policy - ).await; + // Wait for actor creation to complete, fail, or be destroyed + tokio::select! { + res = create_sub2.next() => { res?; }, + res = fail_sub2.next() => { + let msg = res?; + let error = msg.into_body().error; + + // Check if this request needs to be forwarded + // + // We cannot forward if `datacenter_name` is specified because this actor is being + // restricted to the given datacenter. + if input.forward_request && input.datacenter_name.is_none() { + if let crate::errors::Actor::KeyReservedInDifferentDatacenter { datacenter_label } = &error { + // Forward the request to the correct datacenter + return forward_to_datacenter( + ctx, + *datacenter_label, + input.namespace_id, + input.name.clone(), + input.key.clone(), + input.runner_name_selector.clone(), + input.input.clone(), + input.crash_policy + ).await; + } } - } - // Otherwise, return the error as-is - return Err(error.build()); - } - res = destroy_sub.next() => { - res?; - return Err(crate::errors::Actor::DestroyedDuringCreation.build()); + // Otherwise, return the error as-is + return Err(error.build()); + } + res = destroy_sub2.next() => { + res?; + return Err(crate::errors::Actor::DestroyedDuringCreation.build()); + } } - } - - // Fetch the created actor - let actors_res = ctx - .op(crate::ops::actor::get::Input { - actor_ids: vec![input.actor_id], - fetch_error: false, + } else { + // Dispatch actor workflow + ctx.workflow(crate::workflows::actor::Input { + actor_id: input.actor_id, + name: input.name.clone(), + runner_name_selector: input.runner_name_selector.clone(), + key: input.key.clone(), + namespace_id: input.namespace_id, + crash_policy: input.crash_policy, + input: input.input.clone(), }) + .tag("actor_id", input.actor_id) + .dispatch() .await?; - let actor = actors_res - .actors - .into_iter() - .next() - .ok_or_else(|| crate::errors::Actor::NotFound.build())?; - - Ok(Output { actor }) -} - -#[derive(Debug)] -pub struct Input2 { - pub actor_id: Id, - pub namespace_id: Id, - pub name: String, - pub key: Option, - pub pool_name: String, - pub crash_policy: CrashPolicy, - pub input: Option, - /// If true, will handle ForwardToDatacenter errors by forwarding the request to the correct datacenter. - /// Used by api-public. api-peer should set this to false. - pub forward_request: bool, - /// Datacenter to create the actor in - /// - /// Providing this value will cause an error if attempting to create an actor where the key is - /// reserved in a different datacenter. - pub datacenter_name: Option, -} - -#[operation(Operation2)] -pub async fn pegboard_actor_create2(ctx: &OperationCtx, input: &Input2) -> Result { - // Set up subscriptions before dispatching workflow - let mut create_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - let mut fail_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - let mut destroy_sub = ctx - .subscribe::(("actor_id", input.actor_id)) - .await?; - - // TODO: check rivetkit version before choosing actor version - - // Dispatch actor workflow - ctx.workflow(crate::workflows::actor2::Input { - actor_id: input.actor_id, - name: input.name.clone(), - pool_name: input.pool_name.clone(), - key: input.key.clone(), - namespace_id: input.namespace_id, - crash_policy: input.crash_policy, - input: input.input.clone(), - }) - .tag("actor_id", input.actor_id) - .dispatch() - .await?; - - // Wait for actor creation to complete, fail, or be destroyed - tokio::select! { - res = create_sub.next() => { res?; }, - res = fail_sub.next() => { - let msg = res?; - let error = msg.into_body().error; - - // Check if this request needs to be forwarded - // - // We cannot forward if `datacenter_name` is specified because this actor is being - // restricted to the given datacenter. - if input.forward_request && input.datacenter_name.is_none() { - if let crate::errors::Actor::KeyReservedInDifferentDatacenter { datacenter_label } = &error { - // Forward the request to the correct datacenter - return forward_to_datacenter( - ctx, - *datacenter_label, - input.namespace_id, - input.name.clone(), - input.key.clone(), - input.pool_name.clone(), - input.input.clone(), - input.crash_policy - ).await; + // Wait for actor creation to complete, fail, or be destroyed + tokio::select! { + res = create_sub.next() => { res?; }, + res = fail_sub.next() => { + let msg = res?; + let error = msg.into_body().error; + + // Check if this request needs to be forwarded + // + // We cannot forward if `datacenter_name` is specified because this actor is being + // restricted to the given datacenter. + if input.forward_request && input.datacenter_name.is_none() { + if let crate::errors::Actor::KeyReservedInDifferentDatacenter { datacenter_label } = &error { + // Forward the request to the correct datacenter + return forward_to_datacenter( + ctx, + *datacenter_label, + input.namespace_id, + input.name.clone(), + input.key.clone(), + input.runner_name_selector.clone(), + input.input.clone(), + input.crash_policy + ).await; + } } - } - // Otherwise, return the error as-is - return Err(error.build()); - } - res = destroy_sub.next() => { - res?; - return Err(crate::errors::Actor::DestroyedDuringCreation.build()); + // Otherwise, return the error as-is + return Err(error.build()); + } + res = destroy_sub.next() => { + res?; + return Err(crate::errors::Actor::DestroyedDuringCreation.build()); + } } } diff --git a/engine/packages/pegboard/src/ops/envoy/drain.rs b/engine/packages/pegboard/src/ops/envoy/drain.rs new file mode 100644 index 0000000000..4489a08282 --- /dev/null +++ b/engine/packages/pegboard/src/ops/envoy/drain.rs @@ -0,0 +1,101 @@ +use anyhow::Result; +use futures_util::TryStreamExt; +use gas::prelude::*; +use rivet_envoy_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; +use universaldb::options::StreamingMode; +use universaldb::utils::IsolationLevel::*; +use universalpubsub::PublishOpts; +use vbare::OwnedVersionedData; + +use crate::{keys, metrics}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Input { + pub namespace_id: Id, + pub pool_name: String, + pub version: u32, +} + +#[operation] +pub async fn pegboard_envoy_drain_older_versions(ctx: &OperationCtx, input: &Input) -> Result<()> { + let pool_res = ctx + .op(crate::ops::runner_config::get::Input { + runners: vec![(input.namespace_id, input.pool_name.clone())], + bypass_cache: false, + }) + .await?; + + let Some(pool) = pool_res.into_iter().next() else { + return Ok(()); + }; + + // Use config's drain_on_version_upgrade if config exists, otherwise default to false + if !pool.config.drain_on_version_upgrade { + return Ok(()); + } + + // Scan EnvoyLoadBalancerIdxKey for older versions + let older_envoys = ctx + .udb()? + .run(|tx| async move { + let tx = tx.with_subspace(keys::subspace()); + let mut older_envoys = Vec::new(); + + let lb_subspace = + keys::subspace().subspace(&keys::ns::EnvoyLoadBalancerIdxKey::subspace( + input.namespace_id, + input.pool_name.clone(), + )); + + let mut stream = tx.get_ranges_keyvalues( + universaldb::RangeOption { + mode: StreamingMode::WantAll, + ..(&lb_subspace).into() + }, + Snapshot, + ); + + while let Some(entry) = stream.try_next().await? { + let (key, _) = tx.read_entry::(&entry)?; + + // Only collect envoys with older versions + if key.version < input.version { + older_envoys.push(key.envoy_key); + } + } + + Ok(older_envoys) + }) + .custom_instrument(tracing::info_span!("drain_older_versions_tx")) + .await?; + + if !older_envoys.is_empty() { + tracing::info!( + namespace_id = %input.namespace_id, + pool_name = %input.pool_name, + new_version = input.version, + older_envoy_count = older_envoys.len(), + "draining older envoy versions due to drain_on_version_upgrade" + ); + + metrics::ENVOY_VERSION_UPGRADE_DRAIN_TOTAL + .with_label_values(&[&input.namespace_id.to_string(), &input.pool_name]) + .inc_by(older_envoys.len() as u64); + + for envoy_key in older_envoys { + let receiver_subject = + crate::pubsub_subjects::EnvoyReceiverSubject::new(input.namespace_id, envoy_key) + .to_string(); + + let message_serialized = + versioned::ToEnvoyConn::wrap_latest(protocol::ToEnvoyConn::ToEnvoyConnClose) + .serialize_with_embedded_version(PROTOCOL_VERSION)?; + + ctx.ups()? + .publish(&receiver_subject, &message_serialized, PublishOpts::one()) + .await?; + } + } + + Ok(()) +} diff --git a/engine/packages/pegboard/src/ops/envoy/expire.rs b/engine/packages/pegboard/src/ops/envoy/expire.rs index f9247575d8..e3b408787b 100644 --- a/engine/packages/pegboard/src/ops/envoy/expire.rs +++ b/engine/packages/pegboard/src/ops/envoy/expire.rs @@ -74,6 +74,14 @@ pub async fn pegboard_envoy_expire(ctx: &OperationCtx, input: &Input) -> Result< input.envoy_key.clone(), ), ); + tx.delete( + &keys::ns::ActiveEnvoyByNameKey::new( + input.namespace_id, + pool_name.clone(), + create_ts, + input.envoy_key.clone(), + ), + ); } Ok(()) diff --git a/engine/packages/pegboard/src/ops/envoy/get.rs b/engine/packages/pegboard/src/ops/envoy/get.rs index 0ebcedee1f..e3f6741379 100644 --- a/engine/packages/pegboard/src/ops/envoy/get.rs +++ b/engine/packages/pegboard/src/ops/envoy/get.rs @@ -10,107 +10,126 @@ use crate::keys; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Input { pub namespace_id: Id, - pub envoy_key: String, + pub envoy_keys: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub envoys: Vec, } #[operation] -pub async fn pegboard_envoy_get(ctx: &OperationCtx, input: &Input) -> Result> { +pub async fn pegboard_envoy_get(ctx: &OperationCtx, input: &Input) -> Result { let dc_name = ctx.config().dc_name()?; - ctx.udb()? + let envoys = ctx + .udb()? .run(|tx| { + let dc_name = dc_name.to_string(); async move { - let tx = tx.with_subspace(keys::subspace()); + let mut envoys = Vec::new(); - // TODO: Make this part of the below try join to reduce round trip count - // Check if envoy exists by looking for workflow ID - if !tx - .exists( - &keys::envoy::PoolNameKey::new(input.namespace_id, input.envoy_key.clone()), - Serializable, - ) - .await? - { - return Ok(None); + for envoy_key in &input.envoy_keys { + if let Some(envoy) = + get_inner(&dc_name, &tx, input.namespace_id, envoy_key).await? + { + envoys.push(envoy); + } } - let pool_name_key = - keys::envoy::PoolNameKey::new(input.namespace_id, input.envoy_key.clone()); - let version_key = - keys::envoy::VersionKey::new(input.namespace_id, input.envoy_key.clone()); - let slots_key = - keys::envoy::SlotsKey::new(input.namespace_id, input.envoy_key.clone()); - let create_ts_key = - keys::envoy::CreateTsKey::new(input.namespace_id, input.envoy_key.clone()); - let connected_ts_key = - keys::envoy::ConnectedTsKey::new(input.namespace_id, input.envoy_key.clone()); - let stop_ts_key = - keys::envoy::StopTsKey::new(input.namespace_id, input.envoy_key.clone()); - let last_ping_ts_key = - keys::envoy::LastPingTsKey::new(input.namespace_id, input.envoy_key.clone()); - let last_rtt_key = - keys::envoy::LastRttKey::new(input.namespace_id, input.envoy_key.clone()); - let metadata_key = - keys::envoy::MetadataKey::new(input.namespace_id, input.envoy_key.clone()); - let metadata_subspace = keys::subspace().subspace(&metadata_key); - - let ( - pool_name, - version, - slots, - create_ts, - connected_ts, - stop_ts, - last_ping_ts, - last_rtt, - metadata_chunks, - ) = tokio::try_join!( - // NOTE: These are not Serializable because this op is meant for basic information (i.e. data for the - // API) - tx.read(&pool_name_key, Snapshot), - tx.read(&version_key, Snapshot), - tx.read(&slots_key, Snapshot), - tx.read(&create_ts_key, Snapshot), - tx.read_opt(&connected_ts_key, Snapshot), - tx.read_opt(&stop_ts_key, Snapshot), - tx.read_opt(&last_ping_ts_key, Snapshot), - tx.read_opt(&last_rtt_key, Snapshot), - async { - tx.get_ranges_keyvalues( - universaldb::RangeOption { - mode: StreamingMode::WantAll, - ..(&metadata_subspace).into() - }, - Snapshot, - ) - .try_collect::>() - .await - .map_err(Into::into) - }, - )?; - - let metadata = if metadata_chunks.is_empty() { - None - } else { - Some(metadata_key.combine(metadata_chunks)?.metadata) - }; - - Ok(Some(Envoy { - envoy_key: input.envoy_key.clone(), - namespace_id: input.namespace_id, - datacenter: dc_name.to_string(), - pool_name, - version, - slots: slots.try_into()?, - create_ts, - last_connected_ts: connected_ts, - stop_ts, - last_ping_ts: last_ping_ts.unwrap_or_default(), - last_rtt: last_rtt.unwrap_or_default(), - metadata, - })) + Ok(envoys) } }) .custom_instrument(tracing::info_span!("envoy_get_tx")) - .await + .await?; + + Ok(Output { envoys }) +} + +pub(crate) async fn get_inner( + dc_name: &str, + tx: &universaldb::Transaction, + namespace_id: Id, + envoy_key: &str, +) -> Result> { + let tx = tx.with_subspace(keys::subspace()); + + // TODO: Make this part of the below try join to reduce round trip count + // Check if envoy exists by looking for workflow ID + if !tx + .exists( + &keys::envoy::PoolNameKey::new(namespace_id, envoy_key.to_string()), + Serializable, + ) + .await? + { + return Ok(None); + } + + let pool_name_key = keys::envoy::PoolNameKey::new(namespace_id, envoy_key.to_string()); + let version_key = keys::envoy::VersionKey::new(namespace_id, envoy_key.to_string()); + let slots_key = keys::envoy::SlotsKey::new(namespace_id, envoy_key.to_string()); + let create_ts_key = keys::envoy::CreateTsKey::new(namespace_id, envoy_key.to_string()); + let connected_ts_key = keys::envoy::ConnectedTsKey::new(namespace_id, envoy_key.to_string()); + let stop_ts_key = keys::envoy::StopTsKey::new(namespace_id, envoy_key.to_string()); + let last_ping_ts_key = keys::envoy::LastPingTsKey::new(namespace_id, envoy_key.to_string()); + let last_rtt_key = keys::envoy::LastRttKey::new(namespace_id, envoy_key.to_string()); + let metadata_key = keys::envoy::MetadataKey::new(namespace_id, envoy_key.to_string()); + let metadata_subspace = keys::subspace().subspace(&metadata_key); + + let ( + pool_name, + version, + slots, + create_ts, + connected_ts, + stop_ts, + last_ping_ts, + last_rtt, + metadata_chunks, + ) = tokio::try_join!( + // NOTE: These are not Serializable because this op is meant for basic information (i.e. data for the + // API) + tx.read(&pool_name_key, Snapshot), + tx.read(&version_key, Snapshot), + tx.read(&slots_key, Snapshot), + tx.read(&create_ts_key, Snapshot), + tx.read_opt(&connected_ts_key, Snapshot), + tx.read_opt(&stop_ts_key, Snapshot), + tx.read_opt(&last_ping_ts_key, Snapshot), + tx.read_opt(&last_rtt_key, Snapshot), + async { + tx.get_ranges_keyvalues( + universaldb::RangeOption { + mode: StreamingMode::WantAll, + ..(&metadata_subspace).into() + }, + Snapshot, + ) + .try_collect::>() + .await + .map_err(Into::into) + }, + )?; + + let metadata = if metadata_chunks.is_empty() { + None + } else { + Some(metadata_key.combine(metadata_chunks)?.metadata) + }; + + Ok(Some(Envoy { + envoy_key: envoy_key.to_string(), + namespace_id: namespace_id, + datacenter: dc_name.to_string(), + pool_name, + version, + slots: slots.try_into()?, + create_ts, + last_connected_ts: connected_ts, + stop_ts, + last_ping_ts: last_ping_ts.unwrap_or_default(), + last_rtt: last_rtt.unwrap_or_default(), + metadata, + })) } diff --git a/engine/packages/pegboard/src/ops/envoy/list.rs b/engine/packages/pegboard/src/ops/envoy/list.rs new file mode 100644 index 0000000000..249a733186 --- /dev/null +++ b/engine/packages/pegboard/src/ops/envoy/list.rs @@ -0,0 +1,131 @@ +use anyhow::Result; +use futures_util::{StreamExt, TryStreamExt}; +use gas::prelude::*; +use rivet_types::envoys::Envoy; +use universaldb::options::StreamingMode; +use universaldb::utils::IsolationLevel::*; + +use crate::keys; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Input { + pub namespace_id: Id, + pub pool_name: Option, + pub created_before: Option, + pub limit: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Output { + pub envoys: Vec, +} + +#[operation] +pub async fn pegboard_envoy_list(ctx: &OperationCtx, input: &Input) -> Result { + let dc_name = ctx.config().dc_name()?; + + let envoys = ctx + .udb()? + .run(|tx| { + let dc_name = dc_name.to_string(); + async move { + let tx = tx.with_subspace(keys::subspace()); + let mut results = Vec::new(); + + if let Some(pool_name) = &input.pool_name { + let envoy_subspace = + keys::subspace().subspace(&keys::ns::ActiveEnvoyByNameKey::subspace( + input.namespace_id, + pool_name.clone(), + )); + let (start, end) = envoy_subspace.range(); + + let end = if let Some(created_before) = input.created_before { + universaldb::utils::end_of_key_range(&tx.pack( + &keys::ns::ActiveEnvoyByNameKey::subspace_with_create_ts( + input.namespace_id, + pool_name.clone(), + created_before, + ), + )) + } else { + end + }; + + let mut stream = tx.get_ranges_keyvalues( + universaldb::RangeOption { + mode: StreamingMode::Iterator, + reverse: true, + ..(start, end).into() + }, + // NOTE: Does not have to be serializable because we are listing, stale data does not matter + Snapshot, + ); + + while let Some(entry) = stream.try_next().await? { + let idx_key = tx.unpack::(entry.key())?; + + results.push(idx_key.envoy_key); + + if results.len() >= input.limit { + break; + } + } + } else { + let envoy_subspace = keys::subspace() + .subspace(&keys::ns::ActiveEnvoyKey::subspace(input.namespace_id)); + let (start, end) = envoy_subspace.range(); + + let end = if let Some(created_before) = input.created_before { + universaldb::utils::end_of_key_range(&tx.pack( + &keys::ns::ActiveEnvoyKey::subspace_with_create_ts( + input.namespace_id, + created_before, + ), + )) + } else { + end + }; + + let mut stream = tx.get_ranges_keyvalues( + universaldb::RangeOption { + mode: StreamingMode::Iterator, + reverse: true, + ..(start, end).into() + }, + // NOTE: Does not have to be serializable because we are listing, stale data does not matter + Snapshot, + ); + + while let Some(entry) = stream.try_next().await? { + let idx_key = tx.unpack::(entry.key())?; + + results.push(idx_key.envoy_key); + + if results.len() >= input.limit { + break; + } + } + } + + futures_util::stream::iter(results) + .map(|envoy_key| { + let tx = tx.clone(); + let dc_name = dc_name.clone(); + + async move { + super::get::get_inner(&dc_name, &tx, input.namespace_id, &envoy_key) + .await + } + }) + .buffered(512) + .try_filter_map(|result| async move { Ok(result) }) + .try_collect::>() + .await + } + }) + .custom_instrument(tracing::info_span!("envoy_list_tx")) + .await?; + + Ok(Output { envoys }) +} diff --git a/engine/packages/pegboard/src/ops/envoy/mod.rs b/engine/packages/pegboard/src/ops/envoy/mod.rs index 2c1bbb9076..3f0c7d7a7f 100644 --- a/engine/packages/pegboard/src/ops/envoy/mod.rs +++ b/engine/packages/pegboard/src/ops/envoy/mod.rs @@ -1,4 +1,6 @@ +pub mod drain; pub mod evict_actors; pub mod expire; pub mod get; +pub mod list; pub mod update_ping; diff --git a/engine/packages/pegboard/src/ops/runner_config/get.rs b/engine/packages/pegboard/src/ops/runner_config/get.rs index 469506e74e..a089ad8f67 100644 --- a/engine/packages/pegboard/src/ops/runner_config/get.rs +++ b/engine/packages/pegboard/src/ops/runner_config/get.rs @@ -16,6 +16,8 @@ pub struct RunnerConfig { pub namespace_id: Id, pub name: String, pub config: rivet_types::runner_configs::RunnerConfig, + /// Unset if the runner's metadata endpoint has never returned `envoyProtocolVersion`` + pub protocol_version: Option, } #[operation] @@ -68,10 +70,17 @@ async fn runner_config_get_inner( namespace_id, runner_name.clone(), ); + let protocol_version_key = keys::runner_config::ProtocolVersionKey::new( + namespace_id, + runner_name.clone(), + ); + + let (runner_config_entry, protocol_version_entry) = tokio::try_join!( + tx.read_opt(&runner_config_key, Serializable), + tx.read_opt(&protocol_version_key, Serializable), + )?; - let Some(runner_config) = - tx.read_opt(&runner_config_key, Serializable).await? - else { + let Some(runner_config) = runner_config_entry else { // Runner config not found return Ok(None); }; @@ -80,6 +89,7 @@ async fn runner_config_get_inner( namespace_id, name: runner_name, config: runner_config, + protocol_version: protocol_version_entry, })) } }) diff --git a/engine/packages/pegboard/src/ops/serverless_metadata/fetch.rs b/engine/packages/pegboard/src/ops/serverless_metadata/fetch.rs index ac427028dc..281d954764 100644 --- a/engine/packages/pegboard/src/ops/serverless_metadata/fetch.rs +++ b/engine/packages/pegboard/src/ops/serverless_metadata/fetch.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::Result; use gas::prelude::*; use reqwest::header::{HeaderMap as ReqwestHeaderMap, HeaderName, HeaderValue}; +use rivet_envoy_protocol::PROTOCOL_VERSION; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; @@ -32,8 +33,10 @@ pub enum ServerlessMetadataError { pub struct Output { pub runtime: String, pub version: String, + pub envoy_protocol_version: Option, pub actor_names: Vec, pub runner_version: Option, + pub envoy_version: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Hash)] @@ -42,6 +45,11 @@ pub struct ActorNameMetadata { pub metadata: serde_json::Map, } +#[derive(Deserialize)] +struct ServerlessMetadataEnvoy { + version: Option, +} + #[derive(Deserialize)] struct ServerlessMetadataRunner { version: Option, @@ -51,8 +59,11 @@ struct ServerlessMetadataRunner { struct ServerlessMetadataPayload { runtime: String, version: String, + #[serde(rename = "envoyProtocolVersion")] + envoy_protocol_version: Option, #[serde(rename = "actorNames", default)] actor_names: HashMap, + envoy: Option, runner: Option, } @@ -77,7 +88,7 @@ fn truncate_response_body(body: &str) -> String { #[operation] #[tracing::instrument(skip_all)] pub async fn pegboard_serverless_metadata_fetch( - _ctx: &OperationCtx, + ctx: &OperationCtx, input: &Input, ) -> Result> { tracing::debug!(url = ?input.url, "fetching serverless runner metadata"); @@ -159,17 +170,20 @@ pub async fn pegboard_serverless_metadata_fetch( let ServerlessMetadataPayload { runtime, version, + envoy_protocol_version, actor_names, + envoy, runner, } = payload; let runner_version = runner.and_then(|r| r.version); + let envoy_version = envoy.and_then(|e| e.version); tracing::debug!( ?runtime, ?version, actor_names_count = actor_names.len(), - ?runner_version, + ?envoy_version, "parsed metadata payload" ); @@ -181,6 +195,14 @@ pub async fn pegboard_serverless_metadata_fetch( })); } + if let Some(envoy_protocol_version) = envoy_protocol_version { + if envoy_protocol_version < 1 || envoy_protocol_version > PROTOCOL_VERSION { + return Ok(Err(ServerlessMetadataError::InvalidResponseJson { + body: body_for_user, + })); + } + } + // Convert actor names, filtering out non-object metadata let actor_names: Vec = actor_names .into_iter() @@ -193,7 +215,9 @@ pub async fn pegboard_serverless_metadata_fetch( Ok(Ok(Output { runtime, version: trimmed_version.to_owned(), + envoy_protocol_version, actor_names, runner_version, + envoy_version, })) } diff --git a/engine/packages/pegboard/src/workflows/actor/mod.rs b/engine/packages/pegboard/src/workflows/actor/mod.rs index f7ea43786f..d29a1f8b32 100644 --- a/engine/packages/pegboard/src/workflows/actor/mod.rs +++ b/engine/packages/pegboard/src/workflows/actor/mod.rs @@ -217,6 +217,28 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> return Ok(()); } + runtime::SpawnActorOutput::MigrateToV2 => { + ctx.signal(metrics::Destroy { + ts: util::timestamp::now(), + }) + .to_workflow_id(metrics_workflow_id) + .send() + .await?; + + ctx.workflow(crate::workflows::actor2::Input { + actor_id: input.actor_id, + name: input.name.clone(), + pool_name: input.runner_name_selector.clone(), + key: input.key.clone(), + namespace_id: input.namespace_id, + crash_policy: input.crash_policy, + input: input.input.clone(), + from_v1: true, + }) + .dispatch() + .await?; + return Ok(()); + } }; let lifecycle_res = ctx @@ -368,7 +390,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> protocol::ActorState::ActorStateStopped( protocol::ActorStateStopped { code, message }, ) => { - if let StoppedResult::Destroy = handle_stopped( + match handle_stopped( ctx, &input, state, @@ -383,9 +405,19 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ) .await? { - return Ok(Loop::Break(runtime::LifecycleResult { - generation: state.generation, - })); + StoppedResult::Continue => {} + StoppedResult::Destroy => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: false, + })); + } + StoppedResult::MigrateToV2 => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: true, + })); + } } } }, @@ -511,7 +543,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> protocol::mk2::ActorState::ActorStateStopped( protocol::mk2::ActorStateStopped { code, message }, ) => { - if let StoppedResult::Destroy = handle_stopped( + match handle_stopped( ctx, &input, state, @@ -523,9 +555,19 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ) .await? { - return Ok(Loop::Break(runtime::LifecycleResult { - generation: state.generation, - })); + StoppedResult::Continue => {} + StoppedResult::Destroy => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: false, + })); + } + StoppedResult::MigrateToV2 => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: true, + })); + } } } }, @@ -616,6 +658,13 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> // Destroyed early return Ok(Loop::Break(runtime::LifecycleResult { generation: state.generation, + migrate_to_v2: false, + })); + } + runtime::SpawnActorOutput::MigrateToV2 => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: true, })); } } @@ -665,7 +714,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> None }; - if let StoppedResult::Destroy = handle_stopped( + match handle_stopped( ctx, &input, state, @@ -677,9 +726,19 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ) .await? { - return Ok(Loop::Break(runtime::LifecycleResult { - generation: state.generation, - })); + StoppedResult::Continue => {} + StoppedResult::Destroy => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: false, + })); + } + StoppedResult::MigrateToV2 => { + return Ok(Loop::Break(runtime::LifecycleResult { + generation: state.generation, + migrate_to_v2: true, + })); + } } } Main::GoingAway(sig) => { @@ -763,6 +822,7 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> return Ok(Loop::Break(runtime::LifecycleResult { generation: state.generation, + migrate_to_v2: false, })); } } @@ -945,6 +1005,7 @@ enum StoppedVariant { enum StoppedResult { Continue, Destroy, + MigrateToV2, } async fn handle_stopped( @@ -1128,6 +1189,7 @@ async fn handle_stopped( } // Destroyed early runtime::SpawnActorOutput::Destroy => return Ok(StoppedResult::Destroy), + runtime::SpawnActorOutput::MigrateToV2 => return Ok(StoppedResult::MigrateToV2), } } // Handle rescheduling if not marked as sleeping @@ -1170,6 +1232,9 @@ async fn handle_stopped( // Destroyed early return Ok(StoppedResult::Destroy); } + runtime::SpawnActorOutput::MigrateToV2 => { + return Ok(StoppedResult::MigrateToV2); + } } } (CrashPolicy::Sleep, false) => { @@ -1210,6 +1275,7 @@ async fn handle_stopped( } // Destroyed early runtime::SpawnActorOutput::Destroy => return Ok(StoppedResult::Destroy), + runtime::SpawnActorOutput::MigrateToV2 => return Ok(StoppedResult::MigrateToV2), } } diff --git a/engine/packages/pegboard/src/workflows/actor/runtime.rs b/engine/packages/pegboard/src/workflows/actor/runtime.rs index 46935f054b..0e6f631214 100644 --- a/engine/packages/pegboard/src/workflows/actor/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor/runtime.rs @@ -8,7 +8,8 @@ use rand::prelude::SliceRandom; use rivet_runner_protocol::{ self as protocol, PROTOCOL_MK1_VERSION, PROTOCOL_MK2_VERSION, versioned, }; -use rivet_types::{actors::CrashPolicy, keys::namespace::runner_config::RunnerConfigVariant}; +use rivet_types::actors::CrashPolicy; +use rivet_types::runner_configs::RunnerConfigKind; use std::time::Instant; use universaldb::prelude::*; use universalpubsub::PublishOpts; @@ -109,6 +110,8 @@ impl LifecycleState { #[derive(Serialize, Deserialize)] pub struct LifecycleResult { pub generation: u32, + #[serde(default)] + pub migrate_to_v2: bool, } #[derive(Serialize, Deserialize, Clone, Default)] @@ -227,6 +230,7 @@ enum AllocateActorStatus { pending_allocation_ts: i64, }, Sleep, + MigrateToV2, } // If no availability, returns the timestamp of the actor's queue key @@ -248,6 +252,26 @@ async fn allocate_actor_v2( .pegboard() .actor_allocation_candidate_sample_size(); + let pool_res = ctx + .op(crate::ops::runner_config::get::Input { + runners: vec![(namespace_id, runner_name_selector.clone())], + bypass_cache: false, + }) + .await?; + let pool = pool_res.into_iter().next(); + let for_serverless = pool + .as_ref() + .map(|pool| matches!(pool.config.kind, RunnerConfigKind::Serverless { .. })) + .unwrap_or(false); + + // Protocol version is set, we must migrate to actor v2 + if pool.and_then(|p| p.protocol_version).is_some() { + return Ok(AllocateActorOutputV2 { + status: AllocateActorStatus::MigrateToV2, + serverless: false, + }); + } + // NOTE: This txn should closely resemble the one found in the allocate_pending_actors activity of the // client wf let res = ctx @@ -265,12 +289,6 @@ async fn allocate_actor_v2( ), ); - let ns_tx = tx.with_subspace(namespace::keys::subspace()); - let runner_config_variant_key = keys::runner_config::ByVariantKey::new( - namespace_id, - RunnerConfigVariant::Serverless, - runner_name_selector.clone(), - ); let mut queue_stream = tx.get_ranges_keyvalues( universaldb::RangeOption { mode: StreamingMode::Exact, @@ -281,13 +299,7 @@ async fn allocate_actor_v2( // inserts/clears to this range Snapshot, ); - let (for_serverless_res, queue_exists_res) = tokio::join!( - // Check if runner is a serverless runner - ns_tx.exists(&runner_config_variant_key, Serializable), - queue_stream.next(), - ); - let for_serverless = for_serverless_res?; - let queue_exists = queue_exists_res.is_some(); + let queue_exists = queue_stream.next().await.is_some(); if for_serverless { tx.atomic_op( @@ -478,6 +490,7 @@ async fn allocate_actor_v2( AllocateActorStatus::Allocated { .. } => "allocated", AllocateActorStatus::Pending { .. } => "pending", AllocateActorStatus::Sleep { .. } => "sleep", + AllocateActorStatus::MigrateToV2 => bail!("should not be migrate_to_v2"), }, ]) .observe(dt); @@ -517,6 +530,7 @@ async fn allocate_actor_v2( state.failure_reason = Some(super::FailureReason::NoCapacity); } } + AllocateActorStatus::MigrateToV2 => bail!("should not be migrate_to_v2"), } Ok(res) @@ -620,6 +634,7 @@ pub enum SpawnActorOutput { }, Sleep, Destroy, + MigrateToV2, } /// Wrapper around `allocate_actor` that handles pending state. @@ -1033,6 +1048,7 @@ pub async fn spawn_actor( } } AllocateActorStatus::Sleep => Ok(SpawnActorOutput::Sleep), + AllocateActorStatus::MigrateToV2 => Ok(SpawnActorOutput::MigrateToV2), } } diff --git a/engine/packages/pegboard/src/workflows/actor2/mod.rs b/engine/packages/pegboard/src/workflows/actor2/mod.rs index 5c3c836950..47d70c0705 100644 --- a/engine/packages/pegboard/src/workflows/actor2/mod.rs +++ b/engine/packages/pegboard/src/workflows/actor2/mod.rs @@ -29,6 +29,7 @@ pub struct Input { /// Arbitrary user-provided binary data encoded in base64. pub input: Option, + pub from_v1: bool, } #[derive(Deserialize, Serialize)] @@ -123,61 +124,64 @@ pub async fn pegboard_actor2(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> namespace_id: input.namespace_id, crash_policy: input.crash_policy, create_ts: ctx.create_ts(), + from_v1: input.from_v1, }) .await?; - if let Some(key) = &input.key { - match keys::reserve_key( - ctx, - input.namespace_id, - &input.name, - &key, - input.actor_id, - &input.pool_name, - ) - .await? - { - keys::ReserveKeyOutput::Success => {} - keys::ReserveKeyOutput::ForwardToDatacenter { dc_label } => { - ctx.msg(Failed { - error: errors::Actor::KeyReservedInDifferentDatacenter { - datacenter_label: dc_label, - }, - }) - .topic(("actor_id", input.actor_id)) - .send() - .await?; + if !input.from_v1 { + if let Some(key) = &input.key { + match keys::reserve_key( + ctx, + input.namespace_id, + &input.name, + &key, + input.actor_id, + &input.pool_name, + ) + .await? + { + keys::ReserveKeyOutput::Success => {} + keys::ReserveKeyOutput::ForwardToDatacenter { dc_label } => { + ctx.msg(Failed { + error: errors::Actor::KeyReservedInDifferentDatacenter { + datacenter_label: dc_label, + }, + }) + .topic(("actor_id", input.actor_id)) + .send() + .await?; - // Destroyed early - destroy(ctx, input).await?; + // Destroyed early + destroy(ctx, input).await?; - return Ok(()); - } - keys::ReserveKeyOutput::KeyExists { existing_actor_id } => { - ctx.msg(Failed { - error: errors::Actor::DuplicateKey { - key: key.clone(), - existing_actor_id, - }, - }) - .topic(("actor_id", input.actor_id)) - .send() - .await?; + return Ok(()); + } + keys::ReserveKeyOutput::KeyExists { existing_actor_id } => { + ctx.msg(Failed { + error: errors::Actor::DuplicateKey { + key: key.clone(), + existing_actor_id, + }, + }) + .topic(("actor_id", input.actor_id)) + .send() + .await?; - // Destroyed early - destroy(ctx, input).await?; + // Destroyed early + destroy(ctx, input).await?; - return Ok(()); + return Ok(()); + } } } - } - ctx.activity(PopulateIndexesInput {}).await?; + ctx.activity(PopulateIndexesInput {}).await?; - ctx.msg(CreateComplete {}) - .topic(("actor_id", input.actor_id)) - .send() - .await?; + ctx.msg(CreateComplete {}) + .topic(("actor_id", input.actor_id)) + .send() + .await?; + } // Spawn adjacent workflows let metrics_workflow_id = ctx @@ -235,6 +239,7 @@ pub struct InitStateAndUdbInput { pub crash_policy: CrashPolicy, pub pool_name: String, pub create_ts: i64, + pub from_v1: bool, } #[activity(InitStateAndDb)] @@ -255,10 +260,12 @@ pub async fn insert_state_and_db(ctx: &ActivityCtx, input: &InitStateAndUdbInput .run(|tx| async move { let tx = tx.with_subspace(crate::keys::subspace()); - tx.write( - &crate::keys::actor::CreateTsKey::new(input.actor_id), - input.create_ts, - )?; + if !input.from_v1 { + tx.write( + &crate::keys::actor::CreateTsKey::new(input.actor_id), + input.create_ts, + )?; + } tx.write( &crate::keys::actor::WorkflowIdKey::new(input.actor_id), ctx.workflow_id(), @@ -284,13 +291,15 @@ pub async fn insert_state_and_db(ctx: &ActivityCtx, input: &InitStateAndUdbInput )?; } - // Update metrics - namespace::keys::metric::inc( - &tx.with_subspace(namespace::keys::subspace()), - input.namespace_id, - namespace::keys::metric::Metric::TotalActors(input.name.clone()), - 1, - ); + if !input.from_v1 { + // Update metrics + namespace::keys::metric::inc( + &tx.with_subspace(namespace::keys::subspace()), + input.namespace_id, + namespace::keys::metric::Metric::TotalActors(input.name.clone()), + 1, + ); + } Ok(()) }) @@ -519,7 +528,7 @@ async fn process_signal( match sig { Main::Allocated(sig) => { // Ignore signals for previous generations - if sig.generation == state.generation { + if sig.generation != state.generation { return Ok(Loop::Continue); } diff --git a/engine/packages/pegboard/src/workflows/actor2/runtime.rs b/engine/packages/pegboard/src/workflows/actor2/runtime.rs index 4340840458..76ee019907 100644 --- a/engine/packages/pegboard/src/workflows/actor2/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor2/runtime.rs @@ -182,6 +182,7 @@ pub async fn allocate(ctx: &ActivityCtx, input: &AllocateInput) -> Result None, }); + let actor_id = state.actor_id; let namespace_id = state.namespace_id; let pool_name = &state.pool_name; let envoy_eligible_threshold = ctx.config().pegboard().envoy_eligible_threshold(); @@ -197,6 +198,9 @@ pub async fn allocate(ctx: &ActivityCtx, input: &AllocateInput) -> Result (rewake_after_stop, false), @@ -526,16 +530,43 @@ pub async fn handle_stopped( _ => (true, false), }; - let stopped_res = if try_reallocate { - // An actor stopping with `StopCode::Ok` indicates a graceful exit (if not going away) - let graceful_exit = !was_going_away - && matches!( - variant, - StoppedVariant::Normal { - code: protocol::StopCode::Ok, - .. - } - ); + // Always immediately reallocate if going away + let stopped_res = if going_away { + let allocate_res = ctx.activity(AllocateInput {}).await?; + + if let Some(allocation) = allocate_res.allocation { + state.generation += 1; + + ctx.activity(SendOutboundInput { + generation: state.generation, + input: input.input.clone(), + allocation, + }) + .await?; + + // Transition to allocating + state.transition = Transition::Allocating { + destroy_after_start: false, + lost_timeout_ts: util::timestamp::now() + + ctx.config().pegboard().actor_allocation_threshold(), + }; + } else { + // Transition to retry backoff + state.transition = Transition::Sleeping { + attempting_reallocation: true, + }; + } + + StoppedResult::Continue + } else if try_reallocate { + // An actor stopping with `StopCode::Ok` indicates a graceful exit + let graceful_exit = matches!( + variant, + StoppedVariant::Normal { + code: protocol::StopCode::Ok, + .. + } + ); match (input.crash_policy, graceful_exit) { (CrashPolicy::Restart, false) => { @@ -569,8 +600,6 @@ pub async fn handle_stopped( (CrashPolicy::Sleep, false) => { tracing::debug!(actor_id=?input.actor_id, "actor sleeping due to ungraceful exit"); - ctx.activity(SetSleepingInput {}).await?; - // Clear alarm if let Some(alarm_ts) = state.alarm_ts { let now = ctx.activity(GetTsInput {}).await?; @@ -610,6 +639,10 @@ pub async fn handle_stopped( StoppedResult::Continue }; + if let Transition::Sleeping { .. } = state.transition { + ctx.activity(SetSleepingInput {}).await?; + } + ctx.msg(Stopped {}) .topic(("actor_id", input.actor_id)) .send() diff --git a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs index 448638fae1..64de8d6551 100644 --- a/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs +++ b/engine/packages/pegboard/src/workflows/runner_pool_metadata_poller.rs @@ -2,9 +2,10 @@ use std::time::Duration; use futures_util::FutureExt; use gas::prelude::*; -use rivet_types::runner_configs::RunnerConfigKind; +use rivet_types::{actor::RunnerPoolError, runner_configs::RunnerConfigKind}; +use universaldb::prelude::*; -use crate::ops::actor_name::upsert_batch::ActorNameEntry; +use crate::{keys, ops::actor_name::upsert_batch::ActorNameEntry}; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Input { @@ -199,6 +200,39 @@ async fn poll_metadata(ctx: &ActivityCtx, input: &PollMetadataInput) -> Result

= metadata @@ -219,26 +253,56 @@ async fn poll_metadata(ctx: &ActivityCtx, input: &PollMetadataInput) -> Result

() + .tag("namespace_id", namespace_id) + .tag("runner_name", pool_name) + .graceful_not_found() + .send() + .await + { + tracing::warn!(?err, "failed to report serverless error"); + } +} + #[signal("pegboard_runner_pool_metadata_poller_endpoint_config_changed")] #[derive(Debug)] pub struct EndpointConfigChanged {} diff --git a/engine/packages/service-manager/src/lib.rs b/engine/packages/service-manager/src/lib.rs index 6b3c48d8e4..f7dbc90c82 100644 --- a/engine/packages/service-manager/src/lib.rs +++ b/engine/packages/service-manager/src/lib.rs @@ -369,16 +369,18 @@ pub async fn start( break; } abort = term_signal.recv() => { - shutting_down.store(true, Ordering::SeqCst); + if !shutting_down.load(Ordering::SeqCst) { + // Spawn force exit task in case of a lingering task + let force_shutdown_duration = config.runtime.force_shutdown_duration(); + tokio::spawn(async move { + tracing::info!(?force_shutdown_duration, "force shutdown timer started"); + tokio::time::sleep(force_shutdown_duration).await; + tracing::warn!("force shutdown timeout reached, exiting process, this indicates a bug"); + std::process::exit(1); + }); + } - // Spawn force exit task in case of a lingering task - let force_shutdown_duration = config.runtime.force_shutdown_duration(); - tokio::spawn(async move { - tracing::info!(?force_shutdown_duration, "force shutdown timer started"); - tokio::time::sleep(force_shutdown_duration).await; - tracing::warn!("force shutdown timeout reached, exiting process, this indicates a bug"); - std::process::exit(1); - }); + shutting_down.store(true, Ordering::SeqCst); // Abort services that don't require graceful shutdown running_services.retain(|task| { diff --git a/engine/packages/types/src/actor/error.rs b/engine/packages/types/src/actor/error.rs index 061924a4ea..76d74c5d14 100644 --- a/engine/packages/types/src/actor/error.rs +++ b/engine/packages/types/src/actor/error.rs @@ -21,6 +21,9 @@ pub enum RunnerPoolError { raw_payload: Option, }, + /// RivetKit was downgraded to a version after being upgraded to use envoys + Downgrade, + /// Internal error InternalError, } diff --git a/engine/sdks/schemas/envoy-protocol/v1.bare b/engine/sdks/schemas/envoy-protocol/v1.bare index 008ac6a5b2..977f35d268 100644 --- a/engine/sdks/schemas/envoy-protocol/v1.bare +++ b/engine/sdks/schemas/envoy-protocol/v1.bare @@ -410,10 +410,11 @@ type ToEnvoyConnPing struct { ts: i64 } -# We have to re-declare the entire union since BARE will not generate the -# ser/de for ToEnvoy if it's not a top-level type +type ToEnvoyConnClose void + type ToEnvoyConn union { ToEnvoyConnPing | + ToEnvoyConnClose | ToEnvoyCommands | ToEnvoyAckEvents | ToEnvoyTunnelMessage diff --git a/engine/sdks/typescript/envoy-client/src/config.ts b/engine/sdks/typescript/envoy-client/src/config.ts index 576de298d6..f83c337758 100644 --- a/engine/sdks/typescript/envoy-client/src/config.ts +++ b/engine/sdks/typescript/envoy-client/src/config.ts @@ -30,6 +30,9 @@ export interface EnvoyConfig { request: Request, ) => Promise; + /** Payload to start an actor from a serverless SSE POST request. Can also use `EnvoyHandle.startServerless` */ + serverlessStartPayload?: ArrayBuffer; + // TODO: fix doc comment /** * Called when receiving a WebSocket connection. diff --git a/engine/sdks/typescript/envoy-client/src/handle.ts b/engine/sdks/typescript/envoy-client/src/handle.ts index c09e575ff2..6c3ecb87f2 100644 --- a/engine/sdks/typescript/envoy-client/src/handle.ts +++ b/engine/sdks/typescript/envoy-client/src/handle.ts @@ -1,5 +1,6 @@ import * as protocol from "@rivetkit/engine-envoy-protocol"; import { ActorEntry } from "./tasks/envoy"; +import { HibernatingWebSocketMetadata } from "./tasks/envoy/tunnel"; export interface KvListOptions { reverse?: boolean; @@ -14,6 +15,8 @@ export interface EnvoyHandle { getEnvoyKey(): string; + started(): Promise; + getActor(actorId: string, generation?: number): ActorEntry | undefined; /** Send sleep intent for an actor. */ @@ -81,4 +84,17 @@ export interface EnvoyHandle { /** Drop all key-value data for an actor. */ kvDrop(actorId: string): Promise; + + restoreHibernatingRequests( + actorId: string, + metaEntries: HibernatingWebSocketMetadata[], + ): void; + + sendHibernatableWebSocketMessageAck( + gatewayId: protocol.GatewayId, + requestId: protocol.RequestId, + clientMessageIndex: number, + ): void; + + startServerless(payload: ArrayBuffer): void; } diff --git a/engine/sdks/typescript/envoy-client/src/index.ts b/engine/sdks/typescript/envoy-client/src/index.ts index 9ed8e96299..71940fcb85 100644 --- a/engine/sdks/typescript/envoy-client/src/index.ts +++ b/engine/sdks/typescript/envoy-client/src/index.ts @@ -8,5 +8,6 @@ export { startEnvoy, startEnvoySync, } from "./tasks/envoy/index.js"; +export { type HibernatingWebSocketMetadata } from "./tasks/envoy/tunnel.js"; export * as protocol from "@rivetkit/engine-envoy-protocol"; export * as utils from './utils.js'; \ No newline at end of file diff --git a/engine/sdks/typescript/envoy-client/src/tasks/actor.ts b/engine/sdks/typescript/envoy-client/src/tasks/actor.ts index e0a5a816a3..8a8aff33fd 100644 --- a/engine/sdks/typescript/envoy-client/src/tasks/actor.ts +++ b/engine/sdks/typescript/envoy-client/src/tasks/actor.ts @@ -8,58 +8,78 @@ import { spawn } from "antiox/task"; import type { SharedContext } from "../context.js"; import { logger } from "../log.js"; import { unreachable } from "antiox/panic"; -import { stringifyError } from "../utils.js"; -import { sendResponse } from "./envoy/tunnel.js"; +import { arraysEqual, BufferMap, idToStr, stringifyError } from "../utils.js"; +import { HibernatingWebSocketMetadata } from "./envoy/tunnel.js"; +import { HIBERNATABLE_SYMBOL, WebSocketTunnelAdapter } from "@/websocket.js"; +import { wsSend } from "./connection.js"; +import { stringifyToRivetTunnelMessageKind } from "@/stringify.js"; export interface CreateActorOpts { - commandIdx: bigint; actorId: string; generation: number; config: protocol.ActorConfig; + hibernatingRequests: readonly protocol.HibernatingRequest[]; } -/** - * - * Stop sequence: - * 1. X -> Actor: stop-intent (optional) - * 1. Actor -> Envoy: send-events (optional) - * 1. Envoy -> Actor: command-stop-actor - * 1. Actor: async cleanup - * 1. Actor -> Envoy: state update (stopped) - */ - -// TODO: envoy lost export type ToActor = // Sent when wants to stop the actor, will be forwarded to Envoy | { - type: "actor-intent"; - commandIdx: bigint; + type: "intent"; intent: protocol.ActorIntent; + error?: string; } // Sent when actor is told to stop | { - type: "command-stop-actor"; + type: "stop"; commandIdx: bigint; reason: protocol.StopActorReason; } + | { type: "lost" } // Set or clear an alarm | { type: "set-alarm"; alarmTs: bigint | null; } | { - type: "request-start"; - messageId: protocol.MessageId, - req: protocol.ToEnvoyRequestStart, + type: "req-start"; + messageId: protocol.MessageId; + req: protocol.ToEnvoyRequestStart; } | { - type: "request-chunk"; - messageId: protocol.MessageId, + type: "req-chunk"; + messageId: protocol.MessageId; chunk: protocol.ToEnvoyRequestChunk; - } | { - type: "request-abort"; - messageId: protocol.MessageId, + } + | { + type: "req-abort"; + messageId: protocol.MessageId; + } + | { + type: "ws-open"; + messageId: protocol.MessageId; + path: string; + headers: ReadonlyMap; + } + | { + type: "ws-msg"; + messageId: protocol.MessageId; + msg: protocol.ToEnvoyWebSocketMessage; + } + | { + type: "ws-close"; + messageId: protocol.MessageId; + close: protocol.ToEnvoyWebSocketClose; + } + | { + type: "hws-restore"; + metaEntries: HibernatingWebSocketMetadata[]; + } + | { + type: "hws-ack"; + gatewayId: protocol.GatewayId; + requestId: protocol.RequestId; + envoyMessageIndex: number; }; interface ActorContext { @@ -67,15 +87,19 @@ interface ActorContext { actorId: string; generation: number; config: protocol.ActorConfig; + commandIdx: bigint; eventIndex: bigint; - pendingRequests: Map< - [protocol.GatewayId, protocol.RequestId], + error?: string; + + // Tunnel requests, not http requests + pendingRequests: BufferMap< PendingRequest >; - webSockets: Map< - [protocol.GatewayId, protocol.RequestId], + webSockets: BufferMap< WebSocketTunnelAdapter >; + hibernationRestored: boolean; + hibernatingRequests: readonly protocol.HibernatingRequest[]; } export function createActor( @@ -97,13 +121,14 @@ async function actorInner( actorId: opts.actorId, generation: opts.generation, config: opts.config, + commandIdx: 0n, eventIndex: 0n, - pendingRequests: new Map(), - // webSockets: new Map(), - }; - let stopCode = protocol.StopCode.Ok; - let stopMessage: string | null = null; + pendingRequests: new BufferMap(), + webSockets: new BufferMap(), + hibernationRestored: false, + hibernatingRequests: opts.hibernatingRequests, + }; try { await shared.config.onActorStart( @@ -119,18 +144,14 @@ async function actorInner( error: stringifyError(error), }); - stopCode = protocol.StopCode.Error; - stopMessage = - error instanceof Error ? error.message : "actor start failed"; - sendEvent(ctx, { tag: "EventActorStateUpdate", val: { state: { tag: "ActorStateStopped", val: { - code: stopCode, - message: stopMessage + code: protocol.StopCode.Error, + message: error instanceof Error ? error.message : "actor start failed" }, }, }, @@ -144,169 +165,62 @@ async function actorInner( }); for await (const msg of rx) { - if (msg.type === "actor-intent") { + if (msg.type === "intent") { sendEvent(ctx, { tag: "EventActorIntent", val: { intent: msg.intent }, }); - } else if (msg.type === "command-stop-actor") { - try { - await ctx.shared.config.onActorStop( - ctx.shared.handle, - ctx.actorId, - ctx.generation, - msg.reason, - ); - } catch (error) { - log(ctx)?.error({ - msg: "actor stop failed", - actorId: ctx.actorId, - error: stringifyError(error), + if (msg.error) ctx.error = msg.error; + } else if (msg.type === "stop") { + if (msg.commandIdx <= ctx.commandIdx) { + log(ctx)?.warn({ + msg: "ignoring already seen command", + commandIdx: msg.commandIdx }); - - stopCode = protocol.StopCode.Error; - stopMessage = - error instanceof Error - ? error.message - : "actor stop failed"; } + ctx.commandIdx = msg.commandIdx; - sendEvent(ctx, { - tag: "EventActorStateUpdate", - val: { - state: { - tag: "ActorStateStopped", - val: { - code: stopCode, - message: stopMessage - }, - }, - }, - }); - return; + handleStop(ctx, msg.reason); + break; + } else if (msg.type === "lost") { + handleStop(ctx, protocol.StopActorReason.Lost); + break; } else if (msg.type === "set-alarm") { sendEvent(ctx, { tag: "EventActorSetAlarm", val: { alarmTs: msg.alarmTs }, }); - } else if (msg.type === "request-start") { - // Convert headers map to Headers object - const headers = new Headers(); - for (const [key, value] of msg.req.headers) { - headers.append(key, value); - } - - // Create Request object - const request = new Request(`http://localhost${msg.req.path}`, { - method: msg.req.method, - headers, - body: msg.req.body ? new Uint8Array(msg.req.body) : undefined, - }); - - // Handle streaming request - if (msg.req.stream) { - // Create a stream for the request body - const stream = new ReadableStream({ - start: (controller) => { - // Store controller for chunks - ctx.pendingRequests.set( - [msg.messageId.gatewayId, msg.messageId.requestId], - { - clientMessageIndex: 0, - streamController: controller, - } - ); - }, - }); - - // Create request with streaming body - const streamingRequest = new Request(request, { - body: stream, - duplex: "half", - } as any); - - spawn(async () => { - const response = await ctx.shared.config.fetch( - ctx.shared.handle, - ctx.actorId, - msg.messageId.gatewayId, - msg.messageId.requestId, - streamingRequest, - ); - await sendResponse( - ctx.shared, - { - gatewayId: msg.messageId.gatewayId, - requestId: msg.messageId.requestId, - messageIndex: 0, - }, - response, - ); - }); - } else { - // Non-streaming request - spawn(async () => { - const response = await ctx.shared.config.fetch( - ctx.shared.handle, - ctx.actorId, - msg.messageId.gatewayId, - msg.messageId.requestId, - request, - ); - await sendResponse( - ctx.shared, - { - gatewayId: msg.messageId.gatewayId, - requestId: msg.messageId.requestId, - messageIndex: 0, - }, - response, - ); - }); - } - } else if (msg.type === "request-chunk") { - const existing = ctx.pendingRequests.get( - [msg.messageId.gatewayId, msg.messageId.requestId] - ); - if (existing) { - existing.streamController.enqueue(new Uint8Array(msg.chunk.body)); - - if (msg.chunk.finish) { - existing.streamController.close(); - - ctx.pendingRequests.delete( - [msg.messageId.gatewayId, msg.messageId.requestId], - ); - } - } else { - log(ctx)?.warn({ - msg: "received chunk for unknown pending request", - }); - } - } else if (msg.type === "request-abort") { - const existing = ctx.pendingRequests.get( - [msg.messageId.gatewayId, msg.messageId.requestId] - ); - if (existing) { - existing.streamController.error(new Error("Request aborted")); - - ctx.pendingRequests.delete( - [msg.messageId.gatewayId, msg.messageId.requestId], - ); - } else { - log(ctx)?.warn({ - msg: "received abort for unknown pending request", - }); - } + } else if (msg.type === "req-start") { + handleReqStart(ctx, msg.messageId, msg.req); + } else if (msg.type === "req-chunk") { + handleReqChunk(ctx, msg.messageId, msg.chunk); + } else if (msg.type === "req-abort") { + handleReqAbort(ctx, msg.messageId); + } else if (msg.type === "ws-open") { + handleWsOpen(ctx, msg.messageId, msg.path, msg.headers); + } else if (msg.type === "ws-msg") { + handleWsMessage(ctx, msg.messageId, msg.msg); + } else if (msg.type === "ws-close") { + handleWsClose(ctx, msg.messageId, msg.close); + } else if (msg.type === "hws-restore") { + handleHwsRestore(ctx, msg.metaEntries); + } else if (msg.type === "hws-ack") { + handleHwsAck(ctx, msg.gatewayId, msg.requestId, msg.envoyMessageIndex); } else { unreachable(msg); } } + + log(ctx)?.debug({ + msg: "envoy actor stopped" + }); + + rx.close(); } interface PendingRequest { - clientMessageIndex: number; - streamController: ReadableStreamDefaultController; + envoyMessageIndex: number; + streamController?: ReadableStreamDefaultController; } function sendEvent(ctx: ActorContext, inner: protocol.Event) { @@ -321,6 +235,440 @@ function sendEvent(ctx: ActorContext, inner: protocol.Event) { }); } +async function handleStop(ctx: ActorContext, reason: protocol.StopActorReason) { + let stopCode = ctx.error ? protocol.StopCode.Error : protocol.StopCode.Ok; + let stopMessage: string | null = ctx.error ?? null; + + try { + await ctx.shared.config.onActorStop( + ctx.shared.handle, + ctx.actorId, + ctx.generation, + reason, + ); + } catch (error) { + log(ctx)?.error({ + msg: "actor stop failed", + actorId: ctx.actorId, + error: stringifyError(error), + }); + + stopCode = protocol.StopCode.Error; + if (!stopMessage) { + stopMessage = + error instanceof Error + ? error.message + : "actor stop failed"; + } + } + + sendEvent(ctx, { + tag: "EventActorStateUpdate", + val: { + state: { + tag: "ActorStateStopped", + val: { + code: stopCode, + message: stopMessage + }, + }, + }, + }); +} + +function handleReqStart(ctx: ActorContext, messageId: protocol.MessageId, req: protocol.ToEnvoyRequestStart) { + let pendingReq: PendingRequest = { + envoyMessageIndex: 0, + }; + ctx.pendingRequests.set( + [messageId.gatewayId, messageId.requestId], + pendingReq, + ); + + // Convert headers map to Headers object + const headers = new Headers(); + for (const [key, value] of req.headers) { + headers.append(key, value); + } + + // Create Request object + const request = new Request(`http://localhost${req.path}`, { + method: req.method, + headers, + body: req.body ? new Uint8Array(req.body) : undefined, + }); + + // Handle streaming request + if (req.stream) { + // Create a stream for the request body + const stream = new ReadableStream({ + start: (controller) => { + // Store controller for chunks + pendingReq.streamController = controller; + }, + }); + + // Create request with streaming body + const streamingRequest = new Request(request, { + body: stream, + duplex: "half", + } as any); + + spawn(async () => { + const response = await ctx.shared.config.fetch( + ctx.shared.handle, + ctx.actorId, + messageId.gatewayId, + messageId.requestId, + streamingRequest, + ); + await sendResponse( + ctx, + messageId.gatewayId, + messageId.requestId, + response, + ); + }); + } else { + // Non-streaming request + spawn(async () => { + const response = await ctx.shared.config.fetch( + ctx.shared.handle, + ctx.actorId, + messageId.gatewayId, + messageId.requestId, + request, + ); + await sendResponse( + ctx, + messageId.gatewayId, + messageId.requestId, + response, + ); + ctx.pendingRequests.delete( + [messageId.gatewayId, messageId.requestId], + ); + }); + } +} + +function handleReqChunk(ctx: ActorContext, messageId: protocol.MessageId, chunk: protocol.ToEnvoyRequestChunk) { + const req = ctx.pendingRequests.get( + [messageId.gatewayId, messageId.requestId] + ); + if (req) { + if (req.streamController) { + req.streamController.enqueue(new Uint8Array(chunk.body)); + + if (chunk.finish) { + req.streamController.close(); + + ctx.pendingRequests.delete( + [messageId.gatewayId, messageId.requestId], + ); + } + } else { + log(ctx)?.warn({ + msg: "received chunk for pending request without stream controller", + }); + + } + } else { + log(ctx)?.warn({ + msg: "received chunk for unknown pending request", + }); + } +} + +function handleReqAbort(ctx: ActorContext, messageId: protocol.MessageId) { + const req = ctx.pendingRequests.get( + [messageId.gatewayId, messageId.requestId] + ); + if (req) { + if (req.streamController) { + req.streamController.error(new Error("Request aborted")); + } + + ctx.pendingRequests.delete( + [messageId.gatewayId, messageId.requestId], + ); + } else { + log(ctx)?.warn({ + msg: "received abort for unknown pending request", + }); + } +} + +async function handleWsOpen(ctx: ActorContext, messageId: protocol.MessageId, path: string, headers: ReadonlyMap) { + ctx.pendingRequests.set( + [messageId.gatewayId, messageId.requestId], + { + envoyMessageIndex: 0, + } + ); + + try { + // #createWebSocket will call `runner.config.websocket` under the + // hood to add the event listeners for open, etc. If this handler + // throws, then the WebSocket will be closed before sending the + // open event. + const adapter = await createWebSocket( + ctx, + messageId, + false, + path, + Object.fromEntries(headers), + ); + ctx.webSockets.set([messageId.gatewayId, messageId.requestId], adapter); + + sendMessage(ctx, messageId.gatewayId, messageId.requestId, { + tag: "ToRivetWebSocketOpen", + val: { + canHibernate: adapter[HIBERNATABLE_SYMBOL], + }, + }); + + adapter._handleOpen(); + } catch (error) { + log(ctx)?.error({ msg: "error handling websocket open", error }); + + // Send close on error + sendMessage(ctx, messageId.gatewayId, messageId.requestId, { + tag: "ToRivetWebSocketClose", + val: { + code: 1011, + reason: "Server Error", + hibernate: false, + }, + }); + + ctx.pendingRequests.delete([messageId.gatewayId, messageId.requestId]); + ctx.webSockets.delete([messageId.gatewayId, messageId.requestId]); + } +} + +function handleWsMessage(ctx: ActorContext, messageId: protocol.MessageId, msg: protocol.ToEnvoyWebSocketMessage) { + const ws = ctx.webSockets.get( + [messageId.gatewayId, messageId.requestId] + ); + if (ws) { + const data = msg.binary + ? new Uint8Array(msg.data) + : new TextDecoder().decode(new Uint8Array(msg.data)); + + ws._handleMessage( + data, + messageId.messageIndex, + msg.binary, + ); + } else { + log(ctx)?.warn({ + msg: "received message for unknown ws", + }); + } +} + +function handleWsClose(ctx: ActorContext, messageId: protocol.MessageId, close: protocol.ToEnvoyWebSocketClose) { + const ws = ctx.webSockets.get( + [messageId.gatewayId, messageId.requestId] + ); + if (ws) { + // We don't need to send a close response + ws._handleClose( + close.code || undefined, + close.reason || undefined, + ); + ctx.webSockets.delete( + [messageId.gatewayId, messageId.requestId] + ); + ctx.pendingRequests.delete( + [messageId.gatewayId, messageId.requestId] + ); + } else { + log(ctx)?.warn({ + msg: "received close for unknown ws", + }); + } +} + +async function handleHwsRestore(ctx: ActorContext, metaEntries: HibernatingWebSocketMetadata[]) { + if (ctx.hibernationRestored) { + throw new Error( + `Actor ${ctx.actorId} already restored hibernating requests`, + ); + } + + log(ctx)?.debug({ + msg: "restoring hibernating requests", + requests: ctx.hibernatingRequests.length, + }); + + // Track all background operations + const backgroundOperations: Promise[] = []; + + // Process connected WebSockets + let connectedButNotLoadedCount = 0; + let restoredCount = 0; + for (const { gatewayId, requestId } of ctx.hibernatingRequests) { + const requestIdStr = idToStr(requestId); + const meta = metaEntries.find( + (entry) => + arraysEqual(entry.gatewayId, gatewayId) && + arraysEqual(entry.requestId, requestId), + ); + + if (!meta) { + // Connected but not loaded (not persisted) - close it + // + // This may happen if the metadata was not successfully persisted + log(ctx)?.warn({ + msg: "closing websocket that is not persisted", + requestId: requestIdStr, + }); + + sendMessage(ctx, gatewayId, requestId, { + tag: "ToRivetWebSocketClose", + val: { + code: 1000, + reason: "ws.meta_not_found_during_restore", + hibernate: false, + }, + }); + + connectedButNotLoadedCount++; + } else { + ctx.pendingRequests.set([gatewayId, requestId], { envoyMessageIndex: 0 }); + + // This will call `runner.config.websocket` under the hood to + // attach the event listeners to the WebSocket. + // Track this operation to ensure it completes + const restoreOperation = createWebSocket( + ctx, + { + gatewayId, + requestId, + messageIndex: meta.rivetMessageIndex, + }, + true, + meta.path, + meta.headers, + ) + .then(adapter => { + ctx.webSockets.set([gatewayId, requestId], adapter); + + log(ctx)?.info({ + msg: "connection successfully restored", + requestId: requestIdStr, + }); + }) + .catch((err) => { + log(ctx)?.error({ + msg: "error creating websocket during restore", + requestId: requestIdStr, + error: stringifyError(err), + }); + + // Close the WebSocket on error + sendMessage(ctx, gatewayId, requestId, { + tag: "ToRivetWebSocketClose", + val: { + code: 1011, + reason: "ws.restore_error", + hibernate: false, + }, + }); + + ctx.pendingRequests.delete([gatewayId, requestId]); + }); + + backgroundOperations.push(restoreOperation); + restoredCount++; + } + } + + // Process loaded but not connected (stale) - remove them + let loadedButNotConnectedCount = 0; + for (const meta of metaEntries) { + const requestIdStr = idToStr(meta.requestId); + const isConnected = ctx.hibernatingRequests.some( + (req) => + arraysEqual(req.gatewayId, meta.gatewayId) && + arraysEqual(req.requestId, meta.requestId), + ); + if (!isConnected) { + log(ctx)?.warn({ + msg: "removing stale persisted websocket", + requestId: requestIdStr, + }); + + // Create adapter to register user's event listeners. + // Pass engineAlreadyClosed=true so close callback won't send tunnel message. + // Track this operation to ensure it completes + const cleanupOperation = createWebSocket( + ctx, + { + gatewayId: meta.gatewayId, + requestId: meta.requestId, + messageIndex: meta.rivetMessageIndex, + }, + true, + meta.path, + meta.headers, + ) + .then((adapter) => { + // Close the adapter normally - this will fire user's close event handler + // (which should clean up persistence) and trigger the close callback + // (which will clean up maps but skip sending tunnel message) + adapter.close(1000, "ws.stale_metadata"); + }) + .catch((err) => { + log(ctx)?.error({ + msg: "error creating stale websocket during restore", + requestId: requestIdStr, + error: stringifyError(err), + }); + }); + + backgroundOperations.push(cleanupOperation); + loadedButNotConnectedCount++; + } + } + + // Wait for all background operations to complete before finishing + await Promise.allSettled(backgroundOperations); + + // Mark restoration as complete + ctx.hibernationRestored = true; + + log(ctx)?.info({ + msg: "restored hibernatable websockets", + restoredCount, + connectedButNotLoadedCount, + loadedButNotConnectedCount, + }); +} + +function handleHwsAck(ctx: ActorContext, gatewayId: protocol.GatewayId, requestId: protocol.RequestId, envoyMessageIndex: number) { + const requestIdStr = idToStr(requestId); + + log(ctx)?.debug({ + msg: "ack ws msg", + requestId: requestIdStr, + index: envoyMessageIndex, + }); + + if (envoyMessageIndex < 0 || envoyMessageIndex > 65535) + throw new Error("Invalid websocket ack index"); + + // Send the ack message + sendMessage(ctx, gatewayId, requestId, { + tag: "ToRivetWebSocketMessageAck", + val: { + index: envoyMessageIndex, + }, + }); +} + function incrementCheckpoint(ctx: ActorContext): protocol.ActorCheckpoint { const index = ctx.eventIndex; ctx.eventIndex++; @@ -328,6 +676,190 @@ function incrementCheckpoint(ctx: ActorContext): protocol.ActorCheckpoint { return { actorId: ctx.actorId, generation: ctx.generation, index }; } +async function createWebSocket( + ctx: ActorContext, + messageId: protocol.MessageId, + isRestoringHibernatable: boolean, + path: string, + headers: Record, +): Promise { + // We need to manually ensure the original Upgrade/Connection WS + // headers are present + const fullHeaders = { + ...headers, + Upgrade: "websocket", + Connection: "Upgrade", + }; + + if (!path.startsWith("/")) { + throw new Error("Path must start with leading slash"); + } + + const request = new Request(`http://actor${path}`, { + method: "GET", + headers: fullHeaders, + }); + + const isHibernatable = isRestoringHibernatable || + ctx.shared.config.hibernatableWebSocket.canHibernate( + ctx.actorId, + messageId.gatewayId, + messageId.requestId, + request, + ); + + // Create WebSocket adapter + const adapter = new WebSocketTunnelAdapter( + ctx.shared, + ctx.actorId, + messageId.gatewayId, + messageId.requestId, + messageId.messageIndex, + isHibernatable, + isRestoringHibernatable, + request, + (data: ArrayBuffer | string, isBinary: boolean) => { + // Send message through tunnel + const dataBuffer = + typeof data === "string" + ? (new TextEncoder().encode(data).buffer as ArrayBuffer) + : data; + + sendMessage(ctx, messageId.gatewayId, messageId.requestId, { + tag: "ToRivetWebSocketMessage", + val: { + data: dataBuffer, + binary: isBinary, + }, + }); + }, + (code?: number, reason?: string) => { + sendMessage(ctx, messageId.gatewayId, messageId.requestId, { + tag: "ToRivetWebSocketClose", + val: { + code: code || null, + reason: reason || null, + hibernate: false, + }, + }); + + ctx.pendingRequests.delete([messageId.gatewayId, messageId.requestId]); + ctx.webSockets.delete([messageId.gatewayId, messageId.requestId]); + }, + ); + + // Call WebSocket handler. This handler will add event listeners + // for `open`, etc. Pass the VirtualWebSocket (not the adapter) to the actor. + await ctx.shared.config.websocket( + ctx.shared.handle, + ctx.actorId, + adapter.websocket, + messageId.gatewayId, + messageId.requestId, + request, + path, + headers, + isHibernatable, + isRestoringHibernatable, + ); + + return adapter; +} + +async function sendResponse(ctx: ActorContext, gatewayId: protocol.GatewayId, requestId: protocol.RequestId, response: Response) { + // Always treat responses as non-streaming for now + // In the future, we could detect streaming responses based on: + // - Transfer-Encoding: chunked + // - Content-Type: tbackgroundOperationsext/event-stream + // - Explicit stream flag from the handler + + // Read the body first to get the actual content + const body = response.body ? await response.arrayBuffer() : null; + + if (body && body.byteLength > (ctx.shared.protocolMetadata?.maxResponsePayloadSize ?? Infinity)) { + throw new Error("Response body too large"); + } + + // Convert headers to map and add Content-Length if not present + const headers = new Map(); + response.headers.forEach((value, key) => { + headers.set(key, value); + }); + + // Add Content-Length header if we have a body and it's not already set + if (body && !headers.has("content-length")) { + headers.set("content-length", String(body.byteLength)); + } + + sendMessage( + ctx, + gatewayId, + requestId, + { + tag: "ToRivetResponseStart", + val: { + status: response.status as protocol.u16, + headers, + body: body || null, + stream: false, + } + } + ); +} + +export async function sendMessage( + ctx: ActorContext, + gatewayId: protocol.GatewayId, + requestId: protocol.RequestId, + messageKind: protocol.ToRivetTunnelMessageKind, +) { + const gatewayIdStr = idToStr(gatewayId); + const requestIdStr = idToStr(requestId); + + // Get message index from pending request + const req = ctx.pendingRequests.get([gatewayId, requestId]); + if (!req) { + // No pending request + log(ctx)?.warn({ + msg: "missing pending request for send message", + gatewayId: gatewayIdStr, + requestId: requestIdStr, + }); + return; + } + + const envoyMessageIndex = req.envoyMessageIndex; + req.envoyMessageIndex++; + + const msg = { + messageId: { + gatewayId, + requestId, + messageIndex: envoyMessageIndex, + }, + messageKind, + }; + + const failed = wsSend( + ctx.shared, + { + tag: "ToRivetTunnelMessage", + val: msg, + }, + ); + + // Buffer message if not connected + if (failed) { + log(ctx)?.debug({ + msg: "buffering tunnel message, socket not connected to engine", + requestId: idToStr(requestId), + message: stringifyToRivetTunnelMessageKind(msg.messageKind), + }); + ctx.shared.envoyTx.send({ type: "buffer-tunnel-msg", msg }); + return; + } +} + function log(ctx: ActorContext) { const baseLogger = ctx.shared.config.logger ?? logger(); if (!baseLogger) return undefined; diff --git a/engine/sdks/typescript/envoy-client/src/tasks/connection.ts b/engine/sdks/typescript/envoy-client/src/tasks/connection.ts index 8b3bdd8d76..8e94bdfe36 100644 --- a/engine/sdks/typescript/envoy-client/src/tasks/connection.ts +++ b/engine/sdks/typescript/envoy-client/src/tasks/connection.ts @@ -1,7 +1,7 @@ import * as protocol from "@rivetkit/engine-envoy-protocol"; import type { UnboundedSender } from "antiox/sync/mpsc"; import { sleep } from "antiox/time"; -import { spawn } from "antiox/task"; +import { JoinHandle, spawn } from "antiox/task"; import type { SharedContext } from "../context.js"; import { logger } from "../log.js"; import { stringifyToEnvoy, stringifyToRivet } from "../stringify.js"; @@ -12,31 +12,42 @@ import { webSocket, } from "../websocket.js"; -export function startConnection(ctx: SharedContext) { - spawn(() => connectionLoop(ctx)); +export function startConnection(ctx: SharedContext): JoinHandle { + return spawn(signal => connectionLoop(ctx, signal)); } const STABLE_CONNECTION_MS = 60_000; -async function connectionLoop(ctx: SharedContext) { +async function connectionLoop(ctx: SharedContext, signal: AbortSignal) { let attempt = 0; while (true) { const connectedAt = Date.now(); try { - const res = await singleConnection(ctx); + const res = await singleConnection(ctx, signal); - if (res?.group == 'ws' && res?.error == "eviction") { - log(ctx)?.debug({ - msg: "connection evicted", - }); - ctx.envoyTx.send({ type: "evict" }); - return; + if (res) { + if (res.group === "ws" && res.error === "eviction") { + log(ctx)?.debug({ + msg: "connection evicted", + }); + + ctx.envoyTx.send({ type: "conn-close", evict: true }); + + return; + } else if (res.group === 'channel' && res.error === "closed") { + // Client side shutdown + return; + } } + + ctx.envoyTx.send({ type: "conn-close", evict: false }); } catch (error) { log(ctx)?.error({ msg: "connection failed", error, }); + + ctx.envoyTx.send({ type: "conn-close", evict: false }); } if (Date.now() - connectedAt >= STABLE_CONNECTION_MS) { @@ -54,7 +65,7 @@ async function connectionLoop(ctx: SharedContext) { } } -async function singleConnection(ctx: SharedContext): Promise { +async function singleConnection(ctx: SharedContext, signal: AbortSignal): Promise { const { config } = ctx; const protocols = ["rivet"]; @@ -114,6 +125,8 @@ async function singleConnection(ctx: SharedContext): Promise>; kvRequests: Map; nextKvRequestId: number; - // Maps tunnel requests to actors - requestToActor: Map<[protocol.GatewayId, protocol.RequestId], string>; + // Maps tunnel requests to actors (not http requests) + requestToActor: BufferMap; + bufferedMessages: protocol.ToRivetTunnelMessage[]; } export interface ActorEntry { @@ -59,6 +63,10 @@ export type ToEnvoyFromConnMessage = Exclude< export type ToEnvoyMessage = // Inbound from connection | { type: "conn-message"; message: ToEnvoyFromConnMessage } + | { + type: "conn-close"; + evict: boolean; + } // Sent from actor | { type: "send-events"; @@ -71,20 +79,21 @@ export type ToEnvoyMessage = resolve: (data: protocol.KvResponseData) => void; reject: (error: Error) => void; } + | { type: "buffer-tunnel-msg", msg: protocol.ToRivetTunnelMessage } | { type: "shutdown" } - | { type: "evict" }; + | { type: "stop" }; export async function startEnvoy(config: EnvoyConfig): Promise { - const [handle, startRx] = startEnvoySync(config); + const handle = startEnvoySync(config); // Wait for envoy start - await startRx.changed(); + await handle.started(); return handle; } // Must manually wait for envoy to start. -export function startEnvoySync(config: EnvoyConfig): [EnvoyHandle, WatchReceiver] { +export function startEnvoySync(config: EnvoyConfig): EnvoyHandle { const [envoyTx, envoyRx] = unboundedChannel(); const [startTx, startRx] = watch(void 0); const actors: Map> = new Map(); @@ -97,18 +106,20 @@ export function startEnvoySync(config: EnvoyConfig): [EnvoyHandle, WatchReceiver handle: null as any, }; - startConnection(shared); + const connHandle = startConnection(shared); const ctx: EnvoyContext = { shared, + serverless: false, actors, kvRequests: new Map(), nextKvRequestId: 0, - requestToActor: new Map(), + requestToActor: new BufferMap(), + bufferedMessages: [], }; // Set shared handle - const handle = createHandle(ctx); + const handle = createHandle(ctx, startRx); shared.handle = handle; log(ctx.shared)?.info({ msg: "starting envoy" }); @@ -122,19 +133,65 @@ export function startEnvoySync(config: EnvoyConfig): [EnvoyHandle, WatchReceiver cleanupOldKvRequests(ctx); }, KV_CLEANUP_INTERVAL_MS); + let lostTimeout: NodeJS.Timeout | undefined = undefined; + for await (const msg of envoyRx) { if (msg.type === "conn-message") { - await handleConnMessage(ctx, startTx, msg.message); + await handleConnMessage(ctx, startTx, lostTimeout, msg.message); + } else if (msg.type === "conn-close") { + await handleConnClose(ctx, lostTimeout); + if (msg.evict) break; } else if (msg.type === "send-events") { - handleSendEvents(ctx, msg.events); + const stop = handleSendEvents(ctx, msg.events); + + if (stop) { + log(ctx.shared)?.info({ + msg: "serverless actor stopped, stopping envoy" + }); + break; + } } else if (msg.type === "kv-request") { handleKvRequest(ctx, msg); + } else if (msg.type === "buffer-tunnel-msg") { + ctx.bufferedMessages.push(msg.msg); } else if (msg.type === "shutdown") { wsSend(ctx.shared, { tag: "ToRivetStopping", val: null, }); - } else if (msg.type === "evict") { + + // Start shutdown checker + spawn(async () => { + let i = 0; + + while (true) { + let total = 0; + + // Check for actors with open handles + for (const gens of ctx.actors.values()) { + const last = Array.from(gens.values())[gens.size - 1]; + + if (last && !last.handle.isClosed()) total++; + } + + // Wait until no actors remain + if (total === 0) { + ctx.shared.envoyTx.send({ type: "stop" }); + break; + } + + await sleep(1000); + + if (i % 10 === 0) { + log(ctx.shared)?.info({ + msg: "waiting on actors to stop before shutdown", + actors: total, + }); + } + i++; + } + }); + } else if (msg.type === "stop") { break; } else { unreachable(msg); @@ -146,6 +203,7 @@ export function startEnvoySync(config: EnvoyConfig): [EnvoyHandle, WatchReceiver }); // Cleanup + ctx.shared.wsTx?.send({ type: "close", code: 1000, reason: "envoy.shutdown" }); clearInterval(ackInterval); clearInterval(kvCleanupInterval); @@ -162,12 +220,18 @@ export function startEnvoySync(config: EnvoyConfig): [EnvoyHandle, WatchReceiver ctx.actors.clear(); }); - return [handle, startRx]; + // Queue start actor + if (shared.config.serverlessStartPayload) { + handle.startServerless(shared.config.serverlessStartPayload); + } + + return handle; } async function handleConnMessage( ctx: EnvoyContext, startTx: WatchSender, + lostTimeout: NodeJS.Timeout | undefined, message: ToEnvoyFromConnMessage, ) { if (message.tag === "ToEnvoyInit") { @@ -177,8 +241,10 @@ async function handleConnMessage( protocolMetadata: message.val.metadata, }); + clearTimeout(lostTimeout); resendUnacknowledgedEvents(ctx); processUnsentKvRequests(ctx); + resendBufferedTunnelMessages(ctx); startTx.send(); } else if (message.tag === "ToEnvoyCommands") { @@ -194,6 +260,45 @@ async function handleConnMessage( } } +async function handleConnClose(ctx: EnvoyContext, lostTimeout: NodeJS.Timeout | undefined) { + if (!lostTimeout) { + let lostThreshold = ctx.shared.protocolMetadata ? Number(ctx.shared.protocolMetadata.envoyLostThreshold) : 10000; + log(ctx.shared)?.debug({ + msg: "starting runner lost timeout", + seconds: lostThreshold / 1000, + }); + + lostTimeout = setTimeout( + () => { + // Remove all remaining kv requests + for (const [_, request] of ctx.kvRequests.entries()) { + request.reject(new EnvoyShutdownError()); + } + + ctx.kvRequests.clear(); + + if (ctx.actors.size == 0) return; + + log(ctx.shared)?.warn({ + msg: "stopping all actors due to runner lost threshold", + }); + + // Stop all actors + for (const [_, gens] of ctx.actors) { + for (const [_, entry] of gens) { + if (!entry.handle.isClosed()) { + entry.handle.send({ type: "lost" }); + } + } + } + + ctx.actors.clear(); + }, + lostThreshold, + ); + } +} + // MARK: Util export function log(ctx: SharedContext) { @@ -220,7 +325,10 @@ export function getActorEntry( function createHandle( ctx: EnvoyContext, + startRx: WatchReceiver, ): EnvoyHandle { + let startedPromise = startRx.changed(); + return { shutdown(immediate: boolean) { ctx.shared.envoyTx.send({ type: "shutdown" }); @@ -235,6 +343,10 @@ function createHandle( return ctx.shared.envoyKey; }, + started(): Promise { + return startedPromise; + }, + getActor(actorId: string, generation?: number): ActorEntry | undefined { return getActor(ctx, actorId, generation); }, @@ -248,12 +360,13 @@ function createHandle( ); }, - stopActor(actorId: string, generation?: number): void { + stopActor(actorId: string, generation?: number, error?: string): void { sendActorIntent( ctx, actorId, { tag: "ActorIntentStop", val: null }, generation, + error, ); }, @@ -426,6 +539,54 @@ function createHandle( val: null, }); }, + + restoreHibernatingRequests( + actorId: string, + metaEntries: HibernatingWebSocketMetadata[], + ) { + const actor = getActor(ctx, actorId); + if (!actor) { + throw new Error( + `Actor ${actorId} not found for restoring hibernating requests`, + ); + } + + actor.handle.send({ type: "hws-restore", metaEntries }); + }, + + sendHibernatableWebSocketMessageAck( + gatewayId: protocol.GatewayId, + requestId: protocol.RequestId, + clientMessageIndex: number, + ) { + sendHibernatableWebSocketMessageAck(ctx, gatewayId, requestId, clientMessageIndex); + }, + + startServerless(payload: ArrayBuffer) { + if (ctx.serverless) throw new Error("Already started serverless actor"); + ctx.serverless = true; + + let version = new DataView(payload).getUint16(0, true); + + if (version != protocol.VERSION) + throw new Error(`Serverless start payload does not match protocol version: ${version} vs ${protocol.VERSION}`); + + // Skip first 2 bytes (version) + const message = protocol.decodeToEnvoy(new Uint8Array(payload, 2)); + + if (message.tag !== "ToEnvoyCommands") throw new Error("invalid serverless body"); + if (message.val.length !== 1) throw new Error("invalid serverless body"); + if (message.val[0].inner.tag !== "CommandStartActor") throw new Error("invalid serverless body"); + + // Wait for envoy to start before adding message + startedPromise.then(() => { + log(ctx.shared)?.debug({ + msg: "received serverless start", + data: stringifyToEnvoy(message), + }); + ctx.shared.envoyTx.send({ type: "conn-message", message }); + }); + } }; } @@ -434,13 +595,14 @@ function sendActorIntent( actorId: string, intent: protocol.ActorIntent, generation?: number, + error?: string, ): void { const entry = getActor(ctx, actorId, generation); if (!entry) return; entry.handle.send({ - type: "actor-intent", - commandIdx: 0n, + type: "intent", intent, + error, }); } diff --git a/engine/sdks/typescript/envoy-client/src/tasks/envoy/tunnel.ts b/engine/sdks/typescript/envoy-client/src/tasks/envoy/tunnel.ts index 6f9d5fbcdf..686cffc010 100644 --- a/engine/sdks/typescript/envoy-client/src/tasks/envoy/tunnel.ts +++ b/engine/sdks/typescript/envoy-client/src/tasks/envoy/tunnel.ts @@ -3,6 +3,18 @@ import { EnvoyContext, getActor, log } from "./index.js"; import { SharedContext } from "@/context.js"; import { unreachable } from "antiox"; import { wsSend } from "../connection.js"; +import { idToStr } from "@/utils.js"; +import { stringifyToRivetTunnelMessageKind } from "@/stringify.js"; + +export interface HibernatingWebSocketMetadata { + gatewayId: protocol.GatewayId; + requestId: protocol.RequestId; + envoyMessageIndex: number; + rivetMessageIndex: number; + + path: string; + headers: Record; +} export function handleTunnelMessage(ctx: EnvoyContext, msg: protocol.ToEnvoyTunnelMessage) { const { @@ -21,7 +33,7 @@ export function handleTunnelMessage(ctx: EnvoyContext, msg: protocol.ToEnvoyTunn } else if (tag === "ToEnvoyWebSocketMessage") { handleWebSocketMessage(ctx, messageId, val); } else if (tag === "ToEnvoyWebSocketClose") { - handleWebSocketClose(ctx, messageId); + handleWebSocketClose(ctx, messageId, val); } else { unreachable(tag); } @@ -36,20 +48,15 @@ function handleRequestStart(ctx: EnvoyContext, messageId: protocol.MessageId, re actorId: req.actorId, }); - // NOTE: This is a special response that will cause Guard to retry the request - // - // See should_retry_request_inner - // https://github.com/rivet-dev/rivet/blob/222dae87e3efccaffa2b503de40ecf8afd4e31eb/engine/packages/guard-core/src/proxy_service.rs#L2458 - sendResponse(ctx.shared, messageId, new Response("Actor not found", { - status: 503, - headers: { "x-rivet-error": "envoy.actor_not_found" }, - })); + sendErrorResponse(ctx, messageId.gatewayId, messageId.requestId); return; } + ctx.requestToActor.set([messageId.gatewayId, messageId.requestId], req.actorId); + actor.handle.send({ - type: "request-start", + type: "req-start", messageId, req, }); @@ -60,7 +67,7 @@ function handleRequestChunk(ctx: EnvoyContext, messageId: protocol.MessageId, ch if (actorId) { let actor = getActor(ctx, actorId); if (actor) { - actor.handle.send({ type: "request-chunk", messageId, chunk }); + actor.handle.send({ type: "req-chunk", messageId, chunk }); } } @@ -74,54 +81,166 @@ function handleRequestAbort(ctx: EnvoyContext, messageId: protocol.MessageId) { if (actorId) { let actor = getActor(ctx, actorId); if (actor) { - actor.handle.send({ type: "request-abort", messageId }); + actor.handle.send({ type: "req-abort", messageId }); } } ctx.requestToActor.delete([messageId.gatewayId, messageId.requestId]); } -export async function sendResponse(ctx: SharedContext, messageId: protocol.MessageId, response: Response) { - // Always treat responses as non-streaming for now - // In the future, we could detect streaming responses based on: - // - Transfer-Encoding: chunked - // - Content-Type: text/event-stream - // - Explicit stream flag from the handler +function handleWebSocketOpen(ctx: EnvoyContext, messageId: protocol.MessageId, open: protocol.ToEnvoyWebSocketOpen) { + const actor = getActor(ctx, open.actorId); + + if (!actor) { + log(ctx.shared)?.warn({ + msg: "received request for unknown actor", + actorId: open.actorId, + }); + + wsSend(ctx.shared, { + tag: "ToRivetTunnelMessage", + val: { + messageId, + messageKind: { + tag: "ToRivetWebSocketClose", + val: { + code: 1011, + reason: "Actor not found", + hibernate: false, + }, + } + } + }); + + return; + } + + ctx.requestToActor.set([messageId.gatewayId, messageId.requestId], open.actorId); + + actor.handle.send({ + type: "ws-open", + messageId, + path: open.path, + headers: open.headers, + }); +} + +function handleWebSocketMessage(ctx: EnvoyContext, messageId: protocol.MessageId, msg: protocol.ToEnvoyWebSocketMessage) { + const actorId = ctx.requestToActor.get([messageId.gatewayId, messageId.requestId]); + if (actorId) { + let actor = getActor(ctx, actorId); + if (actor) { + actor.handle.send({ type: "ws-msg", messageId, msg }); + } + } +} + +function handleWebSocketClose(ctx: EnvoyContext, messageId: protocol.MessageId, close: protocol.ToEnvoyWebSocketClose) { + const actorId = ctx.requestToActor.get([messageId.gatewayId, messageId.requestId]); + if (actorId) { + let actor = getActor(ctx, actorId); + if (actor) { + actor.handle.send({ type: "ws-close", messageId, close }); + } + } - // Read the body first to get the actual content - const body = response.body ? await response.arrayBuffer() : null; + ctx.requestToActor.delete([messageId.gatewayId, messageId.requestId]); +} + +export function sendHibernatableWebSocketMessageAck( + ctx: EnvoyContext, + gatewayId: protocol.GatewayId, + requestId: protocol.RequestId, + envoyMessageIndex: number, +) { + const actorId = ctx.requestToActor.get([gatewayId, requestId]); + if (actorId) { + let actor = getActor(ctx, actorId); + if (actor) { + actor.handle.send({ type: "hws-ack", gatewayId, requestId, envoyMessageIndex }); + } + } +} - if (body && body.byteLength > ctx.protocolMetadata?.maxPayloadSize) { - throw new Error("Response body too large"); +export function resendBufferedTunnelMessages(ctx: EnvoyContext) { + if (ctx.bufferedMessages.length === 0) { + return; } - // Convert headers to map and add Content-Length if not present - const headers = new Map(); - response.headers.forEach((value, key) => { - headers.set(key, value); + log(ctx.shared)?.info({ + msg: "resending buffered tunnel messages", + count: ctx.bufferedMessages.length, }); + const messages = ctx.bufferedMessages; + ctx.bufferedMessages = []; + + for (const msg of messages) { + wsSend( + ctx.shared, + { + tag: "ToRivetTunnelMessage", + val: msg, + }, + ); + } +} + +// NOTE: This is a special response that will cause Guard to retry the request +// +// See should_retry_request_inner +// https://github.com/rivet-dev/rivet/blob/222dae87e3efccaffa2b503de40ecf8afd4e31eb/engine/packages/guard-core/src/proxy_service.rs#L2458 +function sendErrorResponse(ctx: EnvoyContext, gatewayId: protocol.GatewayId, requestId: protocol.RequestId) { + const body = new TextEncoder().encode("Actor not found").buffer; + const headers = new Map([["x-rivet-error", "envoy.actor_not_found"]]); + // Add Content-Length header if we have a body and it's not already set if (body && !headers.has("content-length")) { headers.set("content-length", String(body.byteLength)); } - wsSend( - ctx, { - tag: "ToRivetTunnelMessage", - val: { - messageId, - messageKind: { - tag: "ToRivetResponseStart", - val: { - status: response.status as protocol.u16, - headers, - body: body || null, - stream: false, - } + sendMessage( + ctx, + gatewayId, + requestId, + { + tag: "ToRivetResponseStart", + val: { + status: 503, + headers, + body: body, + stream: false, } } - } + ); +} + +export async function sendMessage(ctx: EnvoyContext, gatewayId: protocol.GatewayId, requestId: protocol.RequestId, msg: protocol.ToRivetTunnelMessageKind) { + const payload = { + messageId: { + gatewayId, + requestId, + messageIndex: 0, + }, + messageKind: msg, + }; + + const failed = wsSend( + ctx.shared, + { + tag: "ToRivetTunnelMessage", + val: payload + }, ); -} \ No newline at end of file + // Buffer message if not connected + if (failed) { + log(ctx.shared)?.debug({ + msg: "buffering tunnel message, socket not connected to engine", + requestId: idToStr(requestId), + message: stringifyToRivetTunnelMessageKind(msg), + }); + ctx.bufferedMessages.push(payload); + return; + } +} diff --git a/engine/sdks/typescript/envoy-client/src/utils.ts b/engine/sdks/typescript/envoy-client/src/utils.ts index 1cb4965a86..f1a76bbac2 100644 --- a/engine/sdks/typescript/envoy-client/src/utils.ts +++ b/engine/sdks/typescript/envoy-client/src/utils.ts @@ -1,5 +1,48 @@ import { logger } from "./log"; +export class BufferMap { + #inner: Map; + constructor() { + this.#inner = new Map(); + } + + get(buffers: ArrayBuffer[]): T | undefined { + return this.#inner.get(cyrb53(buffers)); + } + + set(buffers: ArrayBuffer[], value: T) { + this.#inner.set(cyrb53(buffers), value); + } + + delete(buffers: ArrayBuffer[]): boolean { + return this.#inner.delete(cyrb53(buffers)); + } + + has(buffers: ArrayBuffer[]): boolean { + return this.#inner.has(cyrb53(buffers)); + } +} + +function cyrb53(buffers: ArrayBuffer[], seed: number = 0): string { + let h1 = 0xdeadbeef ^ seed, h2 = 0x41c6ce57 ^ seed; + for (const buffer of buffers) { + const bytes = new Uint8Array(buffer); + for (const b of bytes) { + h1 = Math.imul(h1 ^ b, 2654435761); + h2 = Math.imul(h2 ^ b, 1597334677); + } + } + h1 = Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^ Math.imul(h2 ^ (h2 >>> 13), 3266489909); + h2 = Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^ Math.imul(h1 ^ (h1 >>> 13), 3266489909); + return (4294967296 * (2097151 & h2) + (h1 >>> 0)).toString(16); +} + +export class EnvoyShutdownError extends Error { + constructor() { + super("Envoy shut down"); + } +} + /** Resolves after the configured debug latency, or immediately if none. */ export function injectLatency(ms?: number): Promise { if (!ms) return Promise.resolve(); diff --git a/engine/sdks/typescript/envoy-client/src/websocket.ts b/engine/sdks/typescript/envoy-client/src/websocket.ts index cd74066fda..9001ca2a6d 100644 --- a/engine/sdks/typescript/envoy-client/src/websocket.ts +++ b/engine/sdks/typescript/envoy-client/src/websocket.ts @@ -1,3 +1,4 @@ +import type * as protocol from "@rivetkit/engine-runner-protocol"; import type { UnboundedReceiver, UnboundedSender } from "antiox/sync/mpsc"; import { OnceCell } from "antiox/sync/once_cell"; import { spawn } from "antiox/task"; @@ -5,9 +6,10 @@ import type WsWebSocket from "ws"; import { latencyChannel } from "./latency-channel.js"; import { logger } from "./log.js"; import { VirtualWebSocket, type UniversalWebSocket, type RivetMessageEvent } from "@rivetkit/virtual-websocket"; -import { wrappingAddU16, wrappingLteU16, wrappingSubU16 } from "./utils"; +import { idToStr, wrappingAddU16, wrappingLteU16, wrappingSubU16 } from "./utils"; import { SharedContext } from "./context.js"; import { log } from "./tasks/envoy/index.js"; +import { unreachable } from "antiox"; export const HIBERNATABLE_SYMBOL = Symbol("hibernatable"); @@ -87,24 +89,28 @@ export async function webSocket( }); raw.addEventListener("close", (event) => { - inboundTx.send({ - type: "close", - code: event.code, - reason: event.reason, - }); + if (!inboundTx.isClosed()) { + inboundTx.send({ + type: "close", + code: event.code, + reason: event.reason, + }); + } inboundTx.close(); outboundRx.close(); }); raw.addEventListener("error", (event) => { - const error = - typeof event === "object" && event !== null && "error" in event - ? event.error - : new Error("WebSocket error"); - inboundTx.send({ - type: "error", - error: error instanceof Error ? error : new Error(String(error)), - }); + if (!inboundTx.isClosed()) { + const error = + typeof event === "object" && event !== null && "error" in event + ? event.error + : new Error("WebSocket error"); + inboundTx.send({ + type: "error", + error: error instanceof Error ? error : new Error(String(error)), + }); + } inboundTx.close(); outboundRx.close(); }); @@ -113,9 +119,11 @@ export async function webSocket( for await (const message of outboundRx) { if (message.type === "send") { raw.send(message.data); - } else { + } else if (message.type === "close") { raw.close(message.code, message.reason); break; + } else { + unreachable(message); } } @@ -141,9 +149,10 @@ export class WebSocketTunnelAdapter { #shared: SharedContext; #ws: VirtualWebSocket; #actorId: string; - #requestId: string; + #gatewayId: protocol.GatewayId; + #requestId: protocol.RequestId; #hibernatable: boolean; - #messageIndex: number; + #rivetMessageIndex: number; #sendCallback: (data: ArrayBuffer | string, isBinary: boolean) => void; #closeCallback: (code?: number, reason?: string) => void; @@ -154,8 +163,9 @@ export class WebSocketTunnelAdapter { constructor( ctx: SharedContext, actorId: string, - requestId: string, - messageIndex: number, + gatewayId: protocol.GatewayId, + requestId: protocol.RequestId, + rivetMessageIndex: number, hibernatable: boolean, isRestoringHibernatable: boolean, public readonly request: Request, @@ -164,9 +174,10 @@ export class WebSocketTunnelAdapter { ) { this.#shared = ctx; this.#actorId = actorId; + this.#gatewayId = gatewayId; this.#requestId = requestId; this.#hibernatable = hibernatable; - this.#messageIndex = messageIndex; + this.#rivetMessageIndex = rivetMessageIndex; this.#sendCallback = sendCallback; this.#closeCallback = closeCallback; @@ -181,7 +192,7 @@ export class WebSocketTunnelAdapter { log(this.#shared)?.debug({ msg: "setting WebSocket to OPEN state for restored connection", actorId: this.#actorId, - requestId: this.#requestId, + requestId: idToStr(this.#requestId), }); this.#readyState = 1; } @@ -226,23 +237,22 @@ export class WebSocketTunnelAdapter { } // Called by Tunnel when WebSocket is opened - _handleOpen(requestId: ArrayBuffer): void { + _handleOpen(): void { if (this.#readyState !== 0) return; this.#readyState = 1; - this.#ws.dispatchEvent({ type: "open", rivetRequestId: requestId, target: this.#ws }); + this.#ws.dispatchEvent({ type: "open", rivetGatewayId: this.#gatewayId, rivetRequestId: this.#requestId, target: this.#ws }); } // Called by Tunnel when message is received _handleMessage( - requestId: ArrayBuffer, data: string | Uint8Array, - messageIndex: number, + rivetMessageIndex: number, isBinary: boolean, ): boolean { if (this.#readyState !== 1) { log(this.#shared)?.warn({ msg: "WebSocket message ignored - not in OPEN state", - requestId: this.#requestId, + requestId: idToStr(this.#requestId), actorId: this.#actorId, currentReadyState: this.#readyState, }); @@ -251,37 +261,37 @@ export class WebSocketTunnelAdapter { // Validate message index for hibernatable websockets if (this.#hibernatable) { - const previousIndex = this.#messageIndex; + const previousIndex = this.#rivetMessageIndex; - if (wrappingLteU16(messageIndex, previousIndex)) { + if (wrappingLteU16(rivetMessageIndex, previousIndex)) { log(this.#shared)?.info({ msg: "received duplicate hibernating websocket message", - requestId, + requestId: idToStr(this.#requestId), actorId: this.#actorId, previousIndex, - receivedIndex: messageIndex, + receivedIndex: rivetMessageIndex, }); return true; } const expectedIndex = wrappingAddU16(previousIndex, 1); - if (messageIndex !== expectedIndex) { + if (rivetMessageIndex !== expectedIndex) { const closeReason = "ws.message_index_skip"; log(this.#shared)?.warn({ msg: "hibernatable websocket message index out of sequence, closing connection", - requestId, + requestId: idToStr(this.#requestId), actorId: this.#actorId, previousIndex, expectedIndex, - receivedIndex: messageIndex, + receivedIndex: rivetMessageIndex, closeReason, - gap: wrappingSubU16(wrappingSubU16(messageIndex, previousIndex), 1), + gap: wrappingSubU16(wrappingSubU16(rivetMessageIndex, previousIndex), 1), }); this.#close(1008, closeReason, true); return true; } - this.#messageIndex = messageIndex; + this.#rivetMessageIndex = rivetMessageIndex; } // Convert data based on binaryType @@ -297,8 +307,9 @@ export class WebSocketTunnelAdapter { this.#ws.dispatchEvent({ type: "message", data: messageData, - rivetRequestId: requestId, - rivetMessageIndex: messageIndex, + rivetGatewayId: this.#gatewayId, + rivetRequestId: this.#requestId, + rivetMessageIndex: rivetMessageIndex, target: this.#ws, } as RivetMessageEvent); @@ -306,7 +317,7 @@ export class WebSocketTunnelAdapter { } // Called by Tunnel when close is received - _handleClose(_requestId: ArrayBuffer, code?: number, reason?: string): void { + _handleClose(code?: number, reason?: string): void { this.#close(code, reason, true); } diff --git a/engine/sdks/typescript/envoy-protocol/src/index.ts b/engine/sdks/typescript/envoy-protocol/src/index.ts index ff5bb6cb56..ba20853621 100644 --- a/engine/sdks/typescript/envoy-protocol/src/index.ts +++ b/engine/sdks/typescript/envoy-protocol/src/index.ts @@ -2028,12 +2028,11 @@ export function writeToEnvoyConnPing(bc: bare.ByteCursor, x: ToEnvoyConnPing): v bare.writeI64(bc, x.ts) } -/** - * We have to re-declare the entire union since BARE will not generate the - * ser/de for ToEnvoy if it's not a top-level type - */ +export type ToEnvoyConnClose = null + export type ToEnvoyConn = | { readonly tag: "ToEnvoyConnPing"; readonly val: ToEnvoyConnPing } + | { readonly tag: "ToEnvoyConnClose"; readonly val: ToEnvoyConnClose } | { readonly tag: "ToEnvoyCommands"; readonly val: ToEnvoyCommands } | { readonly tag: "ToEnvoyAckEvents"; readonly val: ToEnvoyAckEvents } | { readonly tag: "ToEnvoyTunnelMessage"; readonly val: ToEnvoyTunnelMessage } @@ -2045,10 +2044,12 @@ export function readToEnvoyConn(bc: bare.ByteCursor): ToEnvoyConn { case 0: return { tag: "ToEnvoyConnPing", val: readToEnvoyConnPing(bc) } case 1: - return { tag: "ToEnvoyCommands", val: readToEnvoyCommands(bc) } + return { tag: "ToEnvoyConnClose", val: null } case 2: - return { tag: "ToEnvoyAckEvents", val: readToEnvoyAckEvents(bc) } + return { tag: "ToEnvoyCommands", val: readToEnvoyCommands(bc) } case 3: + return { tag: "ToEnvoyAckEvents", val: readToEnvoyAckEvents(bc) } + case 4: return { tag: "ToEnvoyTunnelMessage", val: readToEnvoyTunnelMessage(bc) } default: { bc.offset = offset @@ -2064,18 +2065,22 @@ export function writeToEnvoyConn(bc: bare.ByteCursor, x: ToEnvoyConn): void { writeToEnvoyConnPing(bc, x.val) break } - case "ToEnvoyCommands": { + case "ToEnvoyConnClose": { bare.writeU8(bc, 1) + break + } + case "ToEnvoyCommands": { + bare.writeU8(bc, 2) writeToEnvoyCommands(bc, x.val) break } case "ToEnvoyAckEvents": { - bare.writeU8(bc, 2) + bare.writeU8(bc, 3) writeToEnvoyAckEvents(bc, x.val) break } case "ToEnvoyTunnelMessage": { - bare.writeU8(bc, 3) + bare.writeU8(bc, 4) writeToEnvoyTunnelMessage(bc, x.val) break } diff --git a/engine/sdks/typescript/test-envoy/src/index.ts b/engine/sdks/typescript/test-envoy/src/index.ts index 7edbb3edcd..d4f56d8ebc 100644 --- a/engine/sdks/typescript/test-envoy/src/index.ts +++ b/engine/sdks/typescript/test-envoy/src/index.ts @@ -1,6 +1,6 @@ import { serve } from "@hono/node-server"; import * as protocol from "@rivetkit/engine-envoy-protocol"; -import { EnvoyHandle, startEnvoy } from "@rivetkit/engine-envoy-client"; +import { EnvoyHandle, startEnvoy, startEnvoySync } from "@rivetkit/engine-envoy-client"; import { Hono, type Context as HonoContext, type Next } from "hono"; import { streamSSE } from "hono/streaming"; import type { Logger } from "pino"; @@ -23,6 +23,122 @@ const AUTOCONFIGURE_SERVERLESS = (process.env.AUTOCONFIGURE_SERVERLESS ?? "1") = let envoy: EnvoyHandle | null = null; const websocketLastMsgIndexes: Map = new Map(); +const config = { + logger: getLogger(), + version: RIVET_ENVOY_VERSION, + endpoint: RIVET_ENDPOINT, + token: RIVET_TOKEN, + namespace: RIVET_NAMESPACE, + poolName: RIVET_POOL_NAME, + prepopulateActorNames: {}, + fetch: async ( + envoy: EnvoyHandle, + actorId: string, + _gatewayId: ArrayBuffer, + _requestId: ArrayBuffer, + request: Request, + ) => { + getLogger().info( + `Fetch called for actor ${actorId}, URL: ${request.url}`, + ); + const url = new URL(request.url); + if (url.pathname === "/ping") { + // Return the actor ID in response + const responseData = { + actorId, + status: "ok", + timestamp: Date.now(), + }; + + return new Response(JSON.stringify(responseData), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + } else if (url.pathname === "/sleep") { + envoy.sleepActor(actorId); + + return new Response("ok", { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + } + + return new Response("ok", { status: 200 }); + }, + onActorStart: async ( + _envoy: EnvoyHandle, + _actorId: string, + _generation: number, + _config: protocol.ActorConfig, + ) => { + getLogger().info( + `Actor ${_actorId} started (generation ${_generation})`, + ); + }, + onActorStop: async ( + _envoy: EnvoyHandle, + _actorId: string, + _generation: number, + reason: protocol.StopActorReason, + ) => { + getLogger().info( + `Actor ${_actorId} stopped (generation ${_generation})`, + ); + }, + onShutdown() { }, + websocket: async ( + envoy: EnvoyHandle, + actorId: string, + ws: WebSocket, + _gatewayId: ArrayBuffer, + _requestId: ArrayBuffer, + _request: Request, + ) => { + getLogger().info(`WebSocket connected for actor ${actorId}`); + + // Echo server - send back any messages received + ws.addEventListener("message", (event) => { + const data = event.data; + getLogger().info({ + msg: `WebSocket message from actor ${actorId}`, + data, + index: (event as any).rivetMessageIndex, + }); + + ws.send(`Echo: ${data}`); + + // Ack + const websocketId = Buffer.from( + (event as any).rivetRequestId, + ).toString("base64"); + websocketLastMsgIndexes.set( + websocketId, + (event as any).rivetMessageIndex, + ); + envoy.sendHibernatableWebSocketMessageAck( + (event as any).rivetGatewayId, + (event as any).rivetRequestId, + (event as any).rivetMessageIndex, + ); + }); + + ws.addEventListener("close", () => { + getLogger().info(`WebSocket closed for actor ${actorId}`); + }); + + ws.addEventListener("error", (error) => { + getLogger().error({ + msg: `WebSocket error for actor ${actorId}:`, + error, + }); + }); + }, + hibernatableWebSocket: { + canHibernate() { + return true; + }, + }, +}; // Create internal server const app = new Hono(); @@ -59,24 +175,26 @@ app.get("/shutdown", async (c) => { return c.text("ok"); }); -// TODO: -app.get("/api/rivet/start", async (c) => { - // return streamSSE(c, async (stream) => { - // const runnerStarted = Promise.withResolvers(); - // const runnerStopped = Promise.withResolvers(); - // const runner = await startRunner(runnerStarted, runnerStopped); +app.post("/api/rivet/start", async (c) => { + let payload = await c.req.arrayBuffer(); - // c.req.raw.signal.addEventListener("abort", () => { - // getLogger().debug("SSE aborted, shutting down runner"); - // runner!.shutdown(true); - // }); - - // await runnerStarted.promise; + return streamSSE(c, async (stream) => { + const stopped = Promise.withResolvers(); + const envoy = startEnvoySync({ + ...config, + serverlessStartPayload: payload, + onShutdown() { + stopped.resolve(); + } + }); - // stream.writeSSE({ data: runner.getServerlessInitPacket()! }); + c.req.raw.signal.addEventListener("abort", () => { + getLogger().debug("SSE aborted, shutting down runner"); + envoy!.shutdown(true); + }); - // await runnerStopped.promise; - // }); + await stopped.promise; + }); }); app.get("/api/rivet/metadata", async (c) => { @@ -84,6 +202,7 @@ app.get("/api/rivet/metadata", async (c) => { // Not actually rivetkit runtime: "rivetkit", version: "1", + envoyProtocolVersion: protocol.VERSION, }); }); @@ -98,123 +217,7 @@ if (AUTOSTART_SERVER) { } if (AUTOSTART_ENVOY) { - envoy = await startEnvoy({ - logger: getLogger(), - version: RIVET_ENVOY_VERSION, - endpoint: RIVET_ENDPOINT, - token: RIVET_TOKEN, - namespace: RIVET_NAMESPACE, - poolName: RIVET_POOL_NAME, - prepopulateActorNames: {}, - fetch: async ( - envoy: EnvoyHandle, - actorId: string, - _gatewayId: ArrayBuffer, - _requestId: ArrayBuffer, - request: Request, - ) => { - getLogger().info( - `Fetch called for actor ${actorId}, URL: ${request.url}`, - ); - const url = new URL(request.url); - if (url.pathname === "/ping") { - // Return the actor ID in response - const responseData = { - actorId, - status: "ok", - timestamp: Date.now(), - }; - - return new Response(JSON.stringify(responseData), { - status: 200, - headers: { "Content-Type": "application/json" }, - }); - } else if (url.pathname === "/sleep") { - envoy.sleepActor(actorId); - - return new Response("ok", { - status: 200, - headers: { "Content-Type": "application/json" }, - }); - } - - return new Response("ok", { status: 200 }); - }, - onActorStart: async ( - _envoy: EnvoyHandle, - _actorId: string, - _generation: number, - _config: protocol.ActorConfig, - ) => { - getLogger().info( - `Actor ${_actorId} started (generation ${_generation})`, - ); - }, - onActorStop: async ( - _envoy: EnvoyHandle, - _actorId: string, - _generation: number, - reason: protocol.StopActorReason, - ) => { - getLogger().info( - `Actor ${_actorId} stopped (generation ${_generation})`, - ); - }, - onShutdown() { }, - // TODO: - websocket: async ( - envoy: EnvoyHandle, - actorId: string, - ws: WebSocket, - _gatewayId: ArrayBuffer, - _requestId: ArrayBuffer, - _request: Request, - ) => { - // getLogger().info(`WebSocket connected for actor ${actorId}`); - - // // Echo server - send back any messages received - // ws.addEventListener("message", (event) => { - // const data = event.data; - // getLogger().info({ - // msg: `WebSocket message from actor ${actorId}`, - // data, - // index: (event as any).rivetMessageIndex, - // }); - - // ws.send(`Echo: ${data}`); - - // // Ack - // const websocketId = Buffer.from( - // (event as any).rivetRequestId, - // ).toString("base64"); - // websocketLastMsgIndexes.set( - // websocketId, - // (event as any).rivetMessageIndex, - // ); - // envoy.sendHibernatableWebSocketMessageAck( - // (event as any).rivetGatewayId, - // (event as any).rivetRequestId, - // (event as any).rivetMessageIndex, - // ); - // }); - - // ws.addEventListener("close", () => { - // getLogger().info(`WebSocket closed for actor ${actorId}`); - // }); - - // ws.addEventListener("error", (error) => { - // getLogger().error({ - // msg: `WebSocket error for actor ${actorId}:`, - // error, - // }); - // }); - }, - hibernatableWebSocket: { - canHibernate() { - return true; - }, - }, - }); + envoy = await startEnvoy(config); } else if (AUTOCONFIGURE_SERVERLESS) { await autoConfigureServerless(); } diff --git a/examples/hono/package.json b/examples/hono/package.json index 7b590878ac..09c42674e6 100644 --- a/examples/hono/package.json +++ b/examples/hono/package.json @@ -5,10 +5,8 @@ "type": "module", "scripts": { "dev": "tsx --watch src/index.ts", - "dev:server": "tsx --watch src/server.ts", "check-types": "tsc --noEmit", - "start": "tsx src/index.ts", - "start:server": "tsx src/server.ts" + "start": "tsx src/index.ts" }, "devDependencies": { "@types/node": "^22.13.9", @@ -16,10 +14,9 @@ "typescript": "^5.5.2" }, "dependencies": { - "@hono/node-server": "^1.19.12", "hono": "^4.7.0", "rivetkit": "^2.2.1" }, "stableVersion": "0.8.0", "license": "MIT" -} +} \ No newline at end of file diff --git a/examples/hono/src/server.ts b/examples/hono/src/server.ts index 7406f9ba51..eb6d6a7a8f 100644 --- a/examples/hono/src/server.ts +++ b/examples/hono/src/server.ts @@ -18,12 +18,3 @@ app.post("/increment/:name", async (c) => { }); export default app; - -// Start server when run directly -if (import.meta.url === `file://${process.argv[1]}`) { - const { serve } = await import("@hono/node-server"); - const port = 3000; - serve({ fetch: app.fetch, port }, () => - console.log(`Server running at http://localhost:${port}`), - ); -} diff --git a/rivetkit-asyncapi/asyncapi.json b/rivetkit-asyncapi/asyncapi.json index 96c8ee0784..aea3e81cbd 100644 --- a/rivetkit-asyncapi/asyncapi.json +++ b/rivetkit-asyncapi/asyncapi.json @@ -2,7 +2,7 @@ "asyncapi": "3.0.0", "info": { "title": "RivetKit WebSocket Protocol", - "version": "2.1.11-rc.1", + "version": "2.2.0", "description": "WebSocket protocol for bidirectional communication between RivetKit clients and actors" }, "channels": { diff --git a/rivetkit-openapi/openapi.json b/rivetkit-openapi/openapi.json index f1e1ec418d..ca482f752c 100644 --- a/rivetkit-openapi/openapi.json +++ b/rivetkit-openapi/openapi.json @@ -1,7 +1,7 @@ { "openapi": "3.0.0", "info": { - "version": "2.1.11-rc.1", + "version": "2.2.0", "title": "RivetKit API" }, "components": { diff --git a/rivetkit-typescript/packages/next-js/src/mod.ts b/rivetkit-typescript/packages/next-js/src/mod.ts index fbf568fcf0..e150bd3f42 100644 --- a/rivetkit-typescript/packages/next-js/src/mod.ts +++ b/rivetkit-typescript/packages/next-js/src/mod.ts @@ -1,14 +1,14 @@ import type { Registry } from "rivetkit"; import { logger } from "./log"; -// Runner version set to seconds since epoch when the module loads in development mode. +// Envoy version set to seconds since epoch when the module loads in development mode. // // This creates a version number that increments each time the code is updated // and the module reloads, allowing the engine to detect code changes via the -// /metadata endpoint and hot-reload all actors by draining older runners. +// /metadata endpoint and hot-reload all actors by draining older envoys. // -// We use seconds (not milliseconds) because the runner version is a u32 on the engine side. -const DEV_RUNNER_VERSION = Math.floor(Date.now() / 1000); +// We use seconds (not milliseconds) because the envoy version is a u32 on the engine side. +const DEV_ENVOY_VERSION = Math.floor(Date.now() / 1000); export const toNextHandler = (registry: Registry) => { // Don't run server locally since we're using the fetch handler directly @@ -44,10 +44,10 @@ export const toNextHandler = (registry: Registry) => { slotsPerRunner: 1, }; - // Set runner version to enable hot-reloading on code changes + // Set envoy version to enable hot-reloading on code changes registry.config.envoy = { ...registry.config.envoy, - version: DEV_RUNNER_VERSION, + version: DEV_ENVOY_VERSION, }; } else { logger().debug( diff --git a/rivetkit-typescript/packages/rivetkit/src/common/router.ts b/rivetkit-typescript/packages/rivetkit/src/common/router.ts index 3919f3572d..44ed5f723a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/router.ts @@ -1,6 +1,7 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, Next } from "hono"; import type { Encoding } from "@/actor/protocol/serde"; +import { protocol as envoyProtocol } from "@rivetkit/engine-envoy-client"; import { getRequestEncoding, getRequestExposeInternalError, @@ -103,7 +104,7 @@ export function handleRouteError(error: unknown, c: HonoContext) { return c.body(output as any, { status: statusCode }); } -export type MetadataRunnerKind = +export type MetadataEnvoyKind = | { serverless: Record } | { normal: Record }; @@ -113,17 +114,18 @@ export type MetadataRunnerKind = export interface MetadataResponse { runtime: string; version: string; - runner?: { - kind: MetadataRunnerKind; + envoy?: { + kind: MetadataEnvoyKind; version?: number; }; + envoyProtocolVersion: number, actorNames: ReturnType; /** - * Endpoint that the client should connect to to access this runner. + * Endpoint that the client should connect to to access this envoy. * * If defined, will override the endpoint the user has configured on startup. * - * This is helpful if attempting to connect to a serverless runner, so the serverless runner can define where the main endpoint lives. + * This is helpful if attempting to connect to a serverless envoy, so the serverless runner can define where the main endpoint lives. * * This is also helpful for setting up clean redirects as needed. **/ @@ -141,7 +143,7 @@ export interface MetadataResponse { export function handleMetadataRequest( c: HonoContext, config: RegistryConfig, - runnerKind: MetadataRunnerKind, + envoyKind: MetadataEnvoyKind, clientEndpoint: string | undefined, clientNamespace: string | undefined, clientToken: string | undefined, @@ -149,10 +151,11 @@ export function handleMetadataRequest( const response: MetadataResponse = { runtime: "rivetkit", version: VERSION, - runner: { - kind: runnerKind, - version: config.runner.version, + envoy: { + kind: envoyKind, + version: config.envoy.version, }, + envoyProtocolVersion: envoyProtocol.VERSION, actorNames: buildActorNames(config), clientEndpoint, clientNamespace, diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 04584677eb..67bf366670 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -1,7 +1,7 @@ import type { EnvoyConfig } from "@rivetkit/engine-envoy-client"; import type { ISqliteVfs } from "@rivetkit/sqlite-vfs"; import { SqliteVfsPoolManager } from "@/driver-helpers/sqlite-pool"; -import { protocol, utils, EnvoyHandle, startEnvoySync } from "@rivetkit/engine-envoy-client"; +import { type HibernatingWebSocketMetadata, protocol, utils, EnvoyHandle, startEnvoySync } from "@rivetkit/engine-envoy-client"; import * as cbor from "cbor-x"; import type { Context as HonoContext } from "hono"; import { streamSSE } from "hono/streaming"; @@ -190,11 +190,11 @@ export class EngineActorDriver implements ActorDriver { }; // Create and start envoy - const [envoy, startRx] = startEnvoySync(envoyConfig); + const envoy = startEnvoySync(envoyConfig); this.#envoy = envoy; - startRx.changed().then(() => { + envoy.started().then(() => { this.#envoyStarted.resolve(undefined); }); @@ -459,47 +459,46 @@ export class EngineActorDriver implements ActorDriver { } async serverlessHandleStart(c: HonoContext): Promise { + let payload = await c.req.arrayBuffer(); + return streamSSE(c, async (stream) => { - // TODO: - // // NOTE: onAbort does not work reliably - // stream.onAbort(() => { }); - // c.req.raw.signal.addEventListener("abort", () => { - // logger().debug("SSE aborted, shutting down runner"); - - // // We cannot assume that the request will always be closed gracefully by Rivet. We always proceed with a graceful shutdown in case the request was terminated for any other reason. - // // - // // If we did not use a graceful shutdown, the runner would - // this.shutdownRunner(false); - // }); - - // await this.#envoyStarted.promise; - - // // Runner id should be set if the runner started - // const payload = this.#envoy.getServerlessInitPacket(); - // invariant(payload, "runnerId not set"); - // await stream.writeSSE({ data: payload }); - - // // Send ping every second to keep the connection alive - // while (true) { - // if (this.#isRunnerStopped) { - // logger().debug({ - // msg: "runner is stopped", - // }); - // break; - // } - - // if (stream.closed || stream.aborted) { - // logger().debug({ - // msg: "runner sse stream closed", - // closed: stream.closed, - // aborted: stream.aborted, - // }); - // break; - // } - - // await stream.writeSSE({ event: "ping", data: "" }); - // await stream.sleep(RUNNER_SSE_PING_INTERVAL); - // } + // NOTE: onAbort does not work reliably + stream.onAbort(() => { }); + c.req.raw.signal.addEventListener("abort", () => { + logger().debug("SSE aborted, shutting down runner"); + + // We cannot assume that the request will always be closed gracefully by Rivet. We always proceed with a graceful shutdown in case the request was terminated for any other reason. + // + // If we did not use a graceful shutdown, the runner would + this.shutdown(false); + }); + + await this.#envoyStarted.promise; + + // Runner id should be set if the runner started + this.#envoy.startServerless(payload); + + // Send ping every second to keep the connection alive + while (true) { + if (this.#isEnvoyStopped) { + logger().debug({ + msg: "envoy is stopped", + }); + break; + } + + if (stream.closed || stream.aborted) { + logger().debug({ + msg: "envoy sse stream closed", + closed: stream.closed, + aborted: stream.aborted, + }); + break; + } + + await stream.writeSSE({ event: "ping", data: "" }); + await stream.sleep(ENVOY_SSE_PING_INTERVAL); + } // Wait for the runner to stop if the SSE stream aborted early for any reason await this.#envoyStopped.promise; @@ -725,7 +724,7 @@ export class EngineActorDriver implements ActorDriver { }); try { - this.#envoy.stopActor(actorId); + this.#envoy.stopActor(actorId, undefined, stringifyError(error)); } catch (stopError) { logger().debug({ msg: "failed to stop actor after start failure", @@ -1130,28 +1129,28 @@ export class EngineActorDriver implements ActorDriver { } } - // async #hwsLoadAll( - // actorId: string, - // ): Promise { - // const actor = await this.loadActor(actorId); - // return actor.conns - // .values() - // .map((conn) => { - // const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; - // const hibernatable = connStateManager.hibernatableData; - // if (!hibernatable) return undefined; - // return { - // gatewayId: hibernatable.gatewayId, - // requestId: hibernatable.requestId, - // serverMessageIndex: hibernatable.serverMessageIndex, - // clientMessageIndex: hibernatable.clientMessageIndex, - // path: hibernatable.requestPath, - // headers: hibernatable.requestHeaders, - // } satisfies HibernatingWebSocketMetadata; - // }) - // .filter((x) => x !== undefined) - // .toArray(); - // } + async #hwsLoadAll( + actorId: string, + ): Promise { + const actor = await this.loadActor(actorId); + return actor.conns + .values() + .map((conn) => { + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + const hibernatable = connStateManager.hibernatableData; + if (!hibernatable) return undefined; + return { + gatewayId: hibernatable.gatewayId, + requestId: hibernatable.requestId, + rivetMessageIndex: hibernatable.serverMessageIndex, + envoyMessageIndex: hibernatable.clientMessageIndex, + path: hibernatable.requestPath, + headers: hibernatable.requestHeaders, + } satisfies HibernatingWebSocketMetadata; + }) + .filter((x) => x !== undefined) + .toArray(); + } async onBeforeActorStart(actor: AnyActorInstance): Promise { // Resolve promise if waiting @@ -1161,10 +1160,9 @@ export class EngineActorDriver implements ActorDriver { handler.actorStartPromise?.resolve(); handler.actorStartPromise = undefined; - // TODO: - // // Restore hibernating requests - // const metaEntries = await this.#hwsLoadAll(actor.id); - // await this.#envoy.restoreHibernatingRequests(actor.id, metaEntries); + // Restore hibernating requests + const metaEntries = await this.#hwsLoadAll(actor.id); + await this.#envoy.restoreHibernatingRequests(actor.id, metaEntries); } onCreateConn(conn: AnyConn) { @@ -1231,12 +1229,11 @@ export class EngineActorDriver implements ActorDriver { entry.pendingAckFromMessageIndex || entry.pendingAckFromBufferSize ) { - // TODO: - // this.#envoy.sendHibernatableWebSocketMessageAck( - // hibernatable.gatewayId, - // hibernatable.requestId, - // entry.serverMessageIndex, - // ); + this.#envoy.sendHibernatableWebSocketMessageAck( + hibernatable.gatewayId, + hibernatable.requestId, + entry.serverMessageIndex, + ); entry.pendingAckFromMessageIndex = false; entry.pendingAckFromBufferSize = false; entry.bufferedMessageSize = 0; diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/router-schema.ts b/rivetkit-typescript/packages/rivetkit/src/manager/router-schema.ts index cc29e62db6..6fe8d6ec8a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/router-schema.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/router-schema.ts @@ -7,14 +7,8 @@ export const ServerlessStartHeadersSchema = z.object({ token: z .string({ error: "x-rivet-token header must be a string" }) .optional(), - totalSlots: z.coerce - .number({ - error: "x-rivet-total-slots header must be a number", - }) - .int({ error: "x-rivet-total-slots header must be an integer" }) - .gte(1, { error: "x-rivet-total-slots header must be positive" }), - runnerName: z.string({ - error: "x-rivet-runner-name header is required", + poolName: z.string({ + error: "x-rivet-pool-name header is required", }), namespace: z.string({ error: "x-rivet-namespace-name header is required", diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts index 4767396a38..c05ede00b9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts @@ -95,9 +95,9 @@ export function buildManagerRouter( const actorIdsParsed = actor_ids ? actor_ids - .split(",") - .map((id) => id.trim()) - .filter((id) => id.length > 0) + .split(",") + .map((id) => id.trim()) + .filter((id) => id.length > 0) : undefined; const actors: ActorOutput[] = []; diff --git a/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts b/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts index 669b212843..a4ad2c5ba0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/serverless/router.ts @@ -26,29 +26,27 @@ export function buildServerlessRouter( }); // Serverless start endpoint - router.get("/start", async (c) => { + router.post("/start", async (c) => { // Parse headers const parseResult = ServerlessStartHeadersSchema.safeParse({ endpoint: c.req.header("x-rivet-endpoint"), token: c.req.header("x-rivet-token") ?? undefined, - totalSlots: c.req.header("x-rivet-total-slots"), - runnerName: c.req.header("x-rivet-runner-name"), + poolName: c.req.header("x-rivet-pool-name"), namespace: c.req.header("x-rivet-namespace-name"), }); if (!parseResult.success) { throw new InvalidRequest( parseResult.error.issues[0]?.message ?? - "invalid serverless start headers", + "invalid serverless start headers", ); } - const { endpoint, token, totalSlots, runnerName, namespace } = + const { endpoint, token, poolName, namespace } = parseResult.data; logger().debug({ msg: "received serverless runner start request", endpoint, - totalSlots, - runnerName, + poolName, namespace, }); @@ -72,10 +70,9 @@ export function buildServerlessRouter( ...config, endpoint, namespace, - runner: { - ...config.runner, - totalSlots, - runnerName, + envoy: { + ...config.envoy, + poolName, }, }; const runnerConfig: RegistryConfig = { diff --git a/scripts/tests/actor_e2e.ts b/scripts/tests/actor_e2e.ts index e62e040b66..a97dbf7350 100755 --- a/scripts/tests/actor_e2e.ts +++ b/scripts/tests/actor_e2e.ts @@ -64,7 +64,7 @@ async function main() { await actorPingResponse2.text(); console.timeEnd("ping 2"); - // await testWebSocket(actorResponse.actor.actor_id); + await testWebSocket(actorResponse.actor.actor_id); } catch (error) { console.error(`Actor test failed:`, error); } finally { diff --git a/shared/typescript/virtual-websocket/src/interface.ts b/shared/typescript/virtual-websocket/src/interface.ts index 6268f159f5..9101c20cf1 100644 --- a/shared/typescript/virtual-websocket/src/interface.ts +++ b/shared/typescript/virtual-websocket/src/interface.ts @@ -5,7 +5,12 @@ export interface RivetEvent { currentTarget?: any; /** * @experimental - * Request ID for hibernatable websockets (provided by engine runner) + * Gateway ID for hibernatable websockets (provided by engine envoy) + **/ + rivetGatewayId?: ArrayBuffer; + /** + * @experimental + * Request ID for hibernatable websockets (provided by engine envoy) **/ rivetRequestId?: ArrayBuffer; } @@ -14,7 +19,7 @@ export interface RivetMessageEvent extends RivetEvent { data: any; /** * @experimental - * Message index for hibernatable websockets (provided by engine runner) + * Message index for hibernatable websockets (provided by engine envoy) **/ rivetMessageIndex?: number; }