Skip to content

[ENH]: fix high latency & response errors of frontend -> query service calls during rollout #5316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

codetheweb
Copy link
Contributor

@codetheweb codetheweb commented Aug 19, 2025

Description of changes

We've observed that during rollouts of query service pods the frontend frequently returns errors to clients (originating from the query service) and in-flight calls to a query service pod being shut down can block for 30s+ before the frontend realizes the connection is broken.

I was able to reproduce this locally by running a script that produces concurrent query load and executing kubectl scale --replicas=1 -n chroma statefulset query-service (scaling from 2 replicas to 1) while the script was running:

Screenshot 2025-08-19 at 15 55 55

The script output shows several queries errored and a few took >20s. For the queries that take >20s, the pattern (as seen in staging/prod) is that the frontend tries twice to make a request to the query service. The first request takes 20-30s before erroring with a variant of a disconnect error. The second attempt succeeds. To be honest, this behavior doesn't completely make sense to me--based on the little documentation I could find, it seems like tonic/hyper is supposed to send clients a GOAWAY frame during server shutdown which should immediately result in an error on the client. Regardless, even if clients immediately errored, there is still the possibility that the client exhausts its retry budget by only retrying against servers that have been shutdown.

Edit: chased down the client disconnect issue
  1. I validated that the server sends GOAWAY frames to all open streams using RUST_LOG=trace.
  2. I validated that the client receives the GOAWAY frames, again by adjusting the log level.
  3. I combed through issues & discussions in the tonic repo. There are quite a few open issues related to connection recovery/disconnection/shutdown. I believe this is the root cause of the behavior we see.

This PR aims to fix the most common cause of these issues by giving the memberlist time to propagate & update on clients before terminating the query service pod. In other words, a pod that is scheduled to shut down will be removed from the memberlist but stay alive for N seconds to allow time for existing connections to drain and clients to update their local memberlist state. The same script after the changes in this PR:

Screenshot 2025-08-19 at 16 07 26

There are still some failure cases:

  1. The query service pod crashes. The memberlist won't be updated before the pod crashes.
  2. The kubernetes pre stop timeout is smaller than the query timeout. A query service pod may be shut down while a frontend is still waiting for a response. This is avoidable with the proper config values.
  3. SysDb doesn't update the memberlist. This may be a bit tricky to handle. We need to make sure that the SysDb leader is updated before any other pods are rolled (or that it is updated after all other pods are rolled).

Test plan

How are these changes tested?

Script used to test querying during scale up/down
#!/usr/bin/env python3
# /// script
# dependencies = [
#   "chromadb",
#   "numpy",
#   "rich",
# ]
# ///

"""
Chroma Concurrent Load Testing Script

This script creates Chroma collections with documents and runs continuous
concurrent get operations to generate load, reporting latency percentiles every 5 seconds.
"""

import argparse
import concurrent.futures
import queue
import random
import string
import threading
import time
from typing import List, Dict, Any, Deque
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from rich.console import Console
from rich.table import Table
from rich.progress import track
from rich import print as rprint

import chromadb
from chromadb.config import Settings
import numpy as np


@dataclass
class QueryMetric:
    """Stores individual query metrics"""
    timestamp: float
    latency_ms: float
    success: bool
    error: str = None


