Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions crates/bashkit/src/scripted_tool/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::tool::{
use crate::tool_def::usage_from_schema;
use async_trait::async_trait;
use schemars::schema_for;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};

// ============================================================================
Expand Down Expand Up @@ -113,7 +114,7 @@ impl ScriptedTool {
};
}

let log: InvocationLog = Arc::new(Mutex::new(Vec::new()));
let log: InvocationLog = Arc::new(Mutex::new(VecDeque::new()));
let mut bash = self.create_bash(Arc::clone(&log));

let response = if let Some(sender) = stream_sender {
Expand Down Expand Up @@ -146,10 +147,12 @@ impl ScriptedTool {
..Default::default()
},
};
let invocations = log
let invocations: Vec<_> = log
.lock()
.expect("scripted invocation log poisoned")
.clone();
.iter()
.cloned()
.collect();
self.store_last_execution_trace(ScriptedExecutionTrace { invocations });
response
}
Expand Down
49 changes: 41 additions & 8 deletions crates/bashkit/src/scripted_tool/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ use crate::error::Result;
use crate::interpreter::ExecResult;
use crate::tool_def::{parse_flags, usage_from_schema};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::future::Future;
use std::sync::{Arc, Mutex};

pub(crate) type InvocationLog = Arc<Mutex<Vec<ScriptedCommandInvocation>>>;
pub(crate) type InvocationLog = Arc<Mutex<VecDeque<ScriptedCommandInvocation>>>;
const MAX_LOG_ENTRIES: usize = 256;
const MAX_LOG_ARG_BYTES: usize = 1024;

