diff --git a/pyproject.toml b/pyproject.toml index bd42ca9..7f49249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,8 @@ codesurface = "codesurface.server:main" [tool.hatch.build.targets.wheel] packages = ["src/codesurface"] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", +] diff --git a/src/codesurface/db.py b/src/codesurface/db.py index fd46c01..126449b 100644 --- a/src/codesurface/db.py +++ b/src/codesurface/db.py @@ -143,11 +143,16 @@ def delete_by_files(conn: sqlite3.Connection, file_paths: list[str]) -> int: def search(conn: sqlite3.Connection, query: str, n: int = 10, - member_type: str | None = None) -> list[dict]: + member_type: str | None = None, + file_path: str | None = None, + include_tests: bool = False) -> list[dict]: """Full-text search with BM25 ranking + PascalCase-aware matching. Column weights: member_name (10x) > class_name (5x) > search_text (4x) > signature (3x) > fqn/summary (1x) Type bonus: class/struct/enum defs rank higher than same-named members. + + file_path: optional path prefix or exact file to scope results. + include_tests: if False (default), exclude test files from results. """ clean = _escape_fts(query) if not clean.strip(): @@ -157,27 +162,36 @@ def search(conn: sqlite3.Connection, query: str, n: int = 10, ranking = """bm25(api_fts, 1.0, 5.0, 10.0, 0.5, 3.0, 4.0) + CASE WHEN r.member_type = 'type' THEN -1.0 ELSE 0.0 END""" - if member_type: - sql = f""" - SELECT r.*, {ranking} AS rank - FROM api_fts f - JOIN api_records r ON r.rowid = f.rowid - WHERE api_fts MATCH ? AND r.member_type = ? - ORDER BY rank - LIMIT ? - """ - rows = conn.execute(sql, (clean, member_type, n)).fetchall() - else: - sql = f""" - SELECT r.*, {ranking} AS rank - FROM api_fts f - JOIN api_records r ON r.rowid = f.rowid - WHERE api_fts MATCH ? - ORDER BY rank - LIMIT ? - """ - rows = conn.execute(sql, (clean, n)).fetchall() + conditions = ["api_fts MATCH ?"] + params: list = [clean] + if member_type: + conditions.append("r.member_type = ?") + params.append(member_type) + + if file_path: + if file_path.endswith("/"): + conditions.append("r.file_path LIKE ?") + params.append(file_path + "%") + else: + conditions.append("(r.file_path = ? OR r.file_path LIKE ?)") + params.extend([file_path, file_path + "/%"]) + + if not include_tests: + _add_test_exclusion(conditions, params, alias="r.") + + where = " AND ".join(conditions) + params.append(n) + + sql = f""" + SELECT r.*, {ranking} AS rank + FROM api_fts f + JOIN api_records r ON r.rowid = f.rowid + WHERE {where} + ORDER BY rank + LIMIT ? + """ + rows = conn.execute(sql, params).fetchall() return [dict(row) for row in rows] @@ -190,19 +204,34 @@ def get_by_fqn(conn: sqlite3.Connection, fqn: str) -> dict | None: def get_class_members(conn: sqlite3.Connection, class_name: str, - namespace: str | None = None) -> list[dict]: - """Get all members of a class by class name, optionally filtered by namespace.""" + file_path: str | None = None, + namespace: str | None = None, + include_tests: bool = False) -> list[dict]: + """Get all members of a class by class name, optionally filtered by file_path or namespace.""" + clauses = ["class_name = ?"] + params: list[str] = [class_name] + if namespace is not None: - rows = conn.execute( - "SELECT * FROM api_records WHERE class_name = ? AND namespace = ? " - "ORDER BY member_type, member_name", - (class_name, namespace), - ).fetchall() - else: - rows = conn.execute( - "SELECT * FROM api_records WHERE class_name = ? ORDER BY member_type, member_name", - (class_name,), - ).fetchall() + clauses.append("namespace = ?") + params.append(namespace) + + if file_path: + if file_path.endswith("/"): + clauses.append("file_path LIKE ?") + params.append(file_path + "%") + else: + clauses.append("(file_path = ? OR file_path LIKE ?)") + params.extend([file_path, file_path + "/%"]) + + if not include_tests: + _add_test_exclusion(clauses, params) + + sql = ( + "SELECT * FROM api_records WHERE " + + " AND ".join(clauses) + + " ORDER BY member_type, member_name" + ) + rows = conn.execute(sql, params).fetchall() return [dict(row) for row in rows] @@ -217,6 +246,46 @@ def get_class_namespaces(conn: sqlite3.Connection, class_name: str) -> list[str] return [row["namespace"] for row in rows] +# --------------------------------------------------------------------------- +# Test-file exclusion helpers +# --------------------------------------------------------------------------- + +# Patterns that identify test files. Applied to the relative file_path stored in DB. +# Directory patterns match anywhere in the path; filename patterns use specific LIKE forms. +# Directory names that indicate test code. +_TEST_DIR_NAMES = ("__tests__", "__test__", "tests", "test") + +# Filename patterns: match within the basename of the file path. +# .test. and .spec. → e.g. Button.test.tsx, utils.spec.js +# _test. → e.g. calculator_test.py, foo_test.go +# /test_ → e.g. test_calculator.py (slash ensures it's the filename start) +_TEST_FILE_PATTERNS = ( + ".test.", # foo.test.ts + ".spec.", # foo.spec.ts + "_test.", # foo_test.py, foo_test.go + "/test_", # test_foo.py +) + + +def _add_test_exclusion(clauses: list[str], params: list, *, alias: str = "") -> None: + """Append SQL clauses that exclude test files. + + alias should be e.g. "r." when querying through a join, or "" for direct table access. + Handles both root-relative paths (tests/foo.py) and nested (src/tests/foo.py). + """ + col = f"{alias}file_path" + for name in _TEST_DIR_NAMES: + # Nested: src/tests/foo.py + clauses.append(f"{col} NOT LIKE ?") + params.append(f"%/{name}/%") + # Root-relative: tests/foo.py + clauses.append(f"{col} NOT LIKE ?") + params.append(f"{name}/%") + for pat in _TEST_FILE_PATTERNS: + clauses.append(f"{col} NOT LIKE ?") + params.append(f"%{pat}%") + + def resolve_namespace(conn: sqlite3.Connection, name: str) -> list[dict]: """Find namespace for a class or member name.""" rows = conn.execute( diff --git a/src/codesurface/filters.py b/src/codesurface/filters.py new file mode 100644 index 0000000..4147ae0 --- /dev/null +++ b/src/codesurface/filters.py @@ -0,0 +1,126 @@ +"""Path filtering for codesurface indexing. + +Handles default exclusions (worktrees, submodules, vendored/build dirs) +and user-configured exclusions (.codesurfaceignore, --exclude CLI flag). +""" +from __future__ import annotations + +import fnmatch +from pathlib import Path + +# Directories excluded by name in every project — vendored deps, build +# output, VCS internals, and IDE config that never contain user source. +_DEFAULT_EXCLUDED_DIRS: frozenset[str] = frozenset({ + # JS / Node + "node_modules", "bower_components", + # Python + ".venv", "venv", "env", "__pycache__", ".tox", ".mypy_cache", + ".pytest_cache", "site-packages", + # Go + "vendor", "testdata", "third_party", "examples", "example", + # .NET / Java + "bin", "obj", "packages", ".gradle", ".mvn", + "generated", "generated-sources", "generated-test-sources", + # Build output / caches + "dist", "build", "out", "target", ".next", ".nuxt", ".nx", + # VCS / IDE + ".git", ".hg", ".svn", + ".idea", ".vscode", ".vs", + # Misc + ".yarn", ".pnp", "coverage", ".turbo", ".cache", ".worktrees", +}) + + +def _read_git_file(path: Path) -> str | None: + """Read .git FILE content if present. Returns None if .git is a directory.""" + git = path / ".git" + if git.is_file(): + try: + return git.read_text().strip() + except OSError: + return None + return None + + +def _is_git_worktree(git_content: str) -> bool: + """True if .git file references a worktrees/ path.""" + return "/worktrees/" in git_content + + +def _is_git_submodule(git_content: str) -> bool: + """True if .git file references a modules/ path.""" + return "/modules/" in git_content + + +def _read_ignore_file(project_root: Path) -> list[str]: + """Read .codesurfaceignore and return non-empty, non-comment lines.""" + ignore_path = project_root / ".codesurfaceignore" + if not ignore_path.is_file(): + return [] + lines = [] + for line in ignore_path.read_text().splitlines(): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + lines.append(stripped) + return lines + + +class PathFilter: + """Determines which directories and files to skip during indexing. + + Default exclusions (always applied): + - Any directory named .worktrees + - Any subdirectory with a .git FILE referencing /worktrees/ (git worktree) + - Any subdirectory with a .git FILE referencing /modules/ (submodule), + unless include_submodules=True + + User exclusions via exclude_globs (CLI) and .codesurfaceignore (project file). + """ + + def __init__( + self, + project_root: Path, + exclude_globs: list[str] | None = None, + include_submodules: bool = False, + ) -> None: + self._root = project_root + self._include_submodules = include_submodules + self._globs: list[str] = list(exclude_globs or []) + self._globs.extend(_read_ignore_file(project_root)) + + def is_dir_excluded_name(self, name: str) -> bool: + """Fast check using only the directory basename (no I/O).""" + return name in _DEFAULT_EXCLUDED_DIRS + + def is_dir_excluded(self, path: Path) -> bool: + """Return True if this directory should be skipped entirely.""" + name = path.name + + if name in _DEFAULT_EXCLUDED_DIRS: + return True + + # .git FILE detection (worktrees / submodules) + git_content = _read_git_file(path) + if git_content is not None: + if _is_git_worktree(git_content): + return True + if _is_git_submodule(git_content) and not self._include_submodules: + return True + + return False + + def is_file_excluded(self, path: Path) -> bool: + """Return True if this file matches any user exclusion glob.""" + if not self._globs: + return False + try: + rel = str(path.relative_to(self._root)).replace("\\", "/") + except ValueError: + return False + return any(fnmatch.fnmatch(rel, g) for g in self._globs) + + def is_file_excluded_rel(self, rel_path: str) -> bool: + """Return True if a relative path matches any user exclusion glob.""" + if not self._globs: + return False + return any(fnmatch.fnmatch(rel_path, g) for g in self._globs) diff --git a/src/codesurface/parsers/__init__.py b/src/codesurface/parsers/__init__.py index 73cea5e..32e7b82 100644 --- a/src/codesurface/parsers/__init__.py +++ b/src/codesurface/parsers/__init__.py @@ -2,10 +2,12 @@ from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from ..filters import PathFilter from .base import BaseParser _REGISTRY: dict[str, type[BaseParser]] = {} @@ -28,22 +30,42 @@ def get_parser(lang: str) -> BaseParser: return cls() -def detect_languages(project_dir: Path) -> list[str]: - """Detect which registered languages are present in *project_dir*.""" +def detect_languages( + project_dir: Path, + path_filter: "PathFilter | None" = None, +) -> list[str]: + """Detect which registered languages are present in *project_dir*. + + Uses os.walk with *path_filter* pruning so vendored directories + (node_modules, .git, etc.) are skipped during detection. + """ + exts = tuple(_EXT_TO_LANG.keys()) found: set[str] = set() - for ext, lang in _EXT_TO_LANG.items(): - # Quick check: does at least one file with this extension exist? - try: - next(project_dir.rglob(f"*{ext}")) - found.add(lang) - except StopIteration: - pass + + for root, dirs, files in os.walk(project_dir): + root_path = Path(root) + if path_filter is not None: + dirs[:] = [d for d in dirs if not path_filter.is_dir_excluded(root_path / d)] + + for filename in files: + for ext in exts: + if filename.endswith(ext): + found.add(_EXT_TO_LANG[ext]) + break + + # Stop early once all registered languages are found + if len(found) == len(_REGISTRY): + break + return sorted(found) -def get_parsers_for_project(project_dir: Path) -> list[BaseParser]: +def get_parsers_for_project( + project_dir: Path, + path_filter: "PathFilter | None" = None, +) -> list[BaseParser]: """Return parser instances for every language detected in *project_dir*.""" - return [get_parser(lang) for lang in detect_languages(project_dir)] + return [get_parser(lang) for lang in detect_languages(project_dir, path_filter)] def all_extensions() -> list[str]: diff --git a/src/codesurface/parsers/base.py b/src/codesurface/parsers/base.py index 5e264ea..99b0752 100644 --- a/src/codesurface/parsers/base.py +++ b/src/codesurface/parsers/base.py @@ -1,32 +1,98 @@ """Abstract base class for language parsers.""" +import os +import sys from abc import ABC, abstractmethod from pathlib import Path +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from ..filters import PathFilter class BaseParser(ABC): """Base class that all language parsers must extend. Subclasses implement `file_extensions` and `parse_file`. - The default `parse_directory` walks recursively for matching files. + The default `parse_directory` walks recursively for matching files, + using raw str paths internally to avoid pathlib overhead at scale. + + Subclasses may override `skip_suffixes` or `skip_files` to filter + out files by suffix or exact name (e.g. ".d.ts", "module-info.java"). """ @property @abstractmethod def file_extensions(self) -> list[str]: - """File extensions this parser handles, e.g. [".cs"].""" + """File extensions this parser handles, e.g. ['.cs'].""" + + @property + def skip_suffixes(self) -> tuple[str, ...]: + """File suffixes to skip (e.g. ('.d.ts',)). Override in subclass.""" + return () + + @property + def skip_files(self) -> frozenset[str]: + """Exact filenames to skip (e.g. frozenset({'conftest.py'})). Override in subclass.""" + return frozenset() @abstractmethod def parse_file(self, path: Path, base_dir: Path) -> list[dict]: """Parse a single file and return API records.""" - def parse_directory(self, directory: Path) -> list[dict]: - """Recursively parse all matching files under *directory*.""" - records = [] - for ext in self.file_extensions: - for f in sorted(directory.rglob(f"*{ext}")): + def _should_skip_dir(self, name: str) -> bool: + """Extra per-parser directory skip logic. Override if needed.""" + return False + + def parse_directory( + self, directory: Path, path_filter: "PathFilter | None" = None, + on_progress: "Callable[[Path], None] | None" = None, + ) -> list[dict]: + """Recursively parse all matching files under *directory*. + + Uses os.walk with str paths to avoid pathlib overhead. + PathFilter handles all default exclusions (node_modules, .git, etc.). + """ + exts = tuple(self.file_extensions) + skip_suf = self.skip_suffixes + skip_fn = self.skip_files + dir_str = str(directory) + records: list[dict] = [] + + for root, dirs, files in os.walk(dir_str): + # Prune excluded directories IN PLACE so os.walk skips them + if path_filter is not None: + dirs[:] = [ + d for d in dirs + if not path_filter.is_dir_excluded_name(d) + and not self._should_skip_dir(d) + and not path_filter.is_dir_excluded(Path(os.path.join(root, d))) + ] + else: + dirs[:] = [d for d in dirs if not self._should_skip_dir(d)] + + for filename in files: + if not filename.endswith(exts): + continue + if skip_suf and filename.endswith(skip_suf): + continue + if skip_fn and filename in skip_fn: + continue + + filepath = os.path.join(root, filename) + + if path_filter is not None and path_filter.is_file_excluded_rel( + filepath[len(dir_str) + 1:].replace("\\", "/") + ): + continue + + f = Path(filepath) try: records.extend(self.parse_file(f, directory)) - except Exception: - continue + except Exception as e: + print(f"codesurface: failed to parse {filepath}: {e}", file=sys.stderr) + finally: + if on_progress is not None: + on_progress(f) + return records diff --git a/src/codesurface/parsers/csharp.py b/src/codesurface/parsers/csharp.py index c7daf50..dbd8c0d 100644 --- a/src/codesurface/parsers/csharp.py +++ b/src/codesurface/parsers/csharp.py @@ -4,6 +4,7 @@ with their full signatures. Doc comments (///) are extracted as bonus data. """ +import os import re from pathlib import Path @@ -112,11 +113,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_cs_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .cs file and extract all public members.""" try: - text = path.read_text(encoding="utf-8-sig", errors="replace") + with open(path, encoding="utf-8-sig", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records = [] diff --git a/src/codesurface/parsers/go.py b/src/codesurface/parsers/go.py index e6be2f3..30a143a 100644 --- a/src/codesurface/parsers/go.py +++ b/src/codesurface/parsers/go.py @@ -8,21 +8,13 @@ Doc comments are consecutive // lines immediately before a declaration. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "vendor", "testdata", ".git", "node_modules", "third_party", - "examples", "example", -}) - -_SKIP_FILE_SUFFIX = "_test.go" - # Go reserved words that can't be identifiers _GO_KEYWORDS = frozenset({ "func", "type", "var", "const", "import", "package", @@ -141,24 +133,12 @@ class GoParser(BaseParser): def file_extensions(self) -> list[str]: return [".go"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip vendor/testdata/test files.""" - records: list[dict] = [] - for f in sorted(directory.rglob("*.go")): - parts = f.relative_to(directory).parts - # Skip dirs - if any(p in _SKIP_DIRS or p.startswith("_") for p in parts): - continue - # Skip test files - if f.name.endswith(_SKIP_FILE_SUFFIX): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_suffixes(self) -> tuple[str, ...]: + return ("_test.go",) + + def _should_skip_dir(self, name: str) -> bool: + return name.startswith("_") def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_go_file(path, base_dir) @@ -172,11 +152,12 @@ def _is_exported(name: str) -> bool: def _parse_go_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .go file and extract exported API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = path.relative_to(base_dir).as_posix() + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() # Skip generated files (Go standard: "Code generated ... DO NOT EDIT") diff --git a/src/codesurface/parsers/java.py b/src/codesurface/parsers/java.py index f29390e..8aee366 100644 --- a/src/codesurface/parsers/java.py +++ b/src/codesurface/parsers/java.py @@ -5,21 +5,13 @@ Javadoc comments (/** ... */) are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "test", "tests", "target", "build", ".gradle", ".git", - "node_modules", ".mvn", "out", "generated", - "generated-sources", "generated-test-sources", - ".idea", "bin", -}) - _SKIP_SUFFIXES = ( "Test.java", "Tests.java", @@ -131,25 +123,13 @@ class JavaParser(BaseParser): def file_extensions(self) -> list[str]: return [".java"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip test/build directories.""" - records = [] - for f in sorted(directory.rglob("*.java")): - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - fname = f.name - if fname in _SKIP_FILES: - continue - if any(fname.endswith(s) for s in _SKIP_SUFFIXES): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_suffixes(self) -> tuple[str, ...]: + return _SKIP_SUFFIXES + + @property + def skip_files(self) -> frozenset[str]: + return _SKIP_FILES def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_java_file(path, base_dir) @@ -158,11 +138,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_java_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .java file and extract public API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records: list[dict] = [] diff --git a/src/codesurface/parsers/python_parser.py b/src/codesurface/parsers/python_parser.py index 459fd96..9fbc068 100644 --- a/src/codesurface/parsers/python_parser.py +++ b/src/codesurface/parsers/python_parser.py @@ -5,20 +5,13 @@ Docstrings are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser - -# --- Skip patterns --- - -_SKIP_DIRS = frozenset({ - "__pycache__", ".git", ".venv", "venv", "env", - "node_modules", ".tox", ".mypy_cache", ".pytest_cache", - "dist", "build", "egg-info", -}) - _SKIP_FILES = frozenset({ "setup.py", "conftest.py", }) @@ -73,25 +66,12 @@ class PythonParser(BaseParser): def file_extensions(self) -> list[str]: return [".py"] - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip common non-source directories.""" - records = [] - for f in sorted(directory.rglob("*.py")): - # Skip files in excluded directories - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - if any(p.endswith(".egg-info") for p in parts): - continue - if f.name in _SKIP_FILES: - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + @property + def skip_files(self) -> frozenset[str]: + return _SKIP_FILES + + def _should_skip_dir(self, name: str) -> bool: + return name.endswith(".egg-info") def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_py_file(path, base_dir) @@ -100,11 +80,12 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_py_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single .py file and extract public API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records = [] @@ -690,8 +671,8 @@ def _file_to_module(path: Path, base_dir: Path) -> str: Walks up from the file looking for __init__.py to determine package boundaries. """ - rel = path.relative_to(base_dir) - parts = list(rel.parts) + rel = os.path.relpath(path, base_dir).replace("\\", "/") + parts = rel.split("/") # Remove .py extension from last part if parts and parts[-1].endswith(".py"): diff --git a/src/codesurface/parsers/typescript.py b/src/codesurface/parsers/typescript.py index 41eb435..4065a50 100644 --- a/src/codesurface/parsers/typescript.py +++ b/src/codesurface/parsers/typescript.py @@ -5,7 +5,9 @@ JSDoc comments (/** ... */) are extracted as summaries. """ +import os import re +import sys from pathlib import Path from .base import BaseParser @@ -13,16 +15,8 @@ # --- Skip patterns --- -_SKIP_DIRS = frozenset({ - "node_modules", "dist", "build", ".git", ".next", - "__tests__", "__mocks__", "coverage", ".turbo", ".cache", - ".tox", ".mypy_cache", ".venv", "venv", -}) - _SKIP_SUFFIXES = ( ".d.ts", - ".test.ts", ".test.tsx", - ".spec.ts", ".spec.tsx", ".stories.ts", ".stories.tsx", ) @@ -147,26 +141,11 @@ class TypeScriptParser(BaseParser): @property def file_extensions(self) -> list[str]: - return [".ts", ".tsx"] - - def parse_directory(self, directory: Path) -> list[dict]: - """Override to skip test/build/node_modules directories.""" - records = [] - for ext in self.file_extensions: - for f in sorted(directory.rglob(f"*{ext}")): - parts = f.relative_to(directory).parts - if any(p in _SKIP_DIRS for p in parts): - continue - fname = f.name - if any(fname.endswith(s) for s in _SKIP_SUFFIXES): - continue - try: - records.extend(self.parse_file(f, directory)) - except Exception as e: - import sys - print(f"codesurface: failed to parse {f}: {e}", file=sys.stderr) - continue - return records + return [".ts", ".tsx", ".js", ".jsx"] + + @property + def skip_suffixes(self) -> tuple[str, ...]: + return _SKIP_SUFFIXES def parse_file(self, path: Path, base_dir: Path) -> list[dict]: return _parse_ts_file(path, base_dir) @@ -175,15 +154,16 @@ def parse_file(self, path: Path, base_dir: Path) -> list[dict]: def _parse_ts_file(path: Path, base_dir: Path) -> list[dict]: """Parse a single TypeScript file and extract exported API members.""" try: - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, encoding="utf-8", errors="replace") as fh: + text = fh.read() except (OSError, UnicodeDecodeError): return [] - rel_path = str(path.relative_to(base_dir)).replace("\\", "/") + rel_path = os.path.relpath(path, base_dir).replace("\\", "/") lines = text.splitlines() records: list[dict] = [] - namespace = _file_to_module(path, base_dir) + namespace = _file_to_module(rel_path) # State class_stack: list[tuple[str, str, int]] = [] # (name, kind, brace_depth) @@ -191,8 +171,9 @@ def _parse_ts_file(path: Path, base_dir: Path) -> list[dict]: paren_depth = 0 # track () to skip multi-line parameter lists in_multiline_comment = False + n_lines = len(lines) i = 0 - while i < len(lines): + while i < n_lines: line = lines[i] # --- Multi-line comment tracking --- @@ -1020,52 +1001,33 @@ def _extract_field_type(stripped: str, field_name: str) -> str: # --- Brace counting --- -def _count_braces_and_parens(line: str) -> tuple[int, int]: - """Count net brace and paren depth changes, skipping inside strings.""" - brace_depth = 0 - paren_depth = 0 - in_single = False - in_double = False - in_template = False - escape = False - - for ch in line: - if escape: - escape = False - continue - if ch == "\\": - escape = True - continue +# Strip string literals (single, double, template) including escaped quotes. +# Order matters: escaped quotes must be consumed before quote boundaries. +_STRING_RE = re.compile( + r"""'(?:[^'\\]|\\.)*'?""" # single-quoted (possibly unterminated) + r'|"(?:[^"\\]|\\.)*"?' # double-quoted + r"|`(?:[^`\\]|\\.)*`?" # template literal (simplified: ignores ${}) + r"|//.*$", # line comments also contain no braces + re.DOTALL, +) - if in_single: - if ch == "'": - in_single = False - continue - if in_double: - if ch == '"': - in_double = False - continue - if in_template: - if ch == "`": - in_template = False - continue - if ch == "'": - in_single = True - elif ch == '"': - in_double = True - elif ch == "`": - in_template = True - elif ch == "{": - brace_depth += 1 - elif ch == "}": - brace_depth -= 1 - elif ch == "(": - paren_depth += 1 - elif ch == ")": - paren_depth -= 1 +_HAS_STRING_OR_COMMENT = frozenset("'\"`/") + - return brace_depth, paren_depth +def _count_braces_and_parens(line: str) -> tuple[int, int]: + """Count net brace and paren depth changes, skipping inside strings.""" + # Fast path: most lines have no strings or comments — skip regex entirely + if _HAS_STRING_OR_COMMENT.isdisjoint(line): + return ( + line.count("{") - line.count("}"), + line.count("(") - line.count(")"), + ) + stripped = _STRING_RE.sub("", line) + return ( + stripped.count("{") - stripped.count("}"), + stripped.count("(") - stripped.count(")"), + ) # --- Visibility --- @@ -1176,25 +1138,19 @@ def _split_params(params_str: str) -> list[str]: # --- File path to module --- -def _file_to_module(path: Path, base_dir: Path) -> str: - """Convert file path to a dot-separated module namespace. +def _file_to_module(rel_path: str) -> str: + """Convert relative file path to a dot-separated module namespace. src/services/myService.ts -> src.services.myService """ - rel = path.relative_to(base_dir) - parts = list(rel.parts) - - if not parts: - return "" - - # Remove extension from last part - last = parts[-1] + # Strip extension for ext in (".tsx", ".ts"): - if last.endswith(ext): - parts[-1] = last[:-len(ext)] + if rel_path.endswith(ext): + rel_path = rel_path[:-len(ext)] break - # Drop index files (index.ts re-exports) + # Split on / and drop trailing "index" + parts = rel_path.split("/") if parts and parts[-1] == "index": parts = parts[:-1] diff --git a/src/codesurface/server.py b/src/codesurface/server.py index 89dd355..1dbaeb0 100644 --- a/src/codesurface/server.py +++ b/src/codesurface/server.py @@ -2,6 +2,7 @@ import argparse import json +import os import re import sys import time @@ -10,6 +11,7 @@ from mcp.server.fastmcp import FastMCP from . import db +from .filters import PathFilter from .parsers import all_extensions, detect_languages, get_parser, get_parsers_for_project mcp = FastMCP( @@ -25,6 +27,32 @@ _project_path: Path | None = None _file_mtimes: dict[str, float] = {} # rel_path → mtime _index_fresh: bool = True # True = checked for changes since last hit; skip auto-reindex +_path_filter: PathFilter | None = None + + +def _count_files( + project_path: Path, + parsers: list, + path_filter: "PathFilter | None", +) -> int: + """Quick pre-scan: count source files that will be parsed.""" + extensions = set() + for p in parsers: + extensions.update(p.file_extensions) + exts = tuple(extensions) + total = 0 + for root, dirs, files in os.walk(project_path): + root_path = Path(root) + if path_filter is not None: + dirs[:] = [d for d in dirs if not path_filter.is_dir_excluded(root_path / d)] + for filename in files: + if not filename.endswith(exts): + continue + f = root_path / filename + if path_filter is not None and path_filter.is_file_excluded(f): + continue + total += 1 + return total def _index_full(project_path: Path, language: str | None = None) -> str: @@ -35,28 +63,65 @@ def _index_full(project_path: Path, language: str | None = None) -> str: if language: parsers = [get_parser(language)] else: - parsers = get_parsers_for_project(project_path) + parsers = get_parsers_for_project(project_path, path_filter=_path_filter) if not parsers: return "No supported source files detected in project directory." + total = _count_files(project_path, parsers, _path_filter) + print(f"[codesurface] scanning {total:,} files...", file=sys.stderr, flush=True) + + parsed = [0] + last_pct = [0.0] + last_time = [t0] + + # Emit the 0% line immediately + print( + f"[codesurface] indexing: 0% ({0:>6,} / {total:,})", + file=sys.stderr, + flush=True, + ) + + def on_progress(f: Path) -> None: + parsed[0] += 1 + now = time.perf_counter() + pct = parsed[0] / max(total, 1) + elapsed = now - t0 + if pct - last_pct[0] >= 0.05 or now - last_time[0] >= 3.0: + print( + f"[codesurface] indexing: {pct:3.0%} ({parsed[0]:>6,} / {total:,}) {elapsed:.1f}s", + file=sys.stderr, + flush=True, + ) + last_pct[0] = pct + last_time[0] = now + records = [] for parser in parsers: - records.extend(parser.parse_directory(project_path)) + records.extend( + parser.parse_directory(project_path, path_filter=_path_filter, on_progress=on_progress) + ) parse_time = time.perf_counter() - t0 t1 = time.perf_counter() _conn = db.create_memory_db(records) db_time = time.perf_counter() - t1 - # Snapshot mtimes for all registered extensions used + # Snapshot mtimes for all registered extensions (pruning excluded dirs) extensions = set() for parser in parsers: extensions.update(parser.file_extensions) + exts = tuple(extensions) _file_mtimes = {} - for ext in extensions: - for f in sorted(project_path.rglob(f"*{ext}")): + for root, dirs, files in os.walk(project_path): + root_path = Path(root) + if _path_filter is not None: + dirs[:] = [d for d in dirs if not _path_filter.is_dir_excluded(root_path / d)] + for filename in files: + if not filename.endswith(exts): + continue + f = root_path / filename rel = str(f.relative_to(project_path)).replace("\\", "/") try: _file_mtimes[rel] = f.stat().st_mtime @@ -65,11 +130,12 @@ def _index_full(project_path: Path, language: str | None = None) -> str: stats = db.get_stats(_conn) langs = ", ".join(type(p).__name__.replace("Parser", "") for p in parsers) - return ( - f"Indexed {stats['total']} records from {stats.get('files', 0)} files " - f"({langs}) in {parse_time + db_time:.2f}s " - f"(parse: {parse_time:.2f}s, db: {db_time:.2f}s)" + summary = ( + f"[codesurface] done: {stats['total']:,} records from {stats.get('files', 0):,} files " + f"({langs}) in {parse_time + db_time:.2f}s" ) + print(summary, file=sys.stderr, flush=True) + return summary def _index_incremental(project_path: Path) -> tuple[str, bool]: @@ -84,12 +150,20 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: t0 = time.perf_counter() # Collect all registered extensions - extensions = set(all_extensions()) + exts = tuple(all_extensions()) - # Scan current files + # Scan current files, pruning excluded directories during walk current: dict[str, float] = {} - for ext in extensions: - for f in sorted(project_path.rglob(f"*{ext}")): + for root, dirs, files in os.walk(project_path): + root_path = Path(root) + if _path_filter is not None: + dirs[:] = [d for d in dirs if not _path_filter.is_dir_excluded(root_path / d)] + for filename in files: + if not filename.endswith(exts): + continue + f = root_path / filename + if _path_filter is not None and _path_filter.is_file_excluded(f): + continue rel = str(f.relative_to(project_path)).replace("\\", "/") try: current[rel] = f.stat().st_mtime @@ -120,7 +194,7 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: db.delete_by_files(_conn, list(stale)) # Build extension-to-parser map for dirty files - parsers = get_parsers_for_project(project_path) + parsers = get_parsers_for_project(project_path, path_filter=_path_filter) ext_to_parser: dict[str, object] = {} for parser in parsers: for ext in parser.file_extensions: @@ -129,7 +203,7 @@ def _index_incremental(project_path: Path) -> tuple[str, bool]: # Parse dirty files new_records = [] for rel in sorted(dirty): - full_path = project_path / rel.replace("/", "\\") + full_path = project_path / Path(rel) suffix = full_path.suffix parser = ext_to_parser.get(suffix) if parser is None: @@ -199,6 +273,14 @@ def _pick_primary_namespace(namespaces: list[str], members: list[dict]) -> str | return None +def _is_test_file(file_path: str) -> bool: + """Check if a file path looks like a test file.""" + for name in db._TEST_DIR_NAMES: + if file_path.startswith(f"{name}/") or f"/{name}/" in file_path: + return True + return any(pat in file_path for pat in db._TEST_FILE_PATTERNS) + + def _format_file_location(r: dict) -> str: """Format file path with optional line range from a record.""" fp = r.get("file_path", "") @@ -261,6 +343,8 @@ def search( query: str, n_results: int = 5, member_type: str | None = None, + file_path: str | None = None, + include_tests: bool = False, ) -> str: """Search the indexed API by keyword. @@ -271,17 +355,22 @@ def search( query: Search terms (e.g. "MergeService", "BlastBoard", "GridCoord") n_results: Max results to return (default 5, max 20) member_type: Optional filter — "type", "method", "property", "field", or "event" + file_path: Optional path prefix or exact file to scope results + (e.g. "src/services/" or "src/services/foo.ts") + include_tests: If true, include test files in results (default false) """ if _conn is None: return "No codebase indexed. Start the server with --project ." global _index_fresh n_results = min(max(n_results, 1), 20) - results = db.search(_conn, query, n=n_results, member_type=member_type) + results = db.search(_conn, query, n=n_results, member_type=member_type, + file_path=file_path, include_tests=include_tests) if not results: if _auto_reindex(): - results = db.search(_conn, query, n=n_results, member_type=member_type) + results = db.search(_conn, query, n=n_results, member_type=member_type, + file_path=file_path, include_tests=include_tests) if not results: return f"No results found for '{query}'. Try broader search terms." @@ -296,7 +385,8 @@ def search( @mcp.tool() -def get_signature(name: str) -> str: +def get_signature(name: str, file_path: str | None = None, + include_tests: bool = False) -> str: """Look up the exact signature of an API member by name or FQN. Use when you need exact parameter types, return types, or method signatures @@ -304,6 +394,8 @@ def get_signature(name: str) -> str: Args: name: Member name or FQN, e.g. "TryMerge", "CampGame.Services.IMergeService.TryMerge" + file_path: Optional path prefix to scope the lookup + include_tests: If true, include test files in results (default false) """ global _index_fresh if _conn is None: @@ -313,12 +405,32 @@ def _lookup() -> str | None: # 1. Exact FQN match record = db.get_by_fqn(_conn, name) if record: - return _format_record(record) + # Skip test-file results unless include_tests + if not include_tests and _is_test_file(record.get("file_path", "")): + pass # fall through to substring search + else: + return _format_record(record) # 2. Substring match (overloads or partial FQN) + file_clause = "" + file_params: list = [] + if file_path: + if file_path.endswith("/"): + file_clause = " AND file_path LIKE ?" + file_params = [file_path + "%"] + else: + file_clause = " AND (file_path = ? OR file_path LIKE ?)" + file_params = [file_path, file_path + "/%"] + + test_clauses: list[str] = [] + test_params: list = [] + if not include_tests: + db._add_test_exclusion(test_clauses, test_params) + test_clause = "".join(f" AND {c}" for c in test_clauses) + rows = _conn.execute( - "SELECT * FROM api_records WHERE fqn LIKE ? ORDER BY fqn", - (f"%{name}%",), + f"SELECT * FROM api_records WHERE fqn LIKE ?{file_clause}{test_clause} ORDER BY fqn", + (f"%{name}%", *file_params, *test_params), ).fetchall() if rows: parts = [f"Found {len(rows)} match(es) for '{name}':\n"] @@ -330,7 +442,7 @@ def _lookup() -> str | None: return "\n".join(parts) # 3. FTS fallback - results = db.search(_conn, name, n=5) + results = db.search(_conn, name, n=5, file_path=file_path, include_tests=include_tests) if results: parts = [f"No exact match for '{name}'. Did you mean:\n"] for r in results: @@ -351,7 +463,8 @@ def _lookup() -> str | None: @mcp.tool() -def get_class(class_name: str) -> str: +def get_class(class_name: str, file_path: str | None = None, + include_tests: bool = False) -> str: """Get a complete reference card for a class — all public members. Shows every method, property, field, and event with signatures. @@ -359,6 +472,8 @@ def get_class(class_name: str) -> str: Args: class_name: Class name, e.g. "BlastBoardModel", "IMergeService", "CampGridService" + file_path: Optional path prefix to scope the lookup + include_tests: If true, include test files in results (default false) """ global _index_fresh if _conn is None: @@ -373,13 +488,16 @@ def get_class(class_name: str) -> str: else: short_name = re.split(r"[.:]", class_name)[-1] - members = db.get_class_members(_conn, short_name, namespace=ns_filter) + members = db.get_class_members(_conn, short_name, file_path=file_path, + namespace=ns_filter, include_tests=include_tests) if not members: if _auto_reindex(): - members = db.get_class_members(_conn, short_name, namespace=ns_filter) + members = db.get_class_members(_conn, short_name, file_path=file_path, + namespace=ns_filter, include_tests=include_tests) if not members: - results = db.search(_conn, class_name, n=5, member_type="type") + results = db.search(_conn, class_name, n=5, member_type="type", + file_path=file_path, include_tests=include_tests) if results: parts = [f"No class '{class_name}' found. Did you mean:\n"] for r in results: @@ -521,17 +639,28 @@ def main(): help="Path to source directory to index") parser.add_argument("--language", default=None, help="Language to parse (e.g. csharp). Auto-detected if omitted.") + parser.add_argument("--exclude", default=None, + help="Comma-separated glob patterns to exclude from indexing " + "(e.g. 'tests/**,generated/**')") + parser.add_argument("--include-submodules", action="store_true", default=False, + help="Include git submodules in indexing (excluded by default)") args, remaining = parser.parse_known_args() - global _project_path + global _project_path, _path_filter + + exclude_globs = [g.strip() for g in args.exclude.split(",")] if args.exclude else [] if args.project: - _project_path = Path(args.project) + _project_path = Path(args.project).expanduser().resolve() if not _project_path.is_dir(): print(f"Warning: Project path not found: {args.project}", file=sys.stderr) else: - summary = _index_full(_project_path, language=args.language) - print(summary, file=sys.stderr) + _path_filter = PathFilter( + _project_path, + exclude_globs=exclude_globs, + include_submodules=args.include_submodules, + ) + _index_full(_project_path, language=args.language) mcp.run() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..a9fd61c --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,181 @@ +"""Tests for PathFilter default skip rules.""" +import os +from pathlib import Path +import pytest +from codesurface.filters import PathFilter + + +@pytest.fixture +def tmp_project(tmp_path): + """Project root with a variety of subdirectories.""" + # Normal source file + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.ts").write_text("export class Foo {}") + + # .worktrees directory (should always be skipped) + wt = tmp_path / ".worktrees" / "pr-42" + wt.mkdir(parents=True) + (wt / "src").mkdir() + (wt / "src" / "main.ts").write_text("export class Bar {}") + # .git FILE in worktree (git worktree marker) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-42\n") + + # Submodule (should be skipped by default) + sub = tmp_path / "vendor" / "mylib" + sub.mkdir(parents=True) + (sub / "lib.ts").write_text("export class Lib {}") + (sub / ".git").write_text("gitdir: /repo/.git/modules/mylib\n") + + # Regular nested dir (should NOT be skipped) + (tmp_path / "packages" / "core").mkdir(parents=True) + (tmp_path / "packages" / "core" / "index.ts").write_text("export class Core {}") + + return tmp_path + + +def test_worktrees_dir_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / ".worktrees") + + +def test_worktree_subdir_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / ".worktrees" / "pr-42") + + +def test_git_file_worktree_skipped(tmp_project): + pf = PathFilter(tmp_project) + wt = tmp_project / ".worktrees" / "pr-42" + assert pf.is_dir_excluded(wt) + + +def test_submodule_skipped_by_default(tmp_project): + pf = PathFilter(tmp_project) + assert pf.is_dir_excluded(tmp_project / "vendor" / "mylib") + + +def test_submodule_included_when_opted_in(tmp_project): + pf = PathFilter(tmp_project, include_submodules=True) + assert not pf.is_dir_excluded(tmp_project / "vendor" / "mylib") + + +def test_worktree_still_skipped_even_with_include_submodules(tmp_project): + pf = PathFilter(tmp_project, include_submodules=True) + wt = tmp_project / ".worktrees" / "pr-42" + assert pf.is_dir_excluded(wt) + + +def test_normal_dir_not_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert not pf.is_dir_excluded(tmp_project / "packages" / "core") + + +def test_src_dir_not_skipped(tmp_project): + pf = PathFilter(tmp_project) + assert not pf.is_dir_excluded(tmp_project / "src") + + +def test_exclude_glob_skips_matching_file(tmp_project): + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + (tmp_project / "tests").mkdir() + test_file = tmp_project / "tests" / "foo.ts" + test_file.write_text("") + assert pf.is_file_excluded(test_file) + + +def test_exclude_glob_does_not_skip_nonmatching(tmp_project): + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + assert not pf.is_file_excluded(tmp_project / "src" / "main.ts") + + +def test_codesurfaceignore_loaded(tmp_project): + (tmp_project / ".codesurfaceignore").write_text("generated/**\n# comment\n\n") + pf = PathFilter(tmp_project) + gen_file = tmp_project / "generated" / "types.ts" + assert pf.is_file_excluded(gen_file) + + +def test_codesurfaceignore_and_cli_globs_merged(tmp_project): + (tmp_project / ".codesurfaceignore").write_text("generated/**\n") + pf = PathFilter(tmp_project, exclude_globs=["tests/**"]) + gen_file = tmp_project / "generated" / "types.ts" + test_file = tmp_project / "tests" / "foo.ts" + assert pf.is_file_excluded(gen_file) + assert pf.is_file_excluded(test_file) + + +def test_codesurfaceignore_missing_is_fine(tmp_project): + # No .codesurfaceignore present — should not raise + pf = PathFilter(tmp_project) + assert not pf.is_file_excluded(tmp_project / "src" / "main.ts") + + +# ---- Query-time file_path filtering ---- +from codesurface import db as csdb + + +def _make_db(): + records = [ + { + "fqn": "Services.FooService", + "namespace": "Services", + "class_name": "FooService", + "member_name": "FooService", + "member_type": "type", + "signature": "class FooService", + "file_path": "src/services/foo.ts", + "line_start": 1, + "line_end": 10, + }, + { + "fqn": "Utils.BarUtil", + "namespace": "Utils", + "class_name": "BarUtil", + "member_name": "BarUtil", + "member_type": "type", + "signature": "class BarUtil", + "file_path": "src/utils/bar.ts", + "line_start": 1, + "line_end": 5, + }, + ] + return csdb.create_memory_db(records) + + +def test_search_file_path_prefix_filters(): + conn = _make_db() + results = csdb.search(conn, "Service", file_path="src/services/") + assert len(results) == 1 + assert results[0]["class_name"] == "FooService" + + +def test_search_file_path_exact_file(): + conn = _make_db() + results = csdb.search(conn, "Bar", file_path="src/utils/bar.ts") + assert len(results) == 1 + assert results[0]["class_name"] == "BarUtil" + + +def test_search_file_path_no_match_returns_empty(): + conn = _make_db() + results = csdb.search(conn, "Foo", file_path="src/utils/") + assert len(results) == 0 + + +def test_search_no_file_path_returns_all(): + conn = _make_db() + results = csdb.search(conn, "Service OR Bar", file_path=None) + assert len(results) == 2 + + +def test_get_class_members_file_path_prefix(): + conn = _make_db() + results = csdb.get_class_members(conn, "FooService", file_path="src/services/") + assert len(results) == 1 + assert results[0]["class_name"] == "FooService" + + +def test_get_class_members_file_path_no_match(): + conn = _make_db() + results = csdb.get_class_members(conn, "FooService", file_path="src/utils/") + assert len(results) == 0 diff --git a/tests/test_include_tests.py b/tests/test_include_tests.py new file mode 100644 index 0000000..d6bae45 --- /dev/null +++ b/tests/test_include_tests.py @@ -0,0 +1,134 @@ +"""Tests for the include_tests parameter on search and get_class_members.""" + +from codesurface import db + + +def _make_record(fqn, file_path, class_name="MyClass", member_name="myMethod", + member_type="method"): + return { + "fqn": fqn, + "namespace": "", + "class_name": class_name, + "member_name": member_name, + "member_type": member_type, + "signature": f"void {member_name}()", + "summary": "", + "params_json": [], + "returns_text": "", + "file_path": file_path, + "line_start": 1, + "line_end": 10, + } + + +def _setup_db(): + """Create a DB with both source and test records.""" + records = [ + # Source files + _make_record("MyClass", "src/MyClass.ts", member_type="type", + member_name="MyClass"), + _make_record("MyClass.foo", "src/MyClass.ts", member_name="foo"), + _make_record("MyClass.bar", "src/MyClass.ts", member_name="bar"), + # .test. pattern + _make_record("MyClass.testHelper", "src/MyClass.test.ts", + member_name="testHelper"), + # .spec. pattern + _make_record("MyClass.specHelper", "src/MyClass.spec.ts", + member_name="specHelper"), + # __tests__/ directory pattern + _make_record("MyClass.dirTest", "src/__tests__/MyClass.ts", + member_name="dirTest"), + # test_ filename pattern (Python convention) + _make_record("TestUtils", "src/test_utils.py", class_name="TestUtils", + member_type="type", member_name="TestUtils"), + _make_record("TestUtils.setup", "src/test_utils.py", + class_name="TestUtils", member_name="setup"), + # _test. filename pattern (Go convention) + _make_record("MyClass.goTest", "src/myclass_test.go", + member_name="goTest"), + # /tests/ directory pattern + _make_record("Fixtures", "tests/fixtures.ts", class_name="Fixtures", + member_type="type", member_name="Fixtures"), + ] + return db.create_memory_db(records) + + +class TestSearchIncludeTests: + def test_excludes_test_files_by_default(self): + conn = _setup_db() + results = db.search(conn, "MyClass", include_tests=False) + paths = [r["file_path"] for r in results] + assert all( + ".test." not in p + and ".spec." not in p + and "__tests__" not in p + and "_test." not in p + for p in paths + ) + + def test_includes_test_files_when_requested(self): + conn = _setup_db() + results = db.search(conn, "MyClass", include_tests=True) + paths = [r["file_path"] for r in results] + assert any(".test." in p or "__tests__" in p or "_test." in p for p in paths) + + def test_excludes_tests_dir(self): + conn = _setup_db() + results = db.search(conn, "Fixtures", include_tests=False) + paths = [r["file_path"] for r in results] + assert all("tests/" not in p for p in paths) + + def test_includes_tests_dir_when_requested(self): + conn = _setup_db() + results = db.search(conn, "Fixtures", include_tests=True) + paths = [r["file_path"] for r in results] + assert any("tests/" in p for p in paths) + + def test_excludes_test_underscore_prefix(self): + conn = _setup_db() + results = db.search(conn, "TestUtils", include_tests=False) + paths = [r["file_path"] for r in results] + assert all("/test_" not in p for p in paths) + + def test_excludes_go_test_suffix(self): + conn = _setup_db() + results = db.search(conn, "goTest", include_tests=False) + assert len(results) == 0 + + def test_does_not_exclude_non_test_dirs_containing_test(self): + """A dir like _test_fixture/ should NOT be excluded.""" + conn = _setup_db() + records = [ + _make_record("Calc", "_test_fixture/calculator.py", + class_name="Calc", member_type="type", + member_name="Calc"), + ] + db.insert_records(conn, records) + results = db.search(conn, "Calc", include_tests=False) + assert len(results) == 1 + assert results[0]["file_path"] == "_test_fixture/calculator.py" + + +class TestGetClassMembersIncludeTests: + def test_excludes_test_files_by_default(self): + conn = _setup_db() + members = db.get_class_members(conn, "MyClass", include_tests=False) + paths = [m["file_path"] for m in members] + assert all( + ".test." not in p + and ".spec." not in p + and "__tests__" not in p + for p in paths + ) + assert len(members) == 3 # type + foo + bar + + def test_includes_test_files_when_requested(self): + conn = _setup_db() + members = db.get_class_members(conn, "MyClass", include_tests=True) + assert len(members) == 7 # type + foo + bar + testHelper + specHelper + dirTest + goTest + + def test_default_is_exclude(self): + conn = _setup_db() + members_default = db.get_class_members(conn, "MyClass") + members_explicit = db.get_class_members(conn, "MyClass", include_tests=False) + assert len(members_default) == len(members_explicit) diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 0000000..4c19c26 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,115 @@ +"""Tests for PathFilter integration with parse_directory.""" +from pathlib import Path +import pytest +from codesurface.filters import PathFilter +from codesurface.parsers.typescript import TypeScriptParser +from codesurface.parsers.python_parser import PythonParser +from codesurface.parsers.go import GoParser +from codesurface.parsers.java import JavaParser + + +@pytest.fixture +def ts_project(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "service.ts").write_text( + "export class FooService { bar(): void {} }" + ) + # A worktree that should be skipped + wt = tmp_path / ".worktrees" / "pr-1" + wt.mkdir(parents=True) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-1\n") + (wt / "service.ts").write_text( + "export class WtService { baz(): void {} }" + ) + # A generated file that should be skipped + (tmp_path / "src" / "gen.ts").write_text( + "export class Generated {}" + ) + return tmp_path + + +def test_worktree_files_not_indexed(ts_project): + pf = PathFilter(ts_project) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "WtService" not in names + + +def test_src_files_indexed(ts_project): + pf = PathFilter(ts_project) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "FooService" in names + + +def test_excluded_file_not_indexed(ts_project): + pf = PathFilter(ts_project, exclude_globs=["src/gen.ts"]) + parser = TypeScriptParser() + records = parser.parse_directory(ts_project, path_filter=pf) + names = [r["class_name"] for r in records] + assert "Generated" not in names + + +def test_no_filter_indexes_worktrees_too(ts_project): + # Without PathFilter, worktrees are NOT excluded — old behaviour preserved + parser = TypeScriptParser() + records = parser.parse_directory(ts_project) + names = [r["class_name"] for r in records] + assert "FooService" in names + # WtService IS found without a filter (old behaviour) + assert "WtService" in names + + +def test_on_progress_called_per_file(ts_project): + """on_progress is called once per successfully parsed file.""" + parser = TypeScriptParser() + visited = [] + parser.parse_directory(ts_project, on_progress=lambda f: visited.append(f)) + # ts_project has service.ts, gen.ts, and a worktree service.ts (3 .ts files total without filter) + assert len(visited) == 3 + assert all(isinstance(f, Path) for f in visited) + + +def test_on_progress_none_is_default(ts_project): + """Omitting on_progress works exactly as before.""" + parser = TypeScriptParser() + records = parser.parse_directory(ts_project) + assert len(records) > 0 + + +@pytest.fixture +def py_project(tmp_path): + (tmp_path / "mod.py").write_text("def hello(): pass\n") + return tmp_path + + +def test_typescript_on_progress(ts_project): + parser = TypeScriptParser() + visited = [] + parser.parse_directory(ts_project, on_progress=lambda f: visited.append(f)) + assert len(visited) >= 1 + + +def test_python_on_progress(py_project): + parser = PythonParser() + visited = [] + parser.parse_directory(py_project, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 + + +def test_go_on_progress(tmp_path): + (tmp_path / "main.go").write_text("package main\nfunc Hello() {}\n") + parser = GoParser() + visited = [] + parser.parse_directory(tmp_path, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 + + +def test_java_on_progress(tmp_path): + (tmp_path / "Foo.java").write_text("public class Foo { public void bar() {} }\n") + parser = JavaParser() + visited = [] + parser.parse_directory(tmp_path, on_progress=lambda f: visited.append(f)) + assert len(visited) == 1 diff --git a/tests/test_server_progress.py b/tests/test_server_progress.py new file mode 100644 index 0000000..a2dfef4 --- /dev/null +++ b/tests/test_server_progress.py @@ -0,0 +1,55 @@ +"""Tests for _count_files and _index_full progress output.""" +from pathlib import Path +import pytest + + +def test_count_files_basic(tmp_path): + """_count_files counts matching source files, ignores other extensions.""" + from codesurface.server import _count_files + from codesurface.parsers.python_parser import PythonParser + + (tmp_path / "a.py").write_text("x = 1") + (tmp_path / "b.py").write_text("y = 2") + (tmp_path / "c.txt").write_text("ignored") + + parsers = [PythonParser()] + assert _count_files(tmp_path, parsers, path_filter=None) == 2 + + +def test_count_files_prunes_excluded_dirs(tmp_path): + """_count_files respects path_filter dir exclusions.""" + from codesurface.server import _count_files + from codesurface.parsers.python_parser import PythonParser + from codesurface.filters import PathFilter + + (tmp_path / "src").mkdir() + (tmp_path / "src" / "a.py").write_text("x = 1") + wt = tmp_path / ".worktrees" / "pr-1" + wt.mkdir(parents=True) + (wt / ".git").write_text("gitdir: /repo/.git/worktrees/pr-1\n") + (wt / "b.py").write_text("y = 2") + + pf = PathFilter(tmp_path) + parsers = [PythonParser()] + assert _count_files(tmp_path, parsers, path_filter=pf) == 1 + + +def test_index_full_emits_progress_to_stderr(tmp_path, capsys): + """_index_full prints at least one progress line and a done line to stderr.""" + from codesurface import server + from codesurface.filters import PathFilter + + for i in range(5): + (tmp_path / f"m{i}.py").write_text(f"def f{i}(): pass\n") + + server._conn = None + server._project_path = tmp_path + server._path_filter = PathFilter(tmp_path) + + server._index_full(tmp_path) + + captured = capsys.readouterr() + assert "[codesurface]" in captured.err + assert "scanning" in captured.err # e.g. "[codesurface] scanning 5 files..." + assert "indexing:" in captured.err # e.g. "[codesurface] indexing: 0% ..." + assert "done:" in captured.err