class ChromaLoadTester:
    """Load testing for Chroma concurrent get operations"""

    def __init__(self,
                 client_type: str = "ephemeral",
                 collection_name_prefix: str = "load_test_collection",
                 num_collections: int = 1,
                 num_documents: int = 100,
                 embedding_dim: int = 384):
        """
        Initialize the load tester.

        Args:
            client_type: Type of Chroma client ("ephemeral" or "persistent")
            collection_name_prefix: Prefix for collection names
            num_collections: Number of collections to create
            num_documents: Number of documents to generate per collection
            embedding_dim: Dimension of embeddings (default 384 for all-MiniLM-L6-v2)
        """
        self.console = Console()
        self.collection_name_prefix = collection_name_prefix
        self.num_collections = num_collections
        self.num_documents = num_documents
        self.embedding_dim = embedding_dim

        # Metrics storage - using deque for efficient windowed metrics
        self.metrics_window: Deque[QueryMetric] = deque(maxlen=10000)
        self.all_metrics: List[QueryMetric] = []
        self.metrics_lock = threading.Lock()
        self.start_time = time.time()
        self.total_queries = 0
        self.total_errors = 0

        # Collection-specific metrics
        self.collection_metrics = {}

        # Initialize Chroma client
        self.client = chromadb.HttpClient()

        # Initialize collections list
        self.collections = []
        self.collection_names = []

        # Create multiple collections
        for i in range(num_collections):
            collection_name = f"{collection_name_prefix}_{i:03d}"
            self.collection_names.append(collection_name)

            # Reset collection if it exists
            try:
                self.client.delete_collection(collection_name)
            except:
                pass

            # Create new collection
            collection = self.client.create_collection(
                name=collection_name,
                metadata={"hnsw:space": "cosine"}
            )
            self.collections.append(collection)

            # Initialize collection-specific metrics
            self.collection_metrics[collection_name] = {
                'queries': 0,
                'errors': 0,
                'total_latency_ms': 0.0
            }

        self.console.print(f"[green]✓[/green] Created {num_collections} collections: {', '.join(self.collection_names)}")

    def generate_sample_documents(self) -> tuple[List[str], List[str], List[Dict], List[List[float]]]:
        """
        Generate sample documents with metadata and embeddings.

        Returns:
            Tuple of (ids, documents, metadatas, embeddings)
        """
        self.console.print(f"[yellow]Generating {self.num_documents} sample documents...[/yellow]")

        # Sample topics for realistic document generation
        topics = [
            "machine learning", "artificial intelligence", "data science",
            "software engineering", "cloud computing", "cybersecurity",
            "blockchain", "quantum computing", "robotics", "IoT",
            "natural language processing", "computer vision", "big data",
            "DevOps", "microservices", "edge computing", "5G technology",
            "augmented reality", "virtual reality", "digital transformation"
        ]

        ids = []
        documents = []
        metadatas = []
        embeddings = []

        for i in track(range(self.num_documents), description="Generating documents"):
            # Generate ID
            doc_id = f"doc_{i:04d}"
            ids.append(doc_id)

            # Generate document content
            topic = random.choice(topics)
            doc_type = random.choice(["article", "blog", "paper", "tutorial", "review"])

            document = f"""
            {doc_type.title()} about {topic}:
            This is a comprehensive {doc_type} discussing various aspects of {topic}.
            It covers fundamental concepts, recent developments, and future directions.
            The content includes practical examples and real-world applications.
            Document ID: {doc_id} | Topic: {topic} | Type: {doc_type}
            Additional keywords: {''.join(random.choices(string.ascii_lowercase + ' ', k=50))}
            """
            documents.append(document.strip())

            # Generate metadata
            metadata = {
                "topic": topic,
                "type": doc_type,
                "index": i,
                "created_at": time.time(),
                "word_count": len(document.split()),
                "category": random.choice(["technical", "educational", "research"])
            }
            metadatas.append(metadata)

            # Generate random embedding
            base_embedding = np.random.randn(self.embedding_dim)
            topic_offset = hash(topic) % 100 / 100.0
            base_embedding[:10] += topic_offset
            embedding = (base_embedding / np.linalg.norm(base_embedding)).tolist()
            embeddings.append(embedding)

        return ids, documents, metadatas, embeddings

    def populate_collections(self):
        """Add documents to all collections in batches"""
        self.console.print(f"[yellow]Populating {self.num_collections} collections with {self.num_documents} documents each...[/yellow]")

        # Populate each collection
        for i, (collection, collection_name) in enumerate(zip(self.collections, self.collection_names)):
            self.console.print(f"[yellow]Populating collection {i+1}/{self.num_collections}: {collection_name}[/yellow]")

            ids, documents, metadatas, embeddings = self.generate_sample_documents()

            # Add documents in batches
            batch_size = 10
            total_batches = len(ids) // batch_size + (1 if len(ids) % batch_size else 0)

            for j in track(range(0, len(ids), batch_size), description=f"Adding batches to {collection_name}"):
                end_idx = min(j + batch_size, len(ids))
                collection.add(
                    ids=[f"{collection_name}_{id}" for id in ids[j:end_idx]],  # Unique IDs per collection
                    documents=documents[j:end_idx],
                    metadatas=metadatas[j:end_idx],
                    embeddings=embeddings[j:end_idx]
                )

            self.console.print(f"[green]✓[/green] Added {len(ids)} documents to {collection_name}")

        # Store sample documents for querying (from last collection)
        if hasattr(self, 'collections') and self.collections:
            _, documents, _, _ = self.generate_sample_documents()
            self.sample_documents = documents[:20]

        self.console.print(f"[green]✓[/green] All {self.num_collections} collections populated successfully")

    def generate_document_id_batches(self, num_batches: int) -> List[List[str]]:
        """Generate batches of document IDs for get operations"""
        id_batches = []

        for i in range(num_batches):
            # Generate a batch of 1-5 unique document IDs
            batch_size = random.randint(1, 5)
            batch = set()  # Use set to ensure uniqueness

            while len(batch) < batch_size:
                # Randomly select a collection and document
                collection_idx = random.randint(0, self.num_collections - 1)
                collection_name = self.collection_names[collection_idx]
                doc_idx = random.randint(0, self.num_documents - 1)
                document_id = f"{collection_name}_doc_{doc_idx:04d}"
                batch.add(document_id)

            id_batches.append(list(batch))

        return id_batches

    def execute_get(self, document_ids: List[str], collection_idx: int = None) -> QueryMetric:
        """Execute a single get operation and record metrics"""
        start_time = time.perf_counter()

        # Select collection (round-robin if not specified)
        if collection_idx is None:
            collection_idx = random.randint(0, len(self.collections) - 1)
        else:
            collection_idx = collection_idx % len(self.collections)

        collection = self.collections[collection_idx]
        collection_name = self.collection_names[collection_idx]

        try:
            results = collection.get(
                ids=document_ids,
                include=["documents", "metadatas", "embeddings"]
            )

            latency_ms = (time.perf_counter() - start_time) * 1000

            # Update collection-specific metrics
            with self.metrics_lock:
                self.collection_metrics[collection_name]['queries'] += 1
                self.collection_metrics[collection_name]['total_latency_ms'] += latency_ms

            return QueryMetric(
                timestamp=time.time(),
                latency_ms=latency_ms,
                success=True
            )
        except Exception as e:
            latency_ms = (time.perf_counter() - start_time) * 1000

            # Update collection-specific metrics
            with self.metrics_lock:
                self.collection_metrics[collection_name]['queries'] += 1
                self.collection_metrics[collection_name]['errors'] += 1
                self.collection_metrics[collection_name]['total_latency_ms'] += latency_ms

            return QueryMetric(
                timestamp=time.time(),
                latency_ms=latency_ms,
                success=False,
                error=f"[{collection_name}] {str(e)}"
            )

    def get_worker(self, worker_id: int, get_queue: queue.Queue, stop_event: threading.Event):
        """Worker thread that continuously processes get operations from the queue"""
        get_count = 0
        while not stop_event.is_set():
            try:
                # Get document IDs with timeout to check stop event periodically
                document_ids = get_queue.get(timeout=0.1)

                # Distribute gets across collections (round-robin per worker)
                collection_idx = (worker_id + get_count) % len(self.collections)

                # Execute get operation and record metrics
                metric = self.execute_get(document_ids, collection_idx)
                get_count += 1

                # Update metrics (thread-safe)
                with self.metrics_lock:
                    self.metrics_window.append(metric)
                    self.all_metrics.append(metric)
                    self.total_queries += 1
                    if not metric.success:
                        self.total_errors += 1

            except queue.Empty:
                continue
            except Exception as e:
                self.console.print(f"[red]Worker {worker_id} error: {e}[/red]")

    def calculate_percentiles(self, metrics: List[QueryMetric], window_seconds: float = None) -> Dict[str, float]:
        """Calculate latency percentiles for given metrics"""
        if not metrics:
            return {"p50": 0, "p95": 0, "p99": 0, "p999": 0, "mean": 0, "max": 0, "min": 0}

        # Filter by time window if specified
        if window_seconds:
            cutoff_time = time.time() - window_seconds
            metrics = [m for m in metrics if m.timestamp >= cutoff_time]

        if not metrics:
            return {"p50": 0, "p95": 0, "p99": 0, "p999": 0, "mean": 0, "max": 0, "min": 0}

        # Extract successful query latencies
        latencies = [m.latency_ms for m in metrics if m.success]

        if not latencies:
            return {"p50": 0, "p95": 0, "p99": 0, "p999": 0, "mean": 0, "max": 0, "min": 0}

        latencies.sort()

        return {
            "p50": np.percentile(latencies, 50),
            "p95": np.percentile(latencies, 95),
            "p99": np.percentile(latencies, 99),
            "p999": np.percentile(latencies, 99.9),
            "mean": np.mean(latencies),
            "max": max(latencies),
            "min": min(latencies)
        }

    def format_stats_table(self) -> Table:
        """Create a formatted table with current statistics"""
        with self.metrics_lock:
            # Make copies to avoid holding lock too long
            metrics_window_copy = list(self.metrics_window)
            all_metrics_copy = list(self.all_metrics)
            total_queries = self.total_queries
            total_errors = self.total_errors

        # Calculate metrics for different windows
        last_5s_metrics = [m for m in metrics_window_copy if m.timestamp >= time.time() - 5]
        last_30s_metrics = [m for m in metrics_window_copy if m.timestamp >= time.time() - 30]

        # Calculate percentiles
        last_5s_stats = self.calculate_percentiles(last_5s_metrics)
        last_30s_stats = self.calculate_percentiles(last_30s_metrics)
        all_time_stats = self.calculate_percentiles(all_metrics_copy)

        # Calculate rates
        runtime = time.time() - self.start_time
        qps_total = total_queries / runtime if runtime > 0 else 0

        last_5s_count = len(last_5s_metrics)
        qps_5s = last_5s_count / 5 if last_5s_count > 0 else 0

        last_5s_errors = sum(1 for m in last_5s_metrics if not m.success)
        error_rate_5s = (last_5s_errors / last_5s_count * 100) if last_5s_count > 0 else 0

        # Create table
        table = Table(title=f"Load Test Statistics - {datetime.now().strftime('%H:%M:%S')}")
        table.add_column("Metric", style="cyan")
        table.add_column("Last 5s", style="yellow")
        table.add_column("Last 30s", style="magenta")
        table.add_column("All Time", style="green")

        # Add latency rows
        table.add_row("P50 (ms)", f"{last_5s_stats['p50']:.2f}", f"{last_30s_stats['p50']:.2f}", f"{all_time_stats['p50']:.2f}")
        table.add_row("P95 (ms)", f"{last_5s_stats['p95']:.2f}", f"{last_30s_stats['p95']:.2f}", f"{all_time_stats['p95']:.2f}")
        table.add_row("P99 (ms)", f"{last_5s_stats['p99']:.2f}", f"{last_30s_stats['p99']:.2f}", f"{all_time_stats['p99']:.2f}")
        table.add_row("P99.9 (ms)", f"{last_5s_stats['p999']:.2f}", f"{last_30s_stats['p999']:.2f}", f"{all_time_stats['p999']:.2f}")
        table.add_row("Mean (ms)", f"{last_5s_stats['mean']:.2f}", f"{last_30s_stats['mean']:.2f}", f"{all_time_stats['mean']:.2f}")
        table.add_row("Max (ms)", f"{last_5s_stats['max']:.2f}", f"{last_30s_stats['max']:.2f}", f"{all_time_stats['max']:.2f}")

        # Add throughput and error rows
        table.add_row("", "", "", "")  # Empty row for spacing
        table.add_row("GPS (Gets/sec)", f"{qps_5s:.1f}", "-", f"{qps_total:.1f}")
        table.add_row("Total Gets", f"{last_5s_count}", f"{len(last_30s_metrics)}", f"{total_queries}")
        table.add_row("Error Rate", f"{error_rate_5s:.1f}%", "-", f"{(total_errors/total_queries*100) if total_queries > 0 else 0:.1f}%")

        return table

    def get_error_summary(self, metrics: List[QueryMetric], top_n: int = 5) -> Dict[str, int]:
        """Get summary of errors from metrics"""
        error_counts = {}
        for metric in metrics:
            if not metric.success and metric.error:
                # Simplify error message for grouping
                error_msg = metric.error.strip()
                # Truncate very long errors
                if len(error_msg) > 200:
                    error_msg = error_msg[:200] + "..."
                error_counts[error_msg] = error_counts.get(error_msg, 0) + 1

        # Sort by count and return top N
        sorted_errors = sorted(error_counts.items(), key=lambda x: x[1], reverse=True)
        return dict(sorted_errors[:top_n])

    def log_percentiles(self):
        """Log percentiles to console in a simple format"""
        with self.metrics_lock:
            metrics_window_copy = list(self.metrics_window)
            collection_metrics_copy = dict(self.collection_metrics)

        last_5s_metrics = [m for m in metrics_window_copy if m.timestamp >= time.time() - 5]

        if not last_5s_metrics:
            return

        stats = self.calculate_percentiles(last_5s_metrics)
        qps = len(last_5s_metrics) / 5
        error_metrics = [m for m in last_5s_metrics if not m.success]
        error_count = len(error_metrics)

        timestamp = datetime.now().strftime('%H:%M:%S')

        # Main metrics line
        self.console.print(
            f"[{timestamp}] Total GPS (Gets/sec): {qps:.1f} | "
            f"P50: {stats['p50']:.1f}ms | "
            f"P95: {stats['p95']:.1f}ms | "
            f"P99: {stats['p99']:.1f}ms | "
            f"Errors: {error_count}"
        )

        # Show top errors if any
        if error_count > 0:
            top_errors = self.get_error_summary(error_metrics, top_n=5)
            self.console.print("[red]Top Errors:[/red]")
            for i, (error_msg, count) in enumerate(top_errors.items(), 1):
                self.console.print(f"  {i}. ({count}x) {error_msg}")

    def run_load_test(self,
                     concurrency: int = 10,
                     target_gps: int = None,
                     duration: int = None,
                     reporting_interval: int = 5):
        """
        Run continuous load test with specified concurrency.

        Args:
            concurrency: Number of concurrent get operation workers
            target_gps: Target gets per second (None for unlimited)
            duration: Test duration in seconds (None for continuous)
            reporting_interval: Interval in seconds for reporting metrics
        """
        self.console.print(f"\n[bold cyan]Starting Load Test[/bold cyan]")
        self.console.print(f"  • Collections: {self.num_collections} ({', '.join(self.collection_names)})")
        self.console.print(f"  • Documents per collection: {self.num_documents}")
        self.console.print(f"  • Concurrency: {concurrency}")
        self.console.print(f"  • Target GPS: {target_gps if target_gps else 'Unlimited'}")
        self.console.print(f"  • Duration: {f'{duration}s' if duration else 'Continuous'}")
        self.console.print(f"  • Reporting Interval: {reporting_interval}s")
        self.console.print("\nPress Ctrl+C to stop...\n")

        # Generate a large pool of document ID batches
        id_batches_pool = self.generate_document_id_batches(1000)

        # Create queue for get operations
        get_queue = queue.Queue(maxsize=concurrency * 2)
        stop_event = threading.Event()

        # Start worker threads
        with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
            # Start worker tasks
            workers = []
            for i in range(concurrency):
                future = executor.submit(self.get_worker, i, get_queue, stop_event)
                workers.append(future)

            # Get operation generator thread
            def get_generator():
                """Generate get operations at target rate"""
                batch_index = 0
                last_time = time.time()

                while not stop_event.is_set():
                    # Get next batch of document IDs
                    id_batch = id_batches_pool[batch_index % len(id_batches_pool)]
                    batch_index += 1

                    # Add to queue (non-blocking)
                    try:
                        get_queue.put(id_batch, timeout=0.1)
                    except queue.Full:
                        continue

                    # Rate limiting if target GPS specified
                    if target_gps:
                        # Calculate sleep time to maintain target GPS
                        elapsed = time.time() - last_time
                        target_interval = 1.0 / target_gps
                        if elapsed < target_interval:
                            time.sleep(target_interval - elapsed)
                        last_time = time.time()

            # Start get operation generator thread
            generator_thread = threading.Thread(target=get_generator, daemon=True)
            generator_thread.start()

            # Reporting thread
            def report_metrics():
                """Periodically report metrics"""
                while not stop_event.is_set():
                    time.sleep(reporting_interval)
                    if not stop_event.is_set():
                        self.log_percentiles()

            report_thread = threading.Thread(target=report_metrics, daemon=True)
            report_thread.start()

            # Run for specified duration or until interrupted
            try:
                if duration:
                    time.sleep(duration)
                    self.console.print(f"\n[yellow]Test duration of {duration}s completed[/yellow]")
                else:
                    # Run until interrupted
                    while True:
                        time.sleep(1)
            except KeyboardInterrupt:
                self.console.print("\n[yellow]Stopping load test...[/yellow]")
            finally:
                # Stop all threads
                stop_event.set()

                # Wait for workers to finish current tasks
                concurrent.futures.wait(workers, timeout=5)

                # Print final statistics
                self.console.print("\n[bold green]Final Statistics[/bold green]")
                self.console.print(self.format_stats_table())

                # Save detailed metrics if needed
                runtime = time.time() - self.start_time
                self.console.print(f"\n[green]Total Runtime: {runtime:.1f}s[/green]")
                self.console.print(f"[green]Total Gets: {self.total_queries}[/green]")
                self.console.print(f"[green]Average GPS (Gets/sec): {self.total_queries/runtime:.1f}[/green]")

                # Show final error summary if any errors occurred
                if self.total_errors > 0:
                    with self.metrics_lock:
                        all_errors = [m for m in self.all_metrics if not m.success]

                    self.console.print(f"\n[bold red]Error Summary (Total: {self.total_errors})[/bold red]")
                    top_errors = self.get_error_summary(all_errors, top_n=10)
                    for i, (error_msg, count) in enumerate(top_errors.items(), 1):
                        percentage = (count / self.total_errors) * 100
                        self.console.print(f"  {i}. ({count}x, {percentage:.1f}%) {error_msg}")