fn push_invocation(
log: &InvocationLog,
Expand All @@ -22,28 +25,48 @@ fn push_invocation(
args: &[String],
exit_code: i32,
) {
let args = truncate_args(args);
let mut invocations = log.lock().expect("tool-def invocation log poisoned");
invocations.push(ScriptedCommandInvocation {
if invocations.len() == MAX_LOG_ENTRIES {
invocations.pop_front();
}
invocations.push_back(ScriptedCommandInvocation {
name: name.to_string(),
kind,
args: args.to_vec(),
args,
exit_code,
});
}

fn truncate_args(args: &[String]) -> Vec<String> {
args.iter().map(|arg| truncate_arg(arg)).collect()
}

fn truncate_arg(arg: &str) -> String {
if arg.len() <= MAX_LOG_ARG_BYTES {
return arg.to_string();
}
// Byte-aware truncation that respects UTF-8 char boundaries.
let cut = arg
.char_indices()
.map(|(i, c)| i + c.len_utf8())
.take_while(|&end| end <= MAX_LOG_ARG_BYTES)
.last()
.unwrap_or(0);
arg[..cut].to_string()
}

/// Builder for [`ToolDefExtension`].
pub struct ToolDefExtensionBuilder {
tools: Vec<RegisteredTool>,
sanitize_errors: bool,
invocation_log: InvocationLog,
}

impl Default for ToolDefExtensionBuilder {
fn default() -> Self {
Self {
tools: Vec::new(),
sanitize_errors: true,
invocation_log: Arc::new(Mutex::new(Vec::new())),
}
}
}
Expand Down Expand Up @@ -105,16 +128,26 @@ impl ToolDefExtensionBuilder {
}

/// Build the extension.
///
/// Each call mints a fresh, isolated invocation log. Clones of the
/// returned extension share the log with the original — keep a clone
/// before passing the extension to a `Bash` if you intend to call
/// [`ToolDefExtension::take_invocations`] later.
pub fn build(&self) -> ToolDefExtension {
ToolDefExtension {
tools: self.tools.clone(),
sanitize_errors: self.sanitize_errors,
invocation_log: Arc::clone(&self.invocation_log),
invocation_log: Arc::new(Mutex::new(VecDeque::new())),
}
}
}

/// Bash extension that registers ToolDef-backed commands plus `help` and `discover`.
///
/// Each [`ToolDefExtensionBuilder::build`] mints a fresh invocation log, so
/// distinct builds (e.g. per tenant) never share traces. Cloning shares the
/// log with the original — that is the supported pattern for retaining a
/// `take_invocations` handle after passing the extension to a `Bash`.
#[derive(Clone)]
pub struct ToolDefExtension {
tools: Vec<RegisteredTool>,
Expand All @@ -132,7 +165,7 @@ impl ToolDefExtension {
Self {
tools,
sanitize_errors: true,
invocation_log: Arc::new(Mutex::new(Vec::new())),
invocation_log: Arc::new(Mutex::new(VecDeque::new())),
}
}

Expand All @@ -153,7 +186,7 @@ impl ToolDefExtension {
.invocation_log
.lock()
.expect("tool-def invocation log poisoned");
std::mem::take(&mut *invocations)
std::mem::take(&mut *invocations).into()
}

fn snapshots(&self) -> Vec<ToolDefSnapshot> {
Expand Down
110 changes: 110 additions & 0 deletions crates/bashkit/src/scripted_tool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,116 @@ mod tests {
assert!(result.stdout.contains(r#""id":7"#));
}

#[tokio::test]
async fn test_tool_def_extension_builds_have_isolated_invocation_logs() {
// Two separate `build()` calls must NOT share traces — that is the
// cross-tenant isolation contract.
let builder = ToolDefExtension::builder()
.tool_fn(ToolDef::new("echo_arg", "Echo"), |args: &ToolArgs| {
Ok(format!("{}\n", args.param_str("msg").unwrap_or_default()))
});
let ext_a = builder.build();
let ext_b = builder.build();
let handle_a = ext_a.clone();
let handle_b = ext_b.clone();

let mut bash_a = crate::Bash::builder().extension(ext_a).build();
let mut bash_b = crate::Bash::builder().extension(ext_b).build();
bash_a
.exec("echo_arg --msg alpha")
.await
.expect("bash a should execute");
bash_b
.exec("echo_arg --msg beta")
.await
.expect("bash b should execute");

let trace_a = handle_a.take_invocations();
let trace_b = handle_b.take_invocations();
assert_eq!(trace_a.len(), 1);
assert_eq!(
trace_a[0].args,
vec!["--msg".to_string(), "alpha".to_string()]
);
assert_eq!(trace_b.len(), 1);
assert_eq!(
trace_b[0].args,
vec!["--msg".to_string(), "beta".to_string()]
);
}

#[tokio::test]
async fn test_tool_def_extension_clones_share_invocation_log() {
// Cloning is the supported way to retain a `take_invocations` handle
// after passing the extension to a `Bash`.
let extension = ToolDefExtension::builder()
.tool_fn(ToolDef::new("echo_arg", "Echo"), |args: &ToolArgs| {
Ok(format!("{}\n", args.param_str("msg").unwrap_or_default()))
})
.build();
let handle = extension.clone();

let mut bash = crate::Bash::builder().extension(extension).build();
bash.exec("echo_arg --msg gamma")
.await
.expect("bash should execute");

let trace = handle.take_invocations();
assert_eq!(trace.len(), 1);
assert_eq!(
trace[0].args,
vec!["--msg".to_string(), "gamma".to_string()]
);
}

#[tokio::test]
async fn test_tool_def_extension_invocations_are_bounded_and_truncated() {
let extension = ToolDefExtension::builder()
.tool_fn(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| {
Ok("ok\n".to_string())
})
.build();
let handle = extension.clone();
let mut bash = crate::Bash::builder().extension(extension).build();

for _ in 0..300 {
let cmd = format!("noop --msg {}", "x".repeat(1500));
bash.exec(&cmd).await.expect("noop should execute");
}
let trace = handle.take_invocations();
assert_eq!(trace.len(), 256, "log must be capped at MAX_LOG_ENTRIES");
assert_eq!(
trace[0].args[1].len(),
1024,
"long argv tokens must be truncated to MAX_LOG_ARG_BYTES"
);
}

#[tokio::test]
async fn test_tool_def_extension_truncation_is_byte_aware_for_utf8() {
// 4-byte UTF-8 codepoint (U+1F600 😀); 400 of them = 1600 bytes,
// well over the 1024 cap. Truncation must (a) cap by bytes,
// not chars, and (b) leave a valid UTF-8 string.
let extension = ToolDefExtension::builder()
.tool_fn(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| {
Ok("ok\n".to_string())
})
.build();
let handle = extension.clone();
let mut bash = crate::Bash::builder().extension(extension).build();

let big = "\u{1F600}".repeat(400);
let cmd = format!("noop --msg {}", big);
bash.exec(&cmd).await.expect("noop should execute");

let trace = handle.take_invocations();
assert_eq!(trace.len(), 1);
let truncated = &trace[0].args[1];
assert!(truncated.len() <= 1024, "byte length must respect cap");
// Each codepoint is 4 bytes; 1024 / 4 = 256 codepoints fit.
assert_eq!(truncated.chars().count(), 256);
}

// -- Issue #1278: --help flag tests --

#[tokio::test]
Expand Down
Loading