diff --git a/pyproject.toml b/pyproject.toml index bd42ca9..7f49249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,8 @@ codesurface = "codesurface.server:main" [tool.hatch.build.targets.wheel] packages = ["src/codesurface"] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", +] diff --git a/src/codesurface/db.py b/src/codesurface/db.py index fd46c01..dcb025d 100644 --- a/src/codesurface/db.py +++ b/src/codesurface/db.py @@ -143,11 +143,14 @@ def delete_by_files(conn: sqlite3.Connection, file_paths: list[str]) -> int: def search(conn: sqlite3.Connection, query: str, n: int = 10, - member_type: str | None = None) -> list[dict]: + member_type: str | None = None, + file_path: str | None = None) -> list[dict]: """Full-text search with BM25 ranking + PascalCase-aware matching. Column weights: member_name (10x) > class_name (5x) > search_text (4x) > signature (3x) > fqn/summary (1x) Type bonus: class/struct/enum defs rank higher than same-named members. + + file_path: optional path prefix or exact file to scope results. """ clean = _escape_fts(query) if not clean.strip(): @@ -157,27 +160,33 @@ def search(conn: sqlite3.Connection, query: str, n: int = 10, ranking = """bm25(api_fts, 1.0, 5.0, 10.0, 0.5, 3.0, 4.0) + CASE WHEN r.member_type = 'type' THEN -1.0 ELSE 0.0 END""" - if member_type: - sql = f""" - SELECT r.*, {ranking} AS rank - FROM api_fts f - JOIN api_records r ON r.rowid = f.rowid - WHERE api_fts MATCH ? AND r.member_type = ? - ORDER BY rank - LIMIT ? - """ - rows = conn.execute(sql, (clean, member_type, n)).fetchall() - else: - sql = f""" - SELECT r.*, {ranking} AS rank - FROM api_fts f - JOIN api_records r ON r.rowid = f.rowid - WHERE api_fts MATCH ? - ORDER BY rank - LIMIT ? - """ - rows = conn.execute(sql, (clean, n)).fetchall() + conditions = ["api_fts MATCH ?"] + params: list = [clean] + if member_type: + conditions.append("r.member_type = ?") + params.append(member_type) + + if file_path: + if file_path.endswith("/"): + conditions.append("r.file_path LIKE ?") + params.append(file_path + "%") + else: + conditions.append("(r.file_path = ? OR r.file_path LIKE ?)") + params.extend([file_path, file_path + "/%"]) + + where = " AND ".join(conditions) + params.append(n) + + sql = f""" + SELECT r.*, {ranking} AS rank + FROM api_fts f + JOIN api_records r ON r.rowid = f.rowid + WHERE {where} + ORDER BY rank + LIMIT ? + """ + rows = conn.execute(sql, params).fetchall() return [dict(row) for row in rows] @@ -190,19 +199,29 @@ def get_by_fqn(conn: sqlite3.Connection, fqn: str) -> dict | None: def get_class_members(conn: sqlite3.Connection, class_name: str, - namespace: str | None = None) -> list[dict]: - """Get all members of a class by class name, optionally filtered by namespace.""" + namespace: str | None = None, + file_path: str | None = None) -> list[dict]: + """Get all members of a class by class name, optionally filtered by namespace and/or file_path.""" + conditions = ["class_name = ?"] + params: list = [class_name] + if namespace is not None: - rows = conn.execute( - "SELECT * FROM api_records WHERE class_name = ? AND namespace = ? " - "ORDER BY member_type, member_name", - (class_name, namespace), - ).fetchall() - else: - rows = conn.execute( - "SELECT * FROM api_records WHERE class_name = ? ORDER BY member_type, member_name", - (class_name,), - ).fetchall() + conditions.append("namespace = ?") + params.append(namespace) + + if file_path: + if file_path.endswith("/"): + conditions.append("file_path LIKE ?") + params.append(file_path + "%") + else: + conditions.append("(file_path = ? OR file_path LIKE ?)") + params.extend([file_path, file_path + "/%"]) + + where = " AND ".join(conditions) + rows = conn.execute( + f"SELECT * FROM api_records WHERE {where} ORDER BY member_type, member_name", + params, + ).fetchall() return [dict(row) for row in rows] diff --git a/src/codesurface/filters.py b/src/codesurface/filters.py new file mode 100644 index 0000000..4147ae0 --- /dev/null +++ b/src/codesurface/filters.py @@ -0,0 +1,126 @@ +"""Path filtering for codesurface indexing. + +Handles default exclusions (worktrees, submodules, vendored/build dirs) +and user-configured exclusions (.codesurfaceignore, --exclude CLI flag). +""" +from __future__ import annotations + +import fnmatch +from pathlib import Path + +# Directories excluded by name in every project — vendored deps, build +# output, VCS internals, and IDE config that never contain user source. +_DEFAULT_EXCLUDED_DIRS: frozenset[str] = frozenset({ + # JS / Node + "node_modules", "bower_components", + # Python + ".venv", "venv", "env", "__pycache__", ".tox", ".mypy_cache", + ".pytest_cache", "site-packages", + # Go + "vendor", "testdata", "third_party", "examples", "example", + # .NET / Java + "bin", "obj", "packages", ".gradle", ".mvn", + "generated", "generated-sources", "generated-test-sources", + # Build output / caches + "dist", "build", "out", "target", ".next", ".nuxt", ".nx", + # VCS / IDE + ".git", ".hg", ".svn", + ".idea", ".vscode", ".vs", + # Misc + ".yarn", ".pnp", "coverage", ".turbo", ".cache", ".worktrees", +}) + + +def _read_git_file(path: Path) -> str | None: + """Read .git FILE content if present. Returns None if .git is a directory.""" + git = path / ".git" + if git.is_file(): + try: + return git.read_text().strip() + except OSError: + return None + return None + + +def _is_git_worktree(git_content: str) -> bool: + """True if .git file references a worktrees/ path.""" + return "/worktrees/" in git_content + + +def _is_git_submodule(git_content: str) -> bool: + """True if .git file references a modules/ path.""" + return "/modules/" in git_content + + +def _read_ignore_file(project_root: Path) -> list[str]: + """Read .codesurfaceignore and return non-empty, non-comment lines.""" + ignore_path = project_root / ".codesurfaceignore" + if not ignore_path.is_file(): + return [] + lines = [] + for line in ignore_path.read_text().splitlines(): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + lines.append(stripped) + return lines + + +class PathFilter: + """Determines which directories and files to skip during indexing. + + Default exclusions (always applied): + - Any directory named .worktrees + - Any subdirectory with a .git FILE referencing /worktrees/ (git worktree) + - Any subdirectory with a .git FILE referencing /modules/ (submodule), + unless include_submodules=True + + User exclusions via exclude_globs (CLI) and .codesurfaceignore (project file). + """ + + def __init__( + self, + project_root: Path, + exclude_globs: list[str] | None = None, + include_submodules: bool = False, + ) -> None: + self._root = project_root + self._include_submodules = include_submodules + self._globs: list[str] = list(exclude_globs or []) + self._globs.extend(_read_ignore_file(project_root)) + + def is_dir_excluded_name(self, name: str) -> bool: + """Fast check using only the directory basename (no I/O).""" + return name in _DEFAULT_EXCLUDED_DIRS + + def is_dir_excluded(self, path: Path) -> bool: + """Return True if this directory should be skipped entirely.""" + name = path.name + + if name in _DEFAULT_EXCLUDED_DIRS: + return True + + # .git FILE detection (worktrees / submodules) + git_content = _read_git_file(path) + if git_content is not None: + if _is_git_worktree(git_content): + return True + if _is_git_submodule(git_content) and not self._include_submodules: + return True + + return False + + def is_file_excluded(self, path: Path) -> bool: + """Return True if this file matches any user exclusion glob.""" + if not self._globs: + return False + try: + rel = str(path.relative_to(self._root)).replace("\\", "/") + except ValueError: + return False + return any(fnmatch.fnmatch(rel, g) for g in self._globs) + + def is_file_excluded_rel(self, rel_path: str) -> bool: + """Return True if a relative path matches any user exclusion glob.""" + if not self._globs: + return False + return any(fnmatch.fnmatch(rel_path, g) for g in self._globs) diff --git a/src/codesurface/parsers/__init__.py b/src/codesurface/parsers/__init__.py index 73cea5e..32e7b82 100644 --- a/src/codesurface/parsers/__init__.py +++ b/src/codesurface/parsers/__init__.py @@ -2,10 +2,12 @@ from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from ..filters import PathFilter from .base import BaseParser _REGISTRY: dict[str, type[BaseParser]] = {} @@ -28,22 +30,42 @@ def get_parser(lang: str) -> BaseParser: return cls() -def detect_languages(project_dir: Path) -> list[str]: - """Detect which registered languages are present in *project_dir*.""" +def detect_languages( + project_dir: Path, + path_filter: "PathFilter | None" = None, +) -> list[str]: + """Detect which registered languages are present in *project_dir*. + + Uses os.walk with *path_filter* pruning so vendored directories + (node_modules, .git, etc.) are skipped during detection. + """ + exts = tuple(_EXT_TO_LANG.keys()) found: set[str] = set() - for ext, lang in _EXT_TO_LANG.items(): - # Quick check: does at least one file with this extension exist? - try: - next(project_dir.rglob(f"*{ext}")) - found.add(lang) - except StopIteration: - pass + + for root, dirs, files in os.walk(project_dir): + root_path = Path(root) + if path_filter is not None: + dirs[:] = [d for d in dirs if not path_filter.is_dir_excluded(root_path / d)] + + for filename in files: + for ext in exts: + if filename.endswith(ext): + found.add(_EXT_TO_LANG[ext]) + break + + # Stop early once all registered languages are found + if len(found) == len(_REGISTRY): + break + return sorted(found) -def get_parsers_for_project(project_dir: Path) -> list[BaseParser]: +def get_parsers_for_project( + project_dir: Path, + path_filter: "PathFilter | None" = None, +) -> list[BaseParser]: """Return parser instances for every language detected in *project_dir*.""" - return [get_parser(lang) for lang in detect_languages(project_dir)] + return [get_parser(lang) for lang in detect_languages(project_dir, path_filter)] def all_extensions() -> list[str]: diff --git a/src/codesurface/parsers/base.py b/src/codesurface/parsers/base.py index 5e264ea..c085e13 100644 --- a/src/codesurface/parsers/base.py +++ b/src/codesurface/parsers/base.py @@ -1,32 +1,113 @@ """Abstract base class for language parsers.""" +import os +import sys from abc import ABC, abstractmethod from pathlib import Path +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from ..filters import PathFilter class BaseParser(ABC): """Base class that all language parsers must extend. Subclasses implement `file_extensions` and `parse_file`. - The default `parse_directory` walks recursively for matching files. + The default `parse_directory` walks recursively for matching files, + using raw str paths internally to avoid pathlib overhead at scale. + + Subclasses may override `skip_suffixes` or `skip_files` to filter + out files by suffix or exact name (e.g. ".d.ts", "module-info.java"). """ @property @abstractmethod def file_extensions(self) -> list[str]: - """File extensions this parser handles, e.g. [".cs"].""" + """File extensions this parser handles, e.g. ['.cs'].""" + + @property + def skip_suffixes(self) -> tuple[str, ...]: + """File suffixes to skip (e.g. ('.d.ts',)). Override in subclass.""" + return () + + @property + def skip_files(self) -> frozenset[str]: + """Exact filenames to skip (e.g. frozenset({'conftest.py'})). Override in subclass.""" + return frozenset() @abstractmethod def parse_file(self, path: Path, base_dir: Path) -> list[dict]: """Parse a single file and return API records.""" - def parse_directory(self, directory: Path) -> list[dict]: - """Recursively parse all matching files under *directory*.""" - records = [] - for ext in self.file_extensions: - for f in sorted(directory.rglob(f"*{ext}")): - try: - records.extend(self.parse_file(f, directory)) - except Exception: + def _should_skip_dir(self, name: str) -> bool: + """Extra per-parser directory skip logic. Override if needed.""" + return False + + def _walk_files( + self, directory: Path, path_filter: "PathFilter | None" = None, + ) -> list[str]: + """Collect all file paths this parser would process. + + Returns absolute path strings, applying all filters (PathFilter, + skip_suffixes, skip_files, _should_skip_dir). + """ + exts = tuple(self.file_extensions) + skip_suf = self.skip_suffixes + skip_fn = self.skip_files + dir_str = str(directory) + result: list[str] = [] + + for root, dirs, files in os.walk(dir_str): + if path_filter is not None: + dirs[:] = [ + d for d in dirs + if not path_filter.is_dir_excluded_name(d) + and not self._should_skip_dir(d) + and not path_filter.is_dir_excluded(Path(os.path.join(root, d))) + ] + else: + dirs[:] = [d for d in dirs if not self._should_skip_dir(d)] + + for filename in files: + if not filename.endswith(exts): + continue + if skip_suf and filename.endswith(skip_suf): continue + if skip_fn and filename in skip_fn: + continue + + filepath = os.path.join(root, filename) + + if path_filter is not None and path_filter.is_file_excluded_rel( + filepath[len(dir_str) + 1:].replace("\\", "/") + ): + continue + + result.append(filepath) + + return result + + def parse_directory( + self, directory: Path, path_filter: "PathFilter | None" = None, + on_progress: "Callable[[Path], None] | None" = None, + ) -> list[dict]: + """Recursively parse all matching files under *directory*. + + Uses os.walk with str paths to avoid pathlib overhead. + PathFilter handles all default exclusions (node_modules, .git, etc.). + """ + file_paths = self._walk_files(directory, path_filter) + records: list[dict] = [] + + for filepath in file_paths: + f = Path(filepath) + try: + records.extend(self.parse_file(f, directory)) + except Exception as e: + print(f"codesurface: failed to parse {filepath}: {e}", file=sys.stderr) + finally: + if on_progress is not None: + on_progress(f) + return records diff --git a/src/codesurface/parsers/cpp.py b/src/codesurface/parsers/cpp.py index 534cc83..6d3cd49 100644 --- a/src/codesurface/parsers/cpp.py +++ b/src/codesurface/parsers/cpp.py @@ -12,6 +12,7 @@ extraction. """ +import os import re from pathlib import Path @@ -21,12 +22,6 @@ # Skip patterns # --------------------------------------------------------------------------- -_SKIP_DIRS = frozenset({ - "build", ".git", "third_party", "vendor", "test", "tests", - "examples", "node_modules", ".cache", "obj", "out", - "Debug", "Release", "x64", "x86", ".vs", -}) - _SKIP_DIR_PREFIXES = ("cmake-build-",) # C++ keywords that can't be member names @@ -239,27 +234,11 @@ class CppParser(BaseParser): def file_extensions(self) -> list[str]: return [".h", ".hpp", ".hxx", ".h++"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip build/vendor/test directories.""" - records: list[dict] = [] - ext_set = set(self.file_extensions) - for f in sorted(directory.rglob("*")): - if f.suffix not in ext_set: - continue - parts = f.relative_to(directory).parts - if any( - p in _SKIP_DIRS - or any(p.startswith(pfx) for pfx in _SKIP_DIR_PREFIXES) - for p in parts - ): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + def _should_skip_dir(self, name: str) -> bool: + return ( + name in ("Debug", "Release", "x64", "x86") + or any(name.startswith(pfx) for pfx in _SKIP_DIR_PREFIXES) + ) def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_cpp_file(path, base_dir) @@ -272,11 +251,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_cpp_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single C++ header file and extract public API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = path.relative_to(base_dir).as_posix() + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() # Skip generated files diff --git a/src/codesurface/parsers/csharp.py b/src/codesurface/parsers/csharp.py index c7daf50..dbd8c0d 100644 --- a/src/codesurface/parsers/csharp.py +++ b/src/codesurface/parsers/csharp.py @@ -4,6 +4,7 @@ with their full signatures. Doc comments (///) are extracted as bonus data. """ +import os import re from pathlib import Path @@ -112,11 +113,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_cs_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .cs file and extract all public members.""" try: - text = path.read_text(encoding="utf-8-sig", errors="replace") + with open(path, encoding="utf-8-sig", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records = [] diff --git a/src/codesurface/parsers/go.py b/src/codesurface/parsers/go.py index e6be2f3..30a143a 100644 --- a/src/codesurface/parsers/go.py +++ b/src/codesurface/parsers/go.py @@ -8,21 +8,13 @@ Doc comments are consecutive // lines immediately before a declaration. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "vendor", "testdata", ".git", "node_modules", "third_party", - "examples", "example", -}) - -_SKIP_FILE_SUFFIX = "_test.go" - # Go reserved words that can't be identifiers _GO_KEYWORDS = frozenset({ "func", "type", "var", "const", "import", "package", @@ -141,24 +133,12 @@ class GoParser(BaseParser): def file_extensions(self) -> list[str]: return [".go"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip vendor/testdata/test files.""" - records: list[dict] = [] - for f in sorted(directory.rglob("*.go")): - parts = f.relative_to(directory).parts - # Skip dirs - if any(p in _SKIP_DIRS or p.startswith("_") for p in parts): - continue - # Skip test files - if f.name.endswith(_SKIP_FILE_SUFFIX): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_suffixes(self) -> tuple[str, ...]: + return ("_test.go",) + + def _should_skip_dir(self, name: str) -> bool: + return name.startswith("_") def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_go_file(path, base_dir) @@ -172,11 +152,12 @@ def _is_exported(name: str) -> bool: def _parse_go_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .go file and extract exported API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = path.relative_to(base_dir).as_posix() + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() # Skip generated files (Go standard: "Code generated ... DO NOT EDIT") diff --git a/src/codesurface/parsers/java.py b/src/codesurface/parsers/java.py index f29390e..8aee366 100644 --- a/src/codesurface/parsers/java.py +++ b/src/codesurface/parsers/java.py @@ -5,21 +5,13 @@ Javadoc comments (/** ... */) are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "test", "tests", "target", "build", ".gradle", ".git", - "node_modules", ".mvn", "out", "generated", - "generated-sources", "generated-test-sources", - ".idea", "bin", -}) - _SKIP_SUFFIXES = ( "Test.java", "Tests.java", @@ -131,25 +123,13 @@ class JavaParser(BaseParser): def file_extensions(self) -> list[str]: return [".java"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip test/build directories.""" - records = [] - for f in sorted(directory.rglob("*.java")): - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - fname = f.name - if fname in _SKIP_FILES: - continue - if any(fname.endswith(s) for s in _SKIP_SUFFIXES): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_suffixes(self) -> tuple[str, ...]: + return _SKIP_SUFFIXES + + @property + def skip_files(self) -> frozenset[str]: + return _SKIP_FILES def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_java_file(path, base_dir) @@ -158,11 +138,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_java_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .java file and extract public API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records: list[dict] = [] diff --git a/src/codesurface/parsers/python_parser.py b/src/codesurface/parsers/python_parser.py index 459fd96..9fbc068 100644 --- a/src/codesurface/parsers/python_parser.py +++ b/src/codesurface/parsers/python_parser.py @@ -5,20 +5,13 @@ Docstrings are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "__pycache__", ".git", ".venv", "venv", "env", - "node_modules", ".tox", ".mypy_cache", ".pytest_cache", - "dist", "build", "egg-info", -}) - _SKIP_FILES = frozenset({ "setup.py", "conftest.py", }) @@ -73,25 +66,12 @@ class PythonParser(BaseParser): def file_extensions(self) -> list[str]: return [".py"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip common non-source directories.""" - records = [] - for f in sorted(directory.rglob("*.py")): - # Skip files in excluded directories - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - if any(p.endswith(".egg-info") for p in parts): - continue - if f.name in _SKIP_FILES: - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_files(self) -> frozenset[str]: + return _SKIP_FILES + + def _should_skip_dir(self, name: str) -> bool: + return name.endswith(".egg-info") def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_py_file(path, base_dir) @@ -100,11 +80,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_py_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .py file and extract public API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records = [] @@ -690,8 +671,8 @@ def _file_to_module(path: Path, base_dir: Path) -> str: Walks up from the file looking for __init__.py to determine package boundaries. """ - rel = path.relative_to(base_dir) - parts = list(rel.parts) + rel = os.path.relpath(path, base_dir).replace("\\", "/") + parts = rel.split("/") # Remove .py extension from last part if parts and parts[-1].endswith(".py"): diff --git a/src/codesurface/parsers/typescript.py b/src/codesurface/parsers/typescript.py index 41eb435..a75be5d 100644 --- a/src/codesurface/parsers/typescript.py +++ b/src/codesurface/parsers/typescript.py @@ -5,7 +5,9 @@ JSDoc comments (/** ... */) are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser @@ -13,12 +15,6 @@ # --- Skip patterns --- -_SKIP_DIRS = frozenset({ - "node_modules", "dist", "build", ".git", ".next", - "__tests__", "__mocks__", "coverage", ".turbo", ".cache", - ".tox", ".mypy_cache", ".venv", "venv", -}) - _SKIP_SUFFIXES = ( ".d.ts", ".test.ts", ".test.tsx", @@ -149,24 +145,9 @@ class TypeScriptParser(BaseParser): def file_extensions(self) -> list[str]: return [".ts", ".tsx"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip test/build/node_modules directories.""" - records = [] - for ext in self.file_extensions: - for f in sorted(directory.rglob(f"*{ext}")): - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - fname = f.name - if any(fname.endswith(s) for s in _SKIP_SUFFIXES): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_suffixes(self) -> tuple[str, ...]: + return _SKIP_SUFFIXES def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_ts_file(path, base_dir) @@ -175,15 +156,16 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_ts_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single TypeScript file and extract exported API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records: list[dict] = [] - namespace = _file_to_module(path, base_dir) + namespace = _file_to_module(rel_path) # State class_stack: list[tuple[str, str, int]] = [] # (name, kind, brace_depth) @@ -191,8 +173,9 @@ def _parse_ts_file(path: Path, base_dir: Path) -> list[dict]: paren_depth = 0 # track () to skip multi-line parameter lists in_multiline_comment = False + n_lines = len(lines) i = 0 - while i < len(lines): + while i < n_lines: line = lines[i] # --- Multi-line comment tracking --- @@ -1020,52 +1003,33 @@ def _extract_field_type(stripped: str, field_name: str) -> str: # --- Brace counting --- -def _count_braces_and_parens(line: str) -> tuple[int, int]: - """Count net brace and paren depth changes, skipping inside strings.""" - brace_depth = 0 - paren_depth = 0 - in_single = False - in_double = False - in_template = False - escape = False - - for ch in line: - if escape: - escape = False - continue - if ch == "\\": - escape = True - continue +# Strip string literals (single, double, template) including escaped quotes. +# Order matters: escaped quotes must be consumed before quote boundaries. +_STRING_RE = re.compile( + r"""'(?:[^'\\]|\\.)*'?""" # single-quoted (possibly unterminated) + r'|"(?:[^"\\]|\\.)*"?' # double-quoted + r"|`(?:[^`\\]|\\.)*`?" # template literal (simplified: ignores ${}) + r"|//.*$", # line comments also contain no braces + re.DOTALL, +) - if in_single: - if ch == "'": - in_single = False - continue - if in_double: - if ch == '"': - in_double = False - continue - if in_template: - if ch == "`": - in_template = False - continue - if ch == "'": - in_single = True - elif ch == '"': - in_double = True - elif ch == "`": - in_template = True - elif ch == "{": - brace_depth += 1 - elif ch == "}": - brace_depth -= 1 - elif ch == "(": - paren_depth += 1 - elif ch == ")": - paren_depth -= 1 +_HAS_STRING_OR_COMMENT = frozenset("'\"`/") - return brace_depth, paren_depth + +def _count_braces_and_parens(line: str) -> tuple[int, int]: + """Count net brace and paren depth changes, skipping inside strings.""" + # Fast path: most lines have no strings or comments — skip regex entirely + if _HAS_STRING_OR_COMMENT.isdisjoint(line): + return ( + line.count("{") - line.count("}"), + line.count("(") - line.count(")"), + ) + stripped = _STRING_RE.sub("", line) + return ( + stripped.count("{") - stripped.count("}"), + stripped.count("(") - stripped.count(")"), + ) # --- Visibility --- @@ -1176,25 +1140,19 @@ def _split_params(params_str: str) -> list[str]: # --- File path to module --- -def _file_to_module(path: Path, base_dir: Path) -> str: - """Convert file path to a dot-separated module namespace. +def _file_to_module(rel_path: str) -> str: + """Convert relative file path to a dot-separated module namespace. src/services/myService.ts -> src.services.myService """ - rel = path.relative_to(base_dir) - parts = list(rel.parts) - - if not parts: - return "" - - # Remove extension from last part - last = parts[-1] + # Strip extension for ext in (".tsx", ".ts"): - if last.endswith(ext): - parts[-1] = last[:-len(ext)] + if rel_path.endswith(ext): + rel_path = rel_path[:-len(ext)] break - # Drop index files (index.ts re-exports) + # Split on / and drop trailing "index" + parts = rel_path.split("/") if parts and parts[-1] == "index": parts = parts[:-1] diff --git a/src/codesurface/server.py b/src/codesurface/server.py index 89dd355..535bde1 100644 --- a/src/codesurface/server.py +++ b/src/codesurface/server.py @@ -2,6 +2,7 @@ import argparse import json +import os import re import sys import time @@ -10,6 +11,7 @@ from mcp.server.fastmcp import FastMCP from . import db +from .filters import PathFilter from .parsers import all_extensions, detect_languages, get_parser, get_parsers_for_project mcp = FastMCP( @@ -25,6 +27,19 @@ _project_path: Path | None = None _file_mtimes: dict[str, float] = {} # rel_path → mtime _index_fresh: bool = True # True = checked for changes since last hit; skip auto-reindex +_path_filter: PathFilter | None = None + + +def _count_files( + project_path: Path, + parsers: list, + path_filter: "PathFilter | None", +) -> int: + """Pre-scan: count source files using each parser's full filter rules.""" + total = 0 + for parser in parsers: + total += len(parser._walk_files(project_path, path_filter)) + return total def _index_full(project_path: Path, language: str | None = None) -> str: @@ -35,41 +50,67 @@ def _index_full(project_path: Path, language: str | None = None) -> str: if language: parsers = [get_parser(language)] else: - parsers = get_parsers_for_project(project_path) + parsers = get_parsers_for_project(project_path, path_filter=_path_filter) if not parsers: return "No supported source files detected in project directory." + total = _count_files(project_path, parsers, _path_filter) + print(f"[codesurface] scanning {total:,} files...", file=sys.stderr, flush=True) + + parsed = [0] + last_pct = [0.0] + last_time = [t0] + new_mtimes: dict[str, float] = {} + + # Emit the 0% line immediately + print( + f"[codesurface] indexing: 0% ({0:>6,} / {total:,})", + file=sys.stderr, + flush=True, + ) + + def on_progress(f: Path) -> None: + parsed[0] += 1 + # Snapshot mtime while we're at it (avoids a third walk) + rel = str(f.relative_to(project_path)).replace("\\", "/") + try: + new_mtimes[rel] = f.stat().st_mtime + except OSError: + pass + now = time.perf_counter() + pct = parsed[0] / max(total, 1) + elapsed = now - t0 + if pct - last_pct[0] >= 0.05 or now - last_time[0] >= 3.0: + print( + f"[codesurface] indexing: {pct:3.0%} ({parsed[0]:>6,} / {total:,}) {elapsed:.1f}s", + file=sys.stderr, + flush=True, + ) + last_pct[0] = pct + last_time[0] = now + records = [] for parser in parsers: - records.extend(parser.parse_directory(project_path)) + records.extend( + parser.parse_directory(project_path, path_filter=_path_filter, on_progress=on_progress) + ) parse_time = time.perf_counter() - t0 t1 = time.perf_counter() _conn = db.create_memory_db(records) db_time = time.perf_counter() - t1 - # Snapshot mtimes for all registered extensions used - extensions = set() - for parser in parsers: - extensions.update(parser.file_extensions) - - _file_mtimes = {} - for ext in extensions: - for f in sorted(project_path.rglob(f"*{ext}")): - rel = str(f.relative_to(project_path)).replace("\\", "/") - try: - _file_mtimes[rel] = f.stat().st_mtime - except OSError: - pass + _file_mtimes = new_mtimes stats = db.get_stats(_conn) langs = ", ".join(type(p).__name__.replace("Parser", "") for p in parsers) - return ( - f"Indexed {stats['total']} records from {stats.get('files', 0)} files " - f"({langs}) in {parse_time + db_time:.2f}s " - f"(parse: {parse_time:.2f}s, db: {db_time:.2f}s)" + summary = ( + f"[codesurface] done: {stats['total']:,} records from {stats.get('files', 0):,} files " + f"({langs}) in {parse_time + db_time:.2f}s" ) + print(summary, file=sys.stderr, flush=True) + return summary def _index_incremental(project_path: Path) -> tuple[str, bool]: @@ -84,12 +125,20 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: t0 = time.perf_counter() # Collect all registered extensions - extensions = set(all_extensions()) + exts = tuple(all_extensions()) - # Scan current files + # Scan current files, pruning excluded directories during walk current: dict[str, float] = {} - for ext in extensions: - for f in sorted(project_path.rglob(f"*{ext}")): + for root, dirs, files in os.walk(project_path): + root_path = Path(root) + if _path_filter is not None: + dirs[:] = [d for d in dirs if not _path_filter.is_dir_excluded(root_path / d)] + for filename in files: + if not filename.endswith(exts): + continue + f = root_path / filename + if _path_filter is not None and _path_filter.is_file_excluded(f): + continue rel = str(f.relative_to(project_path)).replace("\\", "/") try: current[rel] = f.stat().st_mtime @@ -120,7 +169,7 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: db.delete_by_files(_conn, list(stale)) # Build extension-to-parser map for dirty files - parsers = get_parsers_for_project(project_path) + parsers = get_parsers_for_project(project_path, path_filter=_path_filter) ext_to_parser: dict[str, object] = {} for parser in parsers: for ext in parser.file_extensions: @@ -129,7 +178,7 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: # Parse dirty files new_records = [] for rel in sorted(dirty): - full_path = project_path / rel.replace("/", "\\") + full_path = project_path / Path(rel) suffix = full_path.suffix parser = ext_to_parser.get(suffix) if parser is None: @@ -261,6 +310,7 @@ def search( query: str, n_results: int = 5, member_type: str | None = None, + file_path: str | None = None, ) -> str: """Search the indexed API by keyword. @@ -271,17 +321,19 @@ def search( query: Search terms (e.g. "MergeService", "BlastBoard", "GridCoord") n_results: Max results to return (default 5, max 20) member_type: Optional filter — "type", "method", "property", "field", or "event" + file_path: Optional path prefix or exact file to scope results + (e.g. "src/services/" or "src/services/foo.ts") """ if _conn is None: return "No codebase indexed. Start the server with --project ." global _index_fresh n_results = min(max(n_results, 1), 20) - results = db.search(_conn, query, n=n_results, member_type=member_type) + results = db.search(_conn, query, n=n_results, member_type=member_type, file_path=file_path) if not results: if _auto_reindex(): - results = db.search(_conn, query, n=n_results, member_type=member_type) + results = db.search(_conn, query, n=n_results, member_type=member_type, file_path=file_path) if not results: return f"No results found for '{query}'. Try broader search terms." @@ -296,7 +348,7 @@ def search( @mcp.tool() -def get_signature(name: str) -> str: +def get_signature(name: str, file_path: str | None = None) -> str: """Look up the exact signature of an API member by name or FQN. Use when you need exact parameter types, return types, or method signatures @@ -304,6 +356,7 @@ def get_signature(name: str) -> str: Args: name: Member name or FQN, e.g. "TryMerge", "CampGame.Services.IMergeService.TryMerge" + file_path: Optional path prefix to scope the lookup """ global _index_fresh if _conn is None: @@ -316,9 +369,20 @@ def _lookup() -> str | None: return _format_record(record) # 2. Substring match (overloads or partial FQN) + # Build optional file_path filter + file_clause = "" + file_params: list = [] + if file_path: + if file_path.endswith("/"): + file_clause = " AND file_path LIKE ?" + file_params = [file_path + "%"] + else: + file_clause = " AND (file_path = ? OR file_path LIKE ?)" + file_params = [file_path, file_path + "/%"] + rows = _conn.execute( - "SELECT * FROM api_records WHERE fqn LIKE ? ORDER BY fqn", - (f"%{name}%",), + f"SELECT * FROM api_records WHERE fqn LIKE ?{file_clause} ORDER BY fqn", + (f"%{name}%", *file_params), ).fetchall() if rows: parts = [f"Found {len(rows)} match(es) for '{name}':\n"] @@ -330,7 +394,7 @@ def _lookup() -> str | None: return "\n".join(parts) # 3. FTS fallback - results = db.search(_conn, name, n=5) + results = db.search(_conn, name, n=5, file_path=file_path) if results: parts = [f"No exact match for '{name}'. Did you mean:\n"] for r in results: @@ -351,7 +415,7 @@ def _lookup() -> str | None: @mcp.tool() -def get_class(class_name: str) -> str: +def get_class(class_name: str, file_path: str | None = None) -> str: """Get a complete reference card for a class — all public members. Shows every method, property, field, and event with signatures. @@ -359,6 +423,7 @@ def get_class(class_name: str) -> str: Args: class_name: Class name, e.g. "BlastBoardModel", "IMergeService", "CampGridService" + file_path: Optional path prefix to scope the lookup """ global _index_fresh if _conn is None: @@ -373,13 +438,13 @@ def get_class(class_name: str) -> str: else: short_name = re.split(r"[.:]", class_name)[-1] - members = db.get_class_members(_conn, short_name, namespace=ns_filter) + members = db.get_class_members(_conn, short_name, namespace=ns_filter, file_path=file_path) if not members: if _auto_reindex(): - members = db.get_class_members(_conn, short_name, namespace=ns_filter) + members = db.get_class_members(_conn, short_name, namespace=ns_filter, file_path=file_path) if not members: - results = db.search(_conn, class_name, n=5, member_type="type") + results = db.search(_conn, class_name, n=5, member_type="type", file_path=file_path) if results: parts = [f"No class '{class_name}' found. Did you mean:\n"] for r in results: @@ -395,7 +460,7 @@ def get_class(class_name: str) -> str: # Show disambiguation notice ns_filter = _pick_primary_namespace(namespaces, members) if ns_filter is not None: - members = db.get_class_members(_conn, short_name, namespace=ns_filter) + members = db.get_class_members(_conn, short_name, namespace=ns_filter, file_path=file_path) _index_fresh = False type_record = next((m for m in members if m["member_type"] == "type"), None) @@ -521,17 +586,28 @@ def main(): help="Path to source directory to index") parser.add_argument("--language", default=None, help="Language to parse (e.g. csharp). Auto-detected if omitted.") + parser.add_argument("--exclude", default=None, + help="Comma-separated glob patterns to exclude from indexing " + "(e.g. 'tests/**,generated/**')") + parser.add_argument("--include-submodules", action="store_true", default=False, + help="Include git submodules in indexing (excluded by default)") args, remaining = parser.parse_known_args() - global _project_path + global _project_path, _path_filter + + exclude_globs = [g.strip() for g in args.exclude.split(",")] if args.exclude else [] if args.project: - _project_path = Path(args.project) + _project_path = Path(args.project).expanduser().resolve() if not _project_path.is_dir(): print(f"Warning: Project path not found: {args.project}", file=sys.stderr) else: - summary = _index_full(_project_path, language=args.language) - print(summary, file=sys.stderr) + _path_filter = PathFilter( + _project_path, + exclude_globs=exclude_globs, + include_submodules=args.include_submodules, + ) + _index_full(_project_path, language=args.language) mcp.run() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..a9fd61c --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,181 @@ +"""Tests for PathFilter default skip rules.""" +import os +from pathlib import Path +import pytest +from codesurface.filters import PathFilter + + +@pytest.fixture +def tmp_project(tmp_path): + """Project root with a variety of subdirectories.""" + # Normal source file + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.ts").write_text("export class Foo {}") + + # .worktrees directory (should always be skipped) + wt = tmp_path / ".worktrees" / "pr-42" + wt.mkdir(parents=True) + (wt / "src").mkdir() + (wt / "src" / "main.ts").write_text("export class Bar {}") + # .git FILE in worktree (git worktree marker) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-42\n") + + # Submodule (should be skipped by default) + sub = tmp_path / "vendor" / "mylib" + sub.mkdir(parents=True) + (sub / "lib.ts").write_text("export class Lib {}") + (sub / ".git").write_text("gitdir: /repo/.git/modules/mylib\n") + + # Regular nested dir (should NOT be skipped) + (tmp_path / "packages" / "core").mkdir(parents=True) + (tmp_path / "packages" / "core" / "index.ts").write_text("export class Core {}") + + return tmp_path + + +def test_worktrees_dir_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / ".worktrees") + + +def test_worktree_subdir_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / ".worktrees" / "pr-42") + + +def test_git_file_worktree_skipped(tmp_project): + pf = PathFilter(tmp_project) + wt = tmp_project / ".worktrees" / "pr-42" + assert pf.is_dir_excluded(wt) + + +def test_submodule_skipped_by_default(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / "vendor" / "mylib") + + +def test_submodule_included_when_opted_in(tmp_project): + pf = PathFilter(tmp_project, include_submodules=True) + assert not pf.is_dir_excluded(tmp_project / "vendor" / "mylib") + + +def test_worktree_still_skipped_even_with_include_submodules(tmp_project): + pf = PathFilter(tmp_project, include_submodules=True) + wt = tmp_project / ".worktrees" / "pr-42" + assert pf.is_dir_excluded(wt) + + +def test_normal_dir_not_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert not pf.is_dir_excluded(tmp_project / "packages" / "core") + + +def test_src_dir_not_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert not pf.is_dir_excluded(tmp_project / "src") + + +def test_exclude_glob_skips_matching_file(tmp_project): + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + (tmp_project / "tests").mkdir() + test_file = tmp_project / "tests" / "foo.ts" + test_file.write_text("") + assert pf.is_file_excluded(test_file) + + +def test_exclude_glob_does_not_skip_nonmatching(tmp_project): + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + assert not pf.is_file_excluded(tmp_project / "src" / "main.ts") + + +def test_codesurfaceignore_loaded(tmp_project): + (tmp_project / ".codesurfaceignore").write_text("generated/**\n# comment\n\n") + pf = PathFilter(tmp_project) + gen_file = tmp_project / "generated" / "types.ts" + assert pf.is_file_excluded(gen_file) + + +def test_codesurfaceignore_and_cli_globs_merged(tmp_project): + (tmp_project / ".codesurfaceignore").write_text("generated/**\n") + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + gen_file = tmp_project / "generated" / "types.ts" + test_file = tmp_project / "tests" / "foo.ts" + assert pf.is_file_excluded(gen_file) + assert pf.is_file_excluded(test_file) + + +def test_codesurfaceignore_missing_is_fine(tmp_project): + # No .codesurfaceignore present — should not raise + pf = PathFilter(tmp_project) + assert not pf.is_file_excluded(tmp_project / "src" / "main.ts") + + +# ---- Query-time file_path filtering ---- +from codesurface import db as csdb + + +def _make_db(): + records = [ + { + "fqn": "Services.FooService", + "namespace": "Services", + "class_name": "FooService", + "member_name": "FooService", + "member_type": "type", + "signature": "class FooService", + "file_path": "src/services/foo.ts", + "line_start": 1, + "line_end": 10, + }, + { + "fqn": "Utils.BarUtil", + "namespace": "Utils", + "class_name": "BarUtil", + "member_name": "BarUtil", + "member_type": "type", + "signature": "class BarUtil", + "file_path": "src/utils/bar.ts", + "line_start": 1, + "line_end": 5, + }, + ] + return csdb.create_memory_db(records) + + +def test_search_file_path_prefix_filters(): + conn = _make_db() + results = csdb.search(conn, "Service", file_path="src/services/") + assert len(results) == 1 + assert results[0]["class_name"] == "FooService" + + +def test_search_file_path_exact_file(): + conn = _make_db() + results = csdb.search(conn, "Bar", file_path="src/utils/bar.ts") + assert len(results) == 1 + assert results[0]["class_name"] == "BarUtil" + + +def test_search_file_path_no_match_returns_empty(): + conn = _make_db() + results = csdb.search(conn, "Foo", file_path="src/utils/") + assert len(results) == 0 + + +def test_search_no_file_path_returns_all(): + conn = _make_db() + results = csdb.search(conn, "Service OR Bar", file_path=None) + assert len(results) == 2 + + +def test_get_class_members_file_path_prefix(): + conn = _make_db() + results = csdb.get_class_members(conn, "FooService", file_path="src/services/") + assert len(results) == 1 + assert results[0]["class_name"] == "FooService" + + +def test_get_class_members_file_path_no_match(): + conn = _make_db() + results = csdb.get_class_members(conn, "FooService", file_path="src/utils/") + assert len(results) == 0 diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 0000000..4c19c26 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,115 @@ +"""Tests for PathFilter integration with parse_directory.""" +from pathlib import Path +import pytest +from codesurface.filters import PathFilter +from codesurface.parsers.typescript import TypeScriptParser +from codesurface.parsers.python_parser import PythonParser +from codesurface.parsers.go import GoParser +from codesurface.parsers.java import JavaParser + + +@pytest.fixture +def ts_project(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "service.ts").write_text( + "export class FooService { bar(): void {} }" + ) + # A worktree that should be skipped + wt = tmp_path / ".worktrees" / "pr-1" + wt.mkdir(parents=True) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-1\n") + (wt / "service.ts").write_text( + "export class WtService { baz(): void {} }" + ) + # A generated file that should be skipped + (tmp_path / "src" / "gen.ts").write_text( + "export class Generated {}" + ) + return tmp_path + + +def test_worktree_files_not_indexed(ts_project): + pf = PathFilter(ts_project) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "WtService" not in names + + +def test_src_files_indexed(ts_project): + pf = PathFilter(ts_project) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "FooService" in names + + +def test_excluded_file_not_indexed(ts_project): + pf = PathFilter(ts_project, exclude_globs=["src/gen.ts"]) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "Generated" not in names + + +def test_no_filter_indexes_worktrees_too(ts_project): + # Without PathFilter, worktrees are NOT excluded — old behaviour preserved + parser = TypeScriptParser() + records = parser.parse_directory(ts_project) + names = [r["class_name"] for r in records] + assert "FooService" in names + # WtService IS found without a filter (old behaviour) + assert "WtService" in names + + +def test_on_progress_called_per_file(ts_project): + """on_progress is called once per successfully parsed file.""" + parser = TypeScriptParser() + visited = [] + parser.parse_directory(ts_project, on_progress=lambda f: visited.append(f)) + # ts_project has service.ts, gen.ts, and a worktree service.ts (3 .ts files total without filter) + assert len(visited) == 3 + assert all(isinstance(f, Path) for f in visited) + + +def test_on_progress_none_is_default(ts_project): + """Omitting on_progress works exactly as before.""" + parser = TypeScriptParser() + records = parser.parse_directory(ts_project) + assert len(records) > 0 + + +@pytest.fixture +def py_project(tmp_path): + (tmp_path / "mod.py").write_text("def hello(): pass\n") + return tmp_path + + +def test_typescript_on_progress(ts_project): + parser = TypeScriptParser() + visited = [] + parser.parse_directory(ts_project, on_progress=lambda f: visited.append(f)) + assert len(visited) >= 1 + + +def test_python_on_progress(py_project): + parser = PythonParser() + visited = [] + parser.parse_directory(py_project, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 + + +def test_go_on_progress(tmp_path): + (tmp_path / "main.go").write_text("package main\nfunc Hello() {}\n") + parser = GoParser() + visited = [] + parser.parse_directory(tmp_path, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 + + +def test_java_on_progress(tmp_path): + (tmp_path / "Foo.java").write_text("public class Foo { public void bar() {} }\n") + parser = JavaParser() + visited = [] + parser.parse_directory(tmp_path, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 diff --git a/tests/test_server_progress.py b/tests/test_server_progress.py new file mode 100644 index 0000000..a2dfef4 --- /dev/null +++ b/tests/test_server_progress.py @@ -0,0 +1,55 @@ +"""Tests for _count_files and _index_full progress output.""" +from pathlib import Path +import pytest + + +def test_count_files_basic(tmp_path): + """_count_files counts matching source files, ignores other extensions.""" + from codesurface.server import _count_files + from codesurface.parsers.python_parser import PythonParser + + (tmp_path / "a.py").write_text("x = 1") + (tmp_path / "b.py").write_text("y = 2") + (tmp_path / "c.txt").write_text("ignored") + + parsers = [PythonParser()] + assert _count_files(tmp_path, parsers, path_filter=None) == 2 + + +def test_count_files_prunes_excluded_dirs(tmp_path): + """_count_files respects path_filter dir exclusions.""" + from codesurface.server import _count_files + from codesurface.parsers.python_parser import PythonParser + from codesurface.filters import PathFilter + + (tmp_path / "src").mkdir() + (tmp_path / "src" / "a.py").write_text("x = 1") + wt = tmp_path / ".worktrees" / "pr-1" + wt.mkdir(parents=True) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-1\n") + (wt / "b.py").write_text("y = 2") + + pf = PathFilter(tmp_path) + parsers = [PythonParser()] + assert _count_files(tmp_path, parsers, path_filter=pf) == 1 + + +def test_index_full_emits_progress_to_stderr(tmp_path, capsys): + """_index_full prints at least one progress line and a done line to stderr.""" + from codesurface import server + from codesurface.filters import PathFilter + + for i in range(5): + (tmp_path / f"m{i}.py").write_text(f"def f{i}(): pass\n") + + server._conn = None + server._project_path = tmp_path + server._path_filter = PathFilter(tmp_path) + + server._index_full(tmp_path) + + captured = capsys.readouterr() + assert "[codesurface]" in captured.err + assert "scanning" in captured.err # e.g. "[codesurface] scanning 5 files..." + assert "indexing:" in captured.err # e.g. "[codesurface] indexing: 0% ..." + assert "done:" in captured.err