def main():
    """Main entry point for the script"""
    parser = argparse.ArgumentParser(
        description="Chroma concurrent load testing tool using get operations",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--num-documents",
        type=int,
        default=100,
        help="Number of documents to create per collection"
    )

    parser.add_argument(
        "--num-collections",
        type=int,
        default=1,
        help="Number of collections to create and test concurrently"
    )

    parser.add_argument(
        "--concurrency",
        type=int,
        default=10,
        help="Number of concurrent get operation workers"
    )

    parser.add_argument(
        "--target-gps",
        type=int,
        default=None,
        help="Target gets per second (unlimited if not specified)"
    )

    parser.add_argument(
        "--duration",
        type=int,
        default=None,
        help="Test duration in seconds (continuous if not specified)"
    )

    parser.add_argument(
        "--reporting-interval",
        type=int,
        default=5,
        help="Interval in seconds for reporting metrics"
    )

    parser.add_argument(
        "--client-type",
        choices=["ephemeral", "persistent"],
        default="ephemeral",
        help="Type of Chroma client to use"
    )

    parser.add_argument(
        "--collection-name-prefix",
        type=str,
        default="load_test_collection",
        help="Prefix for collection names (collections will be named prefix_000, prefix_001, etc.)"
    )

    args = parser.parse_args()

    console = Console()

    # Print configuration
    console.print("\n[bold cyan]Chroma Load Testing Tool[/bold cyan]")
    console.print("=" * 50)

    try:
        # Initialize tester
        tester = ChromaLoadTester(
            client_type=args.client_type,
            collection_name_prefix=args.collection_name_prefix,
            num_collections=args.num_collections,
            num_documents=args.num_documents
        )

        # Populate collections
        tester.populate_collections()

        # Run load test
        tester.run_load_test(
            concurrency=args.concurrency,
            target_gps=args.target_gps,
            duration=args.duration,
            reporting_interval=args.reporting_interval
        )

    except KeyboardInterrupt:
        console.print("\n[yellow]Load test interrupted by user[/yellow]")
    except Exception as e:
        console.print(f"\n[bold red]Error: {e}[/bold red]")
        raise


