diff --git a/graphindex/frontend/src/api/client.js b/graphindex/frontend/src/api/client.js index 501cc6f..de104fa 100644 --- a/graphindex/frontend/src/api/client.js +++ b/graphindex/frontend/src/api/client.js @@ -53,8 +53,14 @@ export const api = { extendedAsk: (opts) => postJSON('/api/extended_ask', opts), } +// Long-polling only: the server runs Flask-SocketIO under Werkzeug's +// threading mode, which does not handle the websocket transport reliably +// (server-side returns 500 "write() before start_response"). Polling is the +// supported transport for this deployment. +const SOCKET_OPTS = { transports: ['polling'], upgrade: false } + export function connectSocket(onEvent, onHello) { - const socket = io(BASE || '/', { transports: ['polling'], upgrade: false }) + const socket = io(BASE || '/', SOCKET_OPTS) socket.on('index_event', onEvent) if (onHello) socket.on('hello', onHello) return socket @@ -62,7 +68,7 @@ export function connectSocket(onEvent, onHello) { // Separate listener for extended_ask streaming (the "ext_event" channel). export function connectExtSocket(onEvent) { - const socket = io(BASE || '/', { transports: ['polling'], upgrade: false }) + const socket = io(BASE || '/', SOCKET_OPTS) socket.on('ext_event', onEvent) return socket } diff --git a/graphindex/graphindex/api/server.py b/graphindex/graphindex/api/server.py index 1247218..dc0c814 100644 --- a/graphindex/graphindex/api/server.py +++ b/graphindex/graphindex/api/server.py @@ -8,13 +8,17 @@ from __future__ import annotations +import logging import threading +import time from pathlib import Path from flask import Flask, send_from_directory from flask_cors import CORS from flask_socketio import SocketIO +log = logging.getLogger(__name__) + from ..config import Config from ..pipeline.events import IndexEvent from ..pipeline.orchestrator import Indexer @@ -46,31 +50,81 @@ def _resolve_frontend_dist() -> Path | None: _FRONTEND_DIST = _resolve_frontend_dist() +# Event coalescing: cap the number of events forwarded to SocketIO subscribers +# in any single batch, and the minimum interval between batches. Without this, +# a full re-index of a large repo will flood the message queue and the browser. +_BATCH_INTERVAL = 0.05 # seconds +_BATCH_MAX_EVENTS = 500 # events per emit + def create_app(cfg: Config): app = Flask(__name__, static_folder=None) CORS(app) - socketio = SocketIO( - app, - cors_allowed_origins="*", - async_mode="threading", - allow_upgrades=False, - transports=["polling"], - ) + # ``async_mode="threading"`` + the plain Werkzeug dev server does not handle + # WebSocket transport reliably (causes the "write() before start_response" + # 500s seen in production logs). ``allow_upgrades=False`` blocks the + # polling -> websocket *upgrade*, and the connect handler below rejects any + # client that arrives directly on the websocket transport so engine.io + # never enters that broken code path. Long-polling is fully functional. + socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading", + allow_upgrades=False) state = AppState(cfg) app.config["GRAPHINDEX_STATE"] = state app.config["GRAPHINDEX_SOCKETIO"] = socketio - # Bridge the event bus -> SocketIO. Indexing events go on "index_event"; - # extended_ask events (type starting with "ext_") also go on "ext_event". + # Bridge the event bus -> SocketIO with coalescing so a fast indexer cannot + # flood the WS channel / the browser. Events are buffered and flushed by a + # single background dispatcher thread at most every ``_BATCH_INTERVAL``. + pending_index: list[dict] = [] + pending_ext: list[dict] = [] + pending_lock = threading.Lock() + flush_event = threading.Event() + def _forward(evt: IndexEvent) -> None: d = evt.to_dict() - if evt.type.startswith("ext_"): - socketio.emit("ext_event", d) - else: - socketio.emit("index_event", d) - + with pending_lock: + bucket = pending_ext if evt.type.startswith("ext_") else pending_index + bucket.append(d) + # cap unbounded growth if the dispatcher falls behind + if len(bucket) > _BATCH_MAX_EVENTS * 20: + del bucket[:-_BATCH_MAX_EVENTS * 10] + flush_event.set() + + def _dispatcher() -> None: + # The dispatcher MUST keep running for the lifetime of the server. + # A single uncaught exception from socketio.emit() (e.g. transient + # network / client-disconnect / library error) would otherwise kill + # the thread and silently stop forwarding all future events. + while True: + try: + flush_event.wait() + flush_event.clear() + time.sleep(_BATCH_INTERVAL) # coalesce a burst + with pending_lock: + idx_batch = pending_index[:_BATCH_MAX_EVENTS] + del pending_index[:len(idx_batch)] + ext_batch = pending_ext[:_BATCH_MAX_EVENTS] + del pending_ext[:len(ext_batch)] + more = bool(pending_index or pending_ext) + for d in idx_batch: + try: + socketio.emit("index_event", d) + except Exception as exc: + log.warning("socketio.emit(index_event) failed: %s", exc) + for d in ext_batch: + try: + socketio.emit("ext_event", d) + except Exception as exc: + log.warning("socketio.emit(ext_event) failed: %s", exc) + if more: + flush_event.set() # keep draining + except Exception as exc: + log.exception("socketio dispatcher loop error: %s", exc) + time.sleep(1.0) # avoid hot-spin on a repeating failure + + threading.Thread(target=_dispatcher, daemon=True, + name="socketio-dispatcher").start() state.bus.subscribe(_forward) # Background extended_ask runner (single-flight). @@ -129,8 +183,21 @@ def _job(): @socketio.on("connect") def _on_connect(): + # Reject any client that arrived on the websocket transport directly + # (i.e. without going through polling first). ``allow_upgrades=False`` + # already disables polling -> ws upgrades; this guard closes the + # remaining hole where a client whose transports list starts with + # 'websocket' would still hit the broken Werkzeug WS path. + from flask import request as _flask_request + try: + transport = _flask_request.args.get("transport", "") + except Exception: + transport = "" + if transport == "websocket": + return False # engine.io interprets False as "reject connection" socketio.emit("hello", {"indexing": state.indexing, "repo": str(cfg.repo_path)}) + return None # ---- frontend (SPA) ---- @app.get("/") diff --git a/graphindex/graphindex/api/state.py b/graphindex/graphindex/api/state.py index 815b0ec..6d6a4a1 100644 --- a/graphindex/graphindex/api/state.py +++ b/graphindex/graphindex/api/state.py @@ -2,11 +2,21 @@ Holds the singletons the API/sockets operate on and (re)builds the read-side components after an index run so freshly indexed data is served immediately. + +The read-side components (db / vectors / embedder / search_engine / chat / +ask_engine) are exposed as a single immutable snapshot +(:class:`_ReadState`), referenced atomically by ``self._state``. Public +attributes ``db``, ``vectors``, ... proxy to the current snapshot via +``__getattr__``, so a request thread that captured one of them at the start +of a request keeps using a consistent set even if :meth:`reload` swaps in a +new snapshot mid-request. """ from __future__ import annotations import threading +from dataclasses import dataclass +from typing import Any from ..config import Config from ..embedding import get_embedder @@ -18,6 +28,17 @@ from ..storage.compsrc import CompSrc +@dataclass(frozen=True) +class _ReadState: + """Immutable bundle of read-side components built together.""" + db: GraphDB + vectors: VectorStore + embedder: Any + search_engine: SearchEngine + chat: Any + ask_engine: AskEngine + + class AppState: def __init__(self, cfg: Config): self.cfg = cfg @@ -25,24 +46,44 @@ def __init__(self, cfg: Config): self.bus = EventBus() self.index_lock = threading.Lock() self.ask_lock = threading.Lock() + # Guards the *swap* of self._state (writer-side only). Readers do not + # take this lock; they capture self._state into a local snapshot + # (which is just a pointer assignment — atomic under the GIL). + self._reload_lock = threading.Lock() self.indexing = False self.asking = False self.watcher = None self.last_extended = None self.compsrc = CompSrc(cfg.repo_path) - self._open() - - def _open(self) -> None: - self.db = GraphDB(self.cfg.db_path) - dim = self.db.get_meta("embed_dim", self.cfg.embed_dim) or self.cfg.embed_dim - self.vectors = VectorStore(self.cfg.vectors_path, dim) - # Embedder for query-time semantic search (lazy/real or fallback). - self.embedder = get_embedder(self.cfg) - self.search_engine = SearchEngine(self.db, self.vectors, self.embedder) - # Chat model is loaded lazily on first ask; building the engine is cheap. - self.chat = get_chat(self.cfg) - self.ask_engine = AskEngine(self.cfg, self.db, self.vectors, - self.embedder, chat=self.chat) + self._state: _ReadState = self._build() + + # ---- public attribute proxy ------------------------------------------ + def __getattr__(self, name: str) -> Any: + # Only invoked for attributes not found on the instance — i.e. the + # read-side fields, which live on the current snapshot. Note: each + # access reads the *current* snapshot. Callers that need a consistent + # view across multiple fields should call :meth:`snapshot` once. + if name in _ReadState.__dataclass_fields__: + return getattr(self._state, name) + raise AttributeError(name) + + def snapshot(self) -> _ReadState: + """Return the current read-side snapshot (consistent set of components).""" + return self._state + + # ---- construction ---------------------------------------------------- + def _build(self) -> _ReadState: + """Construct a fresh set of read-side components. Pure: no self mutation.""" + db = GraphDB(self.cfg.db_path) + dim = db.get_meta("embed_dim", self.cfg.embed_dim) or self.cfg.embed_dim + vectors = VectorStore(self.cfg.vectors_path, dim) + embedder = get_embedder(self.cfg) + search_engine = SearchEngine(db, vectors, embedder) + chat = get_chat(self.cfg) + ask_engine = AskEngine(self.cfg, db, vectors, embedder, chat=chat) + return _ReadState(db=db, vectors=vectors, embedder=embedder, + search_engine=search_engine, chat=chat, + ask_engine=ask_engine) def ensure_chat(self): """Retry chat-model discovery and keep AskEngine wired to it.""" @@ -53,10 +94,14 @@ def ensure_chat(self): return self.chat def build_extended(self, opts: dict, bus) -> ExtendedAsk: - """Construct an ExtendedAsk orchestrator with caps from the request.""" - self.ensure_chat() + """Construct an ExtendedAsk orchestrator with caps from the request. + + Reads from a single snapshot to avoid mixing old + new components if + a reload is racing this call. + """ + snap = self._state return ExtendedAsk( - self.cfg, self.db, self.vectors, self.embedder, chat=self.chat, bus=bus, + self.cfg, snap.db, snap.vectors, snap.embedder, chat=snap.chat, bus=bus, keyword_rounds=int(opts.get("keyword_rounds", 2)), keywords_per_round=int(opts.get("keywords_per_round", 4)), agents_per_round=int(opts.get("agents_per_round", 3)), @@ -64,9 +109,28 @@ def build_extended(self, opts: dict, bus) -> ExtendedAsk: ) def reload(self) -> None: - """Re-open read components (call after an index run completes).""" + """Re-open read components after an index run. + + Builds the new snapshot OUTSIDE the swap lock (so request threads + aren't blocked while we load the embedder / chat model), then takes + the lock just to perform the atomic pointer swap and capture the + old snapshot. The old DB connection is closed on a delayed daemon + timer so any in-flight request that already captured it can finish. + """ + new_state = self._build() # heavy work, no lock + with self._reload_lock: # short critical section + old_state = self._state + self._state = new_state try: - self.db.close() + timer = threading.Timer(2.0, lambda: _safe_close(old_state.db)) + timer.daemon = True # don't block interpreter shutdown + timer.start() except Exception: - pass - self._open() + _safe_close(old_state.db) + + +def _safe_close(db) -> None: + try: + db.close() + except Exception: + pass diff --git a/graphindex/graphindex/cli.py b/graphindex/graphindex/cli.py index b3fa58e..a9bf53b 100644 --- a/graphindex/graphindex/cli.py +++ b/graphindex/graphindex/cli.py @@ -204,7 +204,10 @@ def serve(repo, host, port, backend, watch): if watch: from graphindex.watcher import RepoWatcher state = app.config["GRAPHINDEX_STATE"] - w = RepoWatcher(cfg, bus=state.bus) + # Share the API's single-flight index_lock so the watcher and the + # foreground /api/index runner never collide on the same DB/vectors. + w = RepoWatcher(cfg, bus=state.bus, index_lock=state.index_lock, + on_reload=state.reload) w.start() state.watcher = w url = f"http://{cfg.host}:{cfg.port}" diff --git a/graphindex/graphindex/engine/server_mode.py b/graphindex/graphindex/engine/server_mode.py index 15abe74..70f709a 100644 --- a/graphindex/graphindex/engine/server_mode.py +++ b/graphindex/graphindex/engine/server_mode.py @@ -9,23 +9,96 @@ from __future__ import annotations +import atexit import subprocess import sys +import threading import time from urllib import request, error import json +# Track every llama_cpp.server subprocess we spawn so an atexit hook can stop +# them. Without this, a crashed parent leaves orphan model servers running +# (holding GPU/RAM + ports) until the user kills them manually. +_PROCS: list[subprocess.Popen] = [] +_PROCS_LOCK = threading.Lock() + + +def _terminate_proc(proc: subprocess.Popen) -> None: + """Best-effort: terminate -> wait -> kill -> wait. Always reaps the child.""" + if proc.poll() is not None: + return + try: + proc.terminate() + except Exception: + pass + try: + proc.wait(timeout=5) + return + except Exception: + pass + # Still alive after grace period — force-kill and reap to avoid zombies. + try: + proc.kill() + except Exception: + pass + try: + proc.wait(timeout=5) + except Exception: + pass + + +def _stop_all() -> None: + # Snapshot the list under the lock but DO NOT remove handles until each + # process is actually reaped (so a concurrent caller doesn't double-stop + # a child while it's still being killed). + with _PROCS_LOCK: + procs = list(_PROCS) + for p in procs: + _terminate_proc(p) + with _PROCS_LOCK: + try: + _PROCS.remove(p) + except ValueError: + pass + + +atexit.register(_stop_all) + + def start_server(model_path: str, host: str = "127.0.0.1", port: int = 8081, embedding: bool = False, extra: list[str] | None = None ) -> subprocess.Popen: + # ``cmd`` is built from a fixed executable + flags; ``shell=False`` (the + # default) ensures arguments are passed without shell interpretation so + # values like ``model_path`` cannot be used for command injection. cmd = [sys.executable, "-m", "llama_cpp.server", "--model", model_path, "--host", host, "--port", str(port)] if embedding: cmd += ["--embedding", "true"] if extra: cmd += extra - return subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + proc = subprocess.Popen( # noqa: S603 (shell=False; args are controlled) + cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + with _PROCS_LOCK: + _PROCS.append(proc) + return proc + + +def stop_server(proc: subprocess.Popen) -> None: + """Explicitly terminate a server started by :func:`start_server`. + + Reaps the child before removing it from the tracking list, so an atexit + cleanup running concurrently can't drop the handle mid-shutdown. + """ + _terminate_proc(proc) + with _PROCS_LOCK: + try: + _PROCS.remove(proc) + except ValueError: + pass def wait_ready(host: str, port: int, timeout: float = 60.0) -> bool: diff --git a/graphindex/graphindex/storage/compsrc.py b/graphindex/graphindex/storage/compsrc.py index 741a2fc..cba51da 100644 --- a/graphindex/graphindex/storage/compsrc.py +++ b/graphindex/graphindex/storage/compsrc.py @@ -6,11 +6,14 @@ from __future__ import annotations -import json import hashlib +import json +import logging import zlib from pathlib import Path +log = logging.getLogger(__name__) + class CompSrc: def __init__(self, root_dir: str | Path): @@ -38,8 +41,8 @@ def store(self, node_id: str, code: str | None, summary: str = "", language: str compressed = zlib.compress(raw) try: self._path(node_id).write_bytes(compressed) - except Exception: - pass + except OSError as exc: + log.warning("compsrc.store failed for %s: %s", node_id, exc) def add_to_batch(self, node_id: str, code: str | None, summary: str = "", language: str = "") -> None: """Buffer a node for batch storage.""" @@ -59,7 +62,8 @@ def flush_batch(self) -> None: compressed = zlib.compress(raw) try: self._path(node_id).write_bytes(compressed) - except Exception: + except OSError as exc: + log.warning("compsrc.flush_batch failed for %s: %s", node_id, exc) continue self._batch.clear() @@ -72,7 +76,8 @@ def retrieve(self, node_id: str) -> dict | None: compressed = path.read_bytes() raw = zlib.decompress(compressed) return json.loads(raw.decode("utf-8")) - except Exception: + except (OSError, zlib.error, json.JSONDecodeError, UnicodeDecodeError) as exc: + log.warning("compsrc.retrieve failed for %s: %s", node_id, exc) return None def get_source_with_summary(self, node_id: str) -> str | None: @@ -113,6 +118,6 @@ def prune_stale(self, active_node_ids: set[str]) -> int: try: p.unlink() deleted += 1 - except Exception: - pass + except OSError as exc: + log.warning("compsrc.prune_stale: %s: %s", p, exc) return deleted diff --git a/graphindex/graphindex/storage/db.py b/graphindex/graphindex/storage/db.py index 1c61a9c..85fc7d4 100644 --- a/graphindex/graphindex/storage/db.py +++ b/graphindex/graphindex/storage/db.py @@ -152,16 +152,31 @@ def upsert_nodes(self, nodes: Iterable[Node]) -> None: def update_node_state(self, node_id: str, state: str) -> None: self.conn.execute("UPDATE nodes SET state=? WHERE id=?", (state, node_id)) + # Whitelist of columns that ``update_node_fields`` is allowed to set. + # Built from the node schema; anything not in this set is rejected to + # prevent SQL injection via attacker-controlled keys (the column name is + # interpolated into the UPDATE statement and cannot be parameterized). + _UPDATABLE_NODE_COLS = frozenset({ + "kind", "name", "path", "language", "start_line", "end_line", + "signature", "params", "search_string", "type_hint", "summary", + "tags", "state", "degree", "flags", "commit_id", "extra", + }) + @_locked def update_node_fields(self, node_id: str, **fields: Any) -> None: if not fields: return cols, vals = [], [] for k, v in fields.items(): + if k not in self._UPDATABLE_NODE_COLS: + raise ValueError(f"update_node_fields: column {k!r} not allowed") cols.append(f"{k}=?") vals.append(_dumps(v) if isinstance(v, (list, dict)) else v) vals.append(node_id) self.conn.execute(f"UPDATE nodes SET {','.join(cols)} WHERE id=?", vals) + # DML opens an implicit transaction; commit so changes persist and the + # write lock is released (otherwise other writers block forever). + self.conn.commit() @_locked def get_node(self, node_id: str) -> Node | None: diff --git a/graphindex/graphindex/storage/vectors.py b/graphindex/graphindex/storage/vectors.py index d4ff622..5bb817e 100644 --- a/graphindex/graphindex/storage/vectors.py +++ b/graphindex/graphindex/storage/vectors.py @@ -9,6 +9,7 @@ from __future__ import annotations import json +import threading from pathlib import Path import numpy as np @@ -38,6 +39,10 @@ def __init__(self, path: str | Path, dim: int): self._matrix: np.ndarray = np.zeros((0, dim), dtype="float32") self._index = faiss.IndexFlatIP(dim) if _HAS_FAISS else None self.backend = "faiss" if _HAS_FAISS else "numpy" + # Serialize mutation/query so the watcher (incremental updates) and the + # main indexer / API search calls cannot corrupt the ids ↔ matrix + # mapping by interleaving ``add``/``remove``/``search``. + self._lock = threading.RLock() self._load() # -- persistence ------------------------------------------------------ @@ -49,6 +54,14 @@ def _ids_file(self) -> Path: def _vec_file(self) -> Path: return self.path / "vectors.npy" + def _reset_state(self) -> None: + """Reset to an empty store in a consistent way (used on load failures).""" + self.ids = [] + self._id_to_idx = {} + self._matrix = np.zeros((0, self.dim), dtype="float32") + if self._index is not None: + self._index.reset() + def _load(self) -> None: if self._ids_file.exists() and self._vec_file.exists(): try: @@ -57,42 +70,55 @@ def _load(self) -> None: # _matrix end up inconsistent (silent corruption / IndexError). mat = np.load(self._vec_file) if mat.shape[1] != self.dim and mat.size: - return # dimension mismatch -> start fresh + self._reset_state() + return self.ids = json.loads(self._ids_file.read_text()) self._id_to_idx = {nid: i for i, nid in enumerate(self.ids)} self._matrix = mat.astype("float32") if self._index is not None and len(self.ids): + self._index.reset() self._index.add(self._matrix) except Exception: - self.ids, self._id_to_idx = [], {} - self._matrix = np.zeros((0, self.dim), dtype="float32") + self._reset_state() def save(self) -> None: - self._ids_file.write_text(json.dumps(self.ids)) - np.save(self._vec_file, self._matrix) + with self._lock: + self._ids_file.write_text(json.dumps(self.ids)) + np.save(self._vec_file, self._matrix) # -- mutation --------------------------------------------------------- def add(self, node_id: str, vector: np.ndarray | list[float]) -> None: - vec = _normalize(np.asarray(vector, dtype="float32").reshape(1, -1)) - idx = self._id_to_idx.get(node_id) - if idx is not None: - self._matrix[idx] = vec[0] - self._rebuild_faiss() - return - self._id_to_idx[node_id] = len(self.ids) - self.ids.append(node_id) - self._matrix = np.vstack([self._matrix, vec]) if self._matrix.size else vec - if self._index is not None: - self._index.add(vec) + raw = np.asarray(vector, dtype="float32").reshape(1, -1) + # Validate dimension BEFORE mutating any state so a bad input cannot + # leave ids / _id_to_idx / _matrix / FAISS index in an inconsistent + # state (which would later surface as a cryptic IndexError). + if raw.shape[1] != self.dim: + raise ValueError( + f"VectorStore.add: dimension mismatch (got {raw.shape[1]}, " + f"expected {self.dim})") + vec = _normalize(raw) + with self._lock: + idx = self._id_to_idx.get(node_id) + if idx is not None: + self._matrix[idx] = vec[0] + self._rebuild_faiss() + return + self._id_to_idx[node_id] = len(self.ids) + self.ids.append(node_id) + self._matrix = np.vstack([self._matrix, vec]) if self._matrix.size else vec + if self._index is not None: + self._index.add(vec) def remove(self, node_ids: set[str]) -> None: if not node_ids: return - keep = [i for i, nid in enumerate(self.ids) if nid not in node_ids] - self.ids = [self.ids[i] for i in keep] - self._id_to_idx = {nid: i for i, nid in enumerate(self.ids)} - self._matrix = self._matrix[keep] if keep else np.zeros((0, self.dim), dtype="float32") - self._rebuild_faiss() + with self._lock: + keep = [i for i, nid in enumerate(self.ids) if nid not in node_ids] + self.ids = [self.ids[i] for i in keep] + self._id_to_idx = {nid: i for i, nid in enumerate(self.ids)} + self._matrix = (self._matrix[keep] if keep + else np.zeros((0, self.dim), dtype="float32")) + self._rebuild_faiss() def get_vector(self, node_id: str) -> np.ndarray | None: """Return the stored (normalized) vector for ``node_id`` or None. @@ -100,8 +126,10 @@ def get_vector(self, node_id: str) -> np.ndarray | None: Public accessor so callers (e.g. duplicate detection) don't reach into ``_matrix``/``ids`` internals. """ - idx = self._id_to_idx.get(node_id) - return None if idx is None else self._matrix[idx] + with self._lock: + idx = self._id_to_idx.get(node_id) + # copy so callers can't mutate the stored row via a view + return None if idx is None else self._matrix[idx].copy() def _rebuild_faiss(self) -> None: if self._index is None: @@ -113,16 +141,24 @@ def _rebuild_faiss(self) -> None: # -- query ------------------------------------------------------------ def search(self, vector: np.ndarray | list[float], top_k: int = 20 ) -> list[tuple[str, float]]: - if not self.ids: - return [] - q = _normalize(np.asarray(vector, dtype="float32").reshape(1, -1)) - k = min(top_k, len(self.ids)) - if self._index is not None: - scores, idxs = self._index.search(q, k) - return [(self.ids[i], float(s)) for i, s in zip(idxs[0], scores[0]) if i >= 0] - sims = (self._matrix @ q[0]) - order = np.argsort(-sims)[:k] - return [(self.ids[i], float(sims[i])) for i in order] + raw = np.asarray(vector, dtype="float32").reshape(1, -1) + if raw.shape[1] != self.dim: + raise ValueError( + f"VectorStore.search: dimension mismatch (got {raw.shape[1]}, " + f"expected {self.dim})") + q = _normalize(raw) + with self._lock: + if not self.ids: + return [] + k = min(top_k, len(self.ids)) + if self._index is not None: + scores, idxs = self._index.search(q, k) + return [(self.ids[i], float(s)) + for i, s in zip(idxs[0], scores[0]) if i >= 0] + sims = (self._matrix @ q[0]) + order = np.argsort(-sims)[:k] + return [(self.ids[i], float(sims[i])) for i in order] def __len__(self) -> int: - return len(self.ids) + with self._lock: + return len(self.ids) diff --git a/graphindex/graphindex/watcher/incremental.py b/graphindex/graphindex/watcher/incremental.py index e5f4915..2e54784 100644 --- a/graphindex/graphindex/watcher/incremental.py +++ b/graphindex/graphindex/watcher/incremental.py @@ -9,9 +9,11 @@ from __future__ import annotations +import logging import os import threading import time +from typing import Callable, Optional from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -23,6 +25,8 @@ from ..scanner.ignore import IgnoreEngine from ..scanner.walker import detect_language +log = logging.getLogger(__name__) + class _Handler(FileSystemEventHandler): def __init__(self, watcher: "RepoWatcher"): @@ -39,7 +43,9 @@ def on_any_event(self, event): class RepoWatcher: def __init__(self, cfg: Config, bus: EventBus | None = None, - debounce: float = 1.0, do_summarize: bool = False): + debounce: float = 1.0, do_summarize: bool = False, + index_lock: Optional["threading.Lock"] = None, + on_reload: Optional[Callable[[], None]] = None): self.cfg = cfg self.bus = bus or EventBus() self.debounce = debounce @@ -49,6 +55,13 @@ def __init__(self, cfg: Config, bus: EventBus | None = None, self._lock = threading.Lock() self._timer: threading.Timer | None = None self._observer: Observer | None = None + # Optional shared single-flight lock: when running inside the API + # process the watcher MUST coordinate with the foreground /api/index + # runner so they never open the same SQLite/FAISS state concurrently. + self._index_lock = index_lock + # Callback to invalidate read-side caches (e.g. AppState.reload) once + # an incremental update completes, so the WebUI gets fresh data. + self._on_reload = on_reload def _enqueue(self, abs_path: str) -> None: try: @@ -74,23 +87,62 @@ def _flush(self) -> None: self._pending.clear() if not changed: return - self.bus.emit("log", message=f"Incremental update: {len(changed)} file(s)") - indexer = Indexer(self.cfg, bus=self.bus, do_summarize=self.do_summarize) - # 1) Handle deletions from the changed set FIRST so node_remove events are - # emitted (a blanket prune first would delete the rows, leaving the - # explicit loop nothing to report and the UI graph stale). - existing = [c for c in changed if (self.cfg.repo_path / c).exists()] - for path in changed: - if not (self.cfg.repo_path / path).exists(): - ids = indexer.db.delete_file(path) - indexer.vectors.remove(set(ids)) - for nid in ids: - self.bus.emit("node_remove", id=nid) - # 2) Catch any other files deleted outside the changed set. - prune_deleted_files(self.cfg, indexer.db, indexer.vectors) - if existing: - indexer.index(only_changed=existing) - indexer.db.close() + + # Coordinate with the foreground indexer so we never open two writers + # on the same DB/vectors at once. If a full re-index is in progress we + # re-queue these files and let the next debounce tick try again. + lock = self._index_lock + if lock is not None and not lock.acquire(blocking=False): + with self._lock: + self._pending.update(changed) + if self._timer: + self._timer.cancel() + self._timer = threading.Timer(self.debounce, self._flush) + self._timer.daemon = True + self._timer.start() + return + + try: + self.bus.emit("log", message=f"Incremental update: {len(changed)} file(s)") + indexer = Indexer(self.cfg, bus=self.bus, do_summarize=self.do_summarize) + try: + # 1) Handle deletions from the changed set FIRST so node_remove + # events are emitted (a blanket prune first would delete the + # rows, leaving the explicit loop nothing to report and the + # UI graph stale). + existing = [c for c in changed if (self.cfg.repo_path / c).exists()] + removed_any = False + for path in changed: + if not (self.cfg.repo_path / path).exists(): + ids = indexer.db.delete_file(path) + if ids: + indexer.vectors.remove(set(ids)) + removed_any = True + for nid in ids: + self.bus.emit("node_remove", id=nid) + # 2) Catch any other files deleted outside the changed set. + # (prune_deleted_files persists vectors itself when it + # removes any nodes; we only need to handle the case where + # our own loop above removed vectors but prune found nothing.) + prune_deleted_files(self.cfg, indexer.db, indexer.vectors) + if existing: + indexer.index(only_changed=existing) + elif removed_any: + # Indexer.index() is what normally persists vectors. When + # the batch was purely deletions from the changed set, we + # must save here or the removals reappear after a restart. + indexer.vectors.save() + finally: + indexer.db.close() + if self._on_reload: + try: + self._on_reload() + except Exception as exc: + log.exception("on_reload callback failed: %s", exc) + self.bus.emit("log", message=f"reload failed: {exc}") + finally: + if lock is not None: + lock.release() def start(self) -> None: self._observer = Observer()