if __name__ == "__main__":
    main()

Migration plan

Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?

Observability plan

What is the plan to instrument and monitor this change?

Documentation Changes

Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the docs section?

Copy link

Reviewer Checklist

Please leverage this checklist to ensure your code review is thorough before approving

Testing, Bugs, Errors, Logs, Documentation

  • Can you think of any use case in which the code does not behave as intended? Have they been tested?
  • Can you think of any inputs or external events that could break the code? Is user input validated and safe? Have they been tested?
  • If appropriate, are there adequate property based tests?
  • If appropriate, are there adequate unit tests?
  • Should any logging, debugging, tracing information be added or removed?
  • Are error messages user-friendly?
  • Have all documentation changes needed been made?
  • Have all non-obvious changes been commented?

System Compatibility

  • Are there any potential impacts on other parts of the system or backward compatibility?
  • Does this change intersect with any items on our roadmap, and if so, is there a plan for fitting them together?

Quality

  • Is this code of a unexpectedly high quality (Readability, Modularity, Intuitiveness)

@codetheweb codetheweb marked this pull request as ready for review August 19, 2025 23:55
Copy link
Contributor

propel-code-bot bot commented Aug 19, 2025

Graceful Shutdown and Memberlist Propagation for Query/Log Services

This PR implements coordinated, configurable graceful shutdown logic for Chroma's query and log services to address high-latency client errors and in-flight call blocking observed during Kubernetes pod rollouts. It introduces a shutdown 'grace period' to allow memberlist updates to propagate, prevents pods marked for deletion from being considered healthy by clients, and exposes grace period configuration in service config files and Helm charts for operational tuning. The change is applied to both the Rust (query/log) servers and the Go memberlist watcher, with attention to deployment compatibility and observability.

Key Changes

• Adds a grpc_shutdown_grace_period configuration (default 1s, configurable via YAML/Helm) to Rust query and log servers; on SIGTERM, servers sleep for this grace period before terminating.
• Applies the grace period on SIGTERM to both query and log gRPC servers, allowing time for memberlist updates and client connection draining.
• Rust config/helpers module introduced for (de)serializing Duration fields in config as seconds.
• Query and log service config, as well as Helm templates, updated to support the new grace period.
• Go memberlist_manager updated: pods with DeletionTimestamp are excluded from memberlist reported to clients, ensuring graceful exclusion from quorum.
• Tests and default config values updated for the new config options.
• Improved shutdown signals and tracing messages-logs SIGTERM receipt and the wait period.

Affected Areas

• rust/worker (query service): config, main entrance, shutdown handling
• rust/log-service: config, shutdown handling
• rust/config: config helpers, duration (de)serialization, new helpers module
• go/memberlist_manager: node watcher, memberlist healthy/finality logic
• Helm charts/configuration for query/log services

This summary was automatically generated by @propel-code-bot

@HammadB HammadB self-requested a review August 19, 2025 23:58
@codetheweb codetheweb requested a review from HammadB August 21, 2025 00:06
@blacksmith-sh blacksmith-sh bot deleted a comment from codetheweb Aug 21, 2025
@codetheweb codetheweb requested a review from tanujnay112 August 21, 2025 18:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants