From 332dd6de4183e3c3501c2f8e17870130459c2806 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Tue, 14 Apr 2026 16:34:10 +0200 Subject: [PATCH 1/7] allow loading in intellij >= 261 --- build.gradle.kts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 52d9419..6992c34 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -15,7 +15,7 @@ plugins { val remoteRobotVersion = "0.11.20" val pluginId = "dev.sweep.assistant" -val pluginName = "Self-Hosted Enterprise Updater" +val pluginName = "Sweep Autocomplete OSS" println("Building plugin: $pluginName with ID: $pluginId") group = "dev.sweep" version = "1.29.3" @@ -67,7 +67,7 @@ intellijPlatform { ) channels.set(listOf(ProductRelease.Channel.RELEASE)) sinceBuild.set("241") - untilBuild.set("253.*") + untilBuild.set("261.*") } } } @@ -86,7 +86,7 @@ tasks { patchPluginXml { sinceBuild.set("241") - untilBuild.set("253.*") + untilBuild.set("261.*") } signPlugin { From 60f24c36d961b5f07af5171a1b82463272b36186 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Tue, 14 Apr 2026 16:34:31 +0200 Subject: [PATCH 2/7] handle EDT exception on load --- .../assistant/services/LocalAutocompleteServerManager.kt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt index 73bfe2f..6236063 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt @@ -389,8 +389,10 @@ class LocalAutocompleteServerManager : Disposable { // Small delay so the shell is ready to accept input, then send command ApplicationManager.getApplication().executeOnPooledThread { Thread.sleep(1000) + // Compute isPowerShell on BGT to avoid EDT threading violation + val isPowerShell = TerminalApiWrapper.isPowerShell(project) ApplicationManager.getApplication().invokeLater { - TerminalApiWrapper.sendCommand(targetWidget, command, project) + TerminalApiWrapper.sendCommand(targetWidget, command, project, isPowerShell) } } logger.info("Started local autocomplete server in terminal: $command") From 7466ddc3b2a261c032f494955e7b41024a62dd56 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Tue, 14 Apr 2026 20:50:59 +0200 Subject: [PATCH 3/7] Allow loading mlx models and use mlx-lm --- .gitignore | 2 + build.gradle.kts | 22 +- .../sweep/assistant/components/SweepConfig.kt | 28 + .../LocalAutocompleteServerManager.kt | 43 +- .../sweep/assistant/settings/SweepSettings.kt | 2 + .../settings/SweepSettingsConfigurable.kt | 10 +- src/main/resources/META-INF/plugin.xml | 4 +- vendor/sweep-autocomplete-mlx/pyproject.toml | 31 + .../sweep_autocomplete/__init__.py | 0 .../sweep_autocomplete/app.py | 106 + .../autocomplete/__init__.py | 0 .../autocomplete/llm_local.py | 111 + .../autocomplete/next_edit_autocomplete.py | 2004 +++++++++++++++++ .../next_edit_autocomplete_retrieval.py | 335 +++ .../next_edit_autocomplete_service.py | 8 + .../next_edit_autocomplete_utils.py | 631 ++++++ .../sweep_autocomplete/cli.py | 15 + .../sweep_autocomplete/config.py | 9 + .../dataclasses/__init__.py | 0 .../dataclasses/file_chunk_data.py | 35 + .../sweep_autocomplete/utils/__init__.py | 0 .../utils/compression_middleware.py | 74 + .../sweep_autocomplete/utils/str_utils.py | 40 + .../sweep_autocomplete/utils/timer.py | 59 + 24 files changed, 3558 insertions(+), 11 deletions(-) create mode 100644 vendor/sweep-autocomplete-mlx/pyproject.toml create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/__init__.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/app.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/__init__.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/llm_local.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_retrieval.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_service.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_utils.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/cli.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/config.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/__init__.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/file_chunk_data.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/__init__.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/compression_middleware.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/str_utils.py create mode 100644 vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/timer.py diff --git a/.gitignore b/.gitignore index e44eae6..608bdb7 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,6 @@ video/*.avi .idea/gradle.xml .kotlin/* vendor/** +!vendor/sweep-autocomplete-mlx/ +!vendor/sweep-autocomplete-mlx/** *.gguf \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 6992c34..01773f7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -130,9 +130,24 @@ tasks { } } + val copyBundledPackagesToSandbox by creating(Copy::class) { + val sandboxPluginDir = + layout.buildDirectory + .dir("idea-sandbox/plugins/${project.name}") + + from("vendor/sweep-autocomplete-mlx") { + into("sweep-autocomplete-mlx") + } + into(sandboxPluginDir) + + doLast { + println("Copied bundled Python packages to sandbox") + } + } + // Hook the copy task to prepareSandbox prepareSandbox { - finalizedBy(copyRipgrepToSandbox) + finalizedBy(copyRipgrepToSandbox, copyBundledPackagesToSandbox) } buildPlugin { @@ -143,6 +158,11 @@ tasks { include("tools/ripgrep/**") into("lib/tools") } + + // Include bundled MLX autocomplete Python package + from("vendor/sweep-autocomplete-mlx") { + into("sweep-autocomplete-mlx") + } } runIde { diff --git a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt index b300867..29f1d42 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt @@ -1041,6 +1041,12 @@ class SweepConfig( SweepSettings.getInstance().autocompleteLocalMode = enabled } + fun isAutocompleteLocalMlx(): Boolean = SweepSettings.getInstance().autocompleteLocalMlx + + fun updateAutocompleteLocalMlx(enabled: Boolean) { + SweepSettings.getInstance().autocompleteLocalMlx = enabled + } + fun getAutocompleteLocalPort(): Int = SweepSettings.getInstance().autocompleteLocalPort fun updateAutocompleteLocalPort(port: Int) { @@ -4836,6 +4842,28 @@ class SweepConfig( border = JBUI.Borders.emptyLeft(24) }, ) + if (System.getProperty("os.name").lowercase().contains("mac")) { + add(Box.createRigidArea(Dimension(0, 4.scaled))) + add( + JCheckBox("Use MLX runtime for NES model").apply { + isSelected = isAutocompleteLocalMlx() + withSweepFont(project) + border = JBUI.Borders.emptyLeft(24) + addActionListener { + updateAutocompleteLocalMlx(isSelected) + } + }, + ) + add(Box.createRigidArea(Dimension(0, 2.scaled))) + add( + JLabel("Uses MLX for faster inference on Apple Silicon (requires macOS + mlx-lm).").apply { + withSweepFont(project, scale = 0.85f) + foreground = JBColor.GRAY + font = font.deriveFont(Font.ITALIC) + border = JBUI.Borders.emptyLeft(48) + }, + ) + } }, gbc, ) diff --git a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt index 6236063..2813d11 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt @@ -10,6 +10,9 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.wm.ToolWindowManager import dev.sweep.assistant.agent.tools.TerminalApiWrapper import dev.sweep.assistant.settings.SweepSettings +import dev.sweep.assistant.utils.SweepConstants +import com.intellij.ide.plugins.PluginManagerCore +import com.intellij.openapi.extensions.PluginId import kotlinx.coroutines.* import org.jetbrains.plugins.terminal.TerminalToolWindowFactory import org.jetbrains.plugins.terminal.TerminalToolWindowManager @@ -130,8 +133,35 @@ class LocalAutocompleteServerManager : Disposable { private val isWindows = System.getProperty("os.name").lowercase().contains("win") - private fun buildUvxCommand(uvxPath: String, port: Int): List = - if (isWindows) { + /** + * Resolves the path to a bundled Python package inside the plugin directory. + * Returns the path if found, null otherwise. + */ + private fun getBundledPackagePath(packageDirName: String): String? { + val pluginId = PluginId.getId(SweepConstants.PLUGIN_ID) + val plugin = PluginManagerCore.getPlugin(pluginId) ?: return null + val pluginPath = plugin.pluginPath ?: return null + val pkgPath = pluginPath.resolve(packageDirName) + return if (pkgPath.toFile().exists()) pkgPath.toString() else null + } + + private fun buildUvxCommand(uvxPath: String, port: Int): List { + val useMlx = SweepSettings.getInstance().autocompleteLocalMlx + + if (useMlx) { + val mlxPath = getBundledPackagePath("sweep-autocomplete-mlx") + if (mlxPath != null) { + return listOf( + uvxPath, + "--from", mlxPath, + "sweep-autocomplete-mlx", + "--port", port.toString(), + ) + } + logger.warn("MLX package not found at plugin path, falling back to default GGUF") + } + + return if (isWindows) { listOf( uvxPath, "--python", "3.12", @@ -142,6 +172,7 @@ class LocalAutocompleteServerManager : Disposable { } else { listOf(uvxPath, "sweep-autocomplete", "--port", port.toString()) } + } private fun startServerProcess(uvxPath: String, onStatus: ((String) -> Unit)? = null) { val port = getPort() @@ -351,7 +382,9 @@ class LocalAutocompleteServerManager : Disposable { return null } } - return buildUvxCommand(uvxPath, getPort()).joinToString(" ") + return buildUvxCommand(uvxPath, getPort()).joinToString(" ") { arg -> + if (arg.contains(" ")) "\"$arg\"" else arg + } } /** @@ -372,7 +405,9 @@ class LocalAutocompleteServerManager : Disposable { .getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID) ?: return@invokeLater // Reuse existing terminal tab if one exists, otherwise create a new one - val existingContent = toolWindow.contentManager.findContent(TERMINAL_TAB_NAME) + val existingContent = toolWindow.contentManager.contents.firstOrNull { + it.displayName == TERMINAL_TAB_NAME + } val widget = if (existingContent != null) { TerminalToolWindowManager.findWidgetByContent(existingContent) } else { diff --git a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt index e5b5f3e..034b143 100644 --- a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt +++ b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt @@ -178,6 +178,8 @@ class SweepSettings : PersistentStateComponent { var autocompleteLocalMode: Boolean = false + var autocompleteLocalMlx: Boolean = false + var autocompleteLocalPort: Int = 8081 fun ensureDefaultPromptsInitialized() { diff --git a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettingsConfigurable.kt b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettingsConfigurable.kt index 197c3a9..ce0f781 100644 --- a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettingsConfigurable.kt +++ b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettingsConfigurable.kt @@ -13,7 +13,6 @@ import java.awt.BorderLayout import java.awt.Component import java.awt.Dimension import java.awt.Font -import java.awt.event.WindowEvent import javax.swing.* class SweepSettingsConfigurable( @@ -53,9 +52,12 @@ class SweepSettingsConfigurable( ) } openConfigButton.addActionListener { - SwingUtilities.getWindowAncestor(openConfigButton)?.dispatchEvent( - WindowEvent(SwingUtilities.getWindowAncestor(openConfigButton), WindowEvent.WINDOW_CLOSING), - ) + // Close the settings dialog by finding and disposing the parent JDialog + val window = SwingUtilities.getWindowAncestor(openConfigButton) + if (window is java.awt.Dialog) { + window.isVisible = false + window.dispose() + } // open tool window and show config popup ToolWindowManager.getInstance(project).getToolWindow(SweepConstants.TOOLWINDOW_NAME)?.show() SweepConfig.getInstance(project).showConfigPopup() diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index d240621..ad5d0dd 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -14,7 +14,7 @@ Simple HTML elements (text formatting, paragraphs, and lists) can be added inside of tag. Guidelines: https://plugins.jetbrains.com/docs/marketplace/plugin-overview-page.html#plugin-description --> This plugin is only for Enterprise customers with private installations (i.e. you've been linked to this). Other users should use the standard extension in the JetBrains Marketplace. + Sweep AI - fast next-edit suggestions powered by local model. ]]> @@ -203,4 +203,4 @@ messages.sweep - \ No newline at end of file + diff --git a/vendor/sweep-autocomplete-mlx/pyproject.toml b/vendor/sweep-autocomplete-mlx/pyproject.toml new file mode 100644 index 0000000..8fc11c8 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "sweep-autocomplete-mlx" +version = "0.1.0" +description = "Local next-edit autocomplete server powered by MLX (Apple Silicon)" +requires-python = ">=3.10" +license = "Apache-2.0" + +dependencies = [ + "fastapi>=0.100.0", + "uvicorn[standard]>=0.23.0", + "python-multipart>=0.0.6", + "loguru>=0.7.0", + "requests>=2.31.0", + "numpy>=1.24.0", + "scipy>=1.11.0", + "regex>=2023.0", + "brotli>=1.1.0", + "pydantic>=2.0.0", + "mlx-lm>=0.20.0", + "huggingface-hub>=0.20.0", +] + +[project.scripts] +sweep-autocomplete-mlx = "sweep_autocomplete.cli:main" + +[tool.setuptools.packages.find] +include = ["sweep_autocomplete*"] diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/__init__.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/app.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/app.py new file mode 100644 index 0000000..b61ddd8 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/app.py @@ -0,0 +1,106 @@ +import json +import time +import traceback + +from fastapi import Body +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse + +from sweep_autocomplete.autocomplete.next_edit_autocomplete import ( + AutocompleteMetadata, + fetch_next_edits, +) +from sweep_autocomplete.dataclasses.file_chunk_data import ( + EditorDiagnostic, + FileChunkData, + UserAction, +) +from sweep_autocomplete.utils.compression_middleware import RequestCompressionMiddleware +from loguru import logger + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(RequestCompressionMiddleware) + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +@app.post("/backend/next_edit_autocomplete", include_in_schema=False) +def next_edit_autocomplete( + file_path: str = Body(...), + file_contents: str = Body(...), + original_file_contents: str = Body(None), + recent_changes: str = Body(...), + cursor_position: int = Body(...), + file_chunks: list[FileChunkData] = Body([]), + retrieval_chunks: list[FileChunkData] = Body([]), + recent_user_actions: list[UserAction] = Body([]), + multiple_suggestions: bool = Body(False), + recent_changes_high_res: str = Body(default=""), + changes_above_cursor: bool = Body(default=True), + editor_diagnostics: list[EditorDiagnostic] = Body(default=[]), +): + function_start_time = time.time() + + def stream(): + metadata: AutocompleteMetadata = AutocompleteMetadata() + + try: + for result, completions, formatted_prompt, metadata in fetch_next_edits( + file_path=file_path, + file_contents=file_contents, + recent_changes=recent_changes, + cursor_position=cursor_position, + original_file_contents=original_file_contents, + file_chunks=file_chunks, + retrieval_chunks=retrieval_chunks, + recent_user_actions=recent_user_actions, + recent_changes_high_res=recent_changes_high_res, + changes_above_cursor=changes_above_cursor, + is_new_user=False, + editor_diagnostics=editor_diagnostics, + ): + data = { + **result.__dict__, + "elapsed_time_ms": int((time.time() - function_start_time) * 1000), + } + logger.debug( + f"Next edit autocomplete took {data['elapsed_time_ms']}ms" + ) + + if multiple_suggestions: + data["completions"] = [ + completion.__dict__ for completion in completions + ] + yield json.dumps(data) + "\n" + + except BaseException as e: + logger.error(f"Next edit autocomplete error: {str(e)}") + yield json.dumps( + { + "status": "error", + "error": f"Next edit autocomplete error: {str(e)}", + "traceback": str(traceback.format_exc()), + } + ) + if not isinstance(e, GeneratorExit): + raise e + finally: + end_time = time.time() + latency_ms = (end_time - function_start_time) * 1000 + logger.debug( + f"Next edit autocomplete took {latency_ms:.2f}ms for finally block:{metadata.convert_to_string()}" + ) + + return StreamingResponse(stream(), media_type="application/x-ndjson") diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/__init__.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/llm_local.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/llm_local.py new file mode 100644 index 0000000..244cf64 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/llm_local.py @@ -0,0 +1,111 @@ +import threading +import time +from typing import Any + +from loguru import logger + +from sweep_autocomplete.config import MODEL_REPO + +_model = None +_tokenizer = None +_model_lock = threading.Lock() +_request_lock = threading.Lock() +_latest_request_id = 0 + + +class RequestCancelled(Exception): + """Raised when a queued request is superseded by a newer one.""" + pass + + +def get_model(): + global _model, _tokenizer + if _model is None: + import mlx_lm + + logger.info(f"Loading MLX model from {MODEL_REPO}") + _model, _tokenizer = mlx_lm.load(MODEL_REPO) + logger.info("MLX model loaded successfully") + return _model, _tokenizer + + +def generate_completion( + prompt: str, + stop: list[str], + max_tokens: int, + temperature: float, + prefix: str = "", +) -> tuple[str, int, list[Any], str | None]: + """Generate a completion using the local MLX model with stream_generate. + + Uses stream_generate for early stop-sequence detection so we don't + waste time generating tokens past a stop sequence. + + Only the latest request will actually run inference. If a newer request + arrives while this one is waiting for the model lock, this request is + cancelled (raises RequestCancelled). + + Returns (completion_text, elapsed_ms, logprobs, finish_reason) + """ + global _latest_request_id + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler + + model, tokenizer = get_model() + full_prompt = prompt + prefix if prefix else prompt + + # Claim a request ID — always monotonically increasing + with _request_lock: + _latest_request_id += 1 + my_id = _latest_request_id + + # Wait for the model. When we get the lock, check if we're still latest. + with _model_lock: + if my_id != _latest_request_id: + logger.info(f"Request {my_id} cancelled (latest is {_latest_request_id})") + raise RequestCancelled() + + tokens = tokenizer.encode(full_prompt) + logger.info(f"Prompt length: {len(full_prompt)} chars, {len(tokens)} tokens") + + start = time.time() + + sampler = make_sampler(temp=temperature if temperature > 0 else 0.0) + + text_parts = [] + finish_reason = "stop" + hit_stop = False + + for response in stream_generate( + model=model, + tokenizer=tokenizer, + prompt=full_prompt, + max_tokens=max_tokens, + sampler=sampler, + prefill_step_size=4096, + ): + text_parts.append(response.text) + + # Check for stop sequences as tokens stream in + accumulated = "".join(text_parts) + for s in stop: + if s in accumulated: + text_parts = [accumulated[:accumulated.index(s)]] + hit_stop = True + break + if hit_stop: + break + + elapsed_ms = int((time.time() - start) * 1000) + + text = "".join(text_parts) + + if not hit_stop: + token_count = len(tokenizer.encode(text, add_special_tokens=False)) + if token_count >= max_tokens: + finish_reason = "length" + + if prefix: + text = prefix + text + + return text, elapsed_ms, [], finish_reason diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete.py new file mode 100644 index 0000000..54ddf66 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete.py @@ -0,0 +1,2004 @@ +from __future__ import annotations + +import json +import math +import re +import time +import uuid +from collections import Counter +from dataclasses import dataclass, field, replace +from typing import Any, Literal, Optional + +import regex +import requests +from pydantic import BaseModel + +from sweep_autocomplete.autocomplete.next_edit_autocomplete_retrieval import ( + find_best_matching_block, +) +from sweep_autocomplete.autocomplete.next_edit_autocomplete_service import ( + next_edit_autocomplete_service, +) +from sweep_autocomplete.autocomplete.next_edit_autocomplete_utils import ( + AUTOCOMPLETE_OUTPUT_MAX_TOKENS, + PromptTooLongError, + adjust_cursor_position_from_utf16, + extract_diff_parts, + filter_whitespace_only_hunks, + get_line_number_from_position, + get_lines_around_cursor, + is_equal_ignoring_newlines, + is_large_diff_above_cursor, + parse_hunk, + should_disable_for_code_block, + split_into_diff_hunks, + split_into_hunks, + strip_leading_empty_newlines, + truncate_long_lines, +) +from sweep_autocomplete.autocomplete.llm_local import generate_completion, RequestCancelled +from sweep_autocomplete.config import NEXT_EDIT_AUTOCOMPLETE_ENDPOINT +from sweep_autocomplete.dataclasses.file_chunk_data import ( + EditorDiagnostic, + FileChunkData, + UserAction, +) +from loguru import logger +from sweep_autocomplete.utils.str_utils import pack_items_for_prompt +from sweep_autocomplete.utils.timer import Timer + +NUM_LINES_BEFORE = 2 +NUM_LINES_AFTER = 5 + +CHARS_PER_TOKEN = 3.5 + +def estimate_token_count(text: str) -> int: + """Estimate token count using character-based approximation.""" + return int(len(text) / CHARS_PER_TOKEN) + +MAX_INPUT_TOKENS_COUNT = (8192 * 4) - 256 # ~8k tokens at 3.5 chars/token, fits in 32k ctx +CHARACTER_BOUND_TO_CHECK_TOKENIZATION = (8192 * 2) - 256 # ~4k tokens +CHARACTER_BOUND_TO_SKIP_TOKENIZATION = (8192 * 4) * 2 # ~16k tokens, skip if clearly too long +MAX_RETRIEVAL_CHUNK_SIZE_LINES = 25 +DEBUG = False +# DEBUG = True + +MAX_RETRIEVAL_CHUNKS = 3 +MAX_TIMEOUT_MS = 2000 +NUM_RECENT_ACTIONS_TO_PRESERVE = 20 + +# source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/tokenization_qwen2.py#L39 +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + +pretokenize_regex = regex.compile(PRETOKENIZE_REGEX) + + +class PromptTruncationRecord(BaseModel): + """Data container for prompt truncation logic. No S3 saving.""" + autocomplete_id: str = "" + original_prompt_length: int = 0 + final_prompt_length: int = 0 + max_seq_len: int = 0 + lower_bound: int = 0 + truncation_occurred: bool = False + truncation_reason: str = "" + retrieval_results_length: int = 0 + file_chunks_used: int = 0 + file_chunks_available: int = 0 + file_path: str = "" + cursor_position: float = 0 + initial_file: str = "" + retrieval_results: str = "" + recent_changes: str = "" + prev_section: str = "" + code_block: str = "" + suffix: str = "" + prefix: str = "" + prefill: str = "" + file_chunks: list[FileChunkData] = [] + start_line: int = 0 + end_line: int = 0 + + +def find_ast_based_prefill_start( + code_block: str, + cursor_position: int, + file_path: str, + file_contents: str, + block_start_index: int, +) -> int | None: + """ + Find the start of the AST node containing the cursor position and crawl up parent nodes + to determine the optimal prefill start position. Only use parent nodes if they start + within 1-2 lines above the cursor position. + + Args: + code_block: The code block containing the cursor + cursor_position: The cursor position within the code block (relative to code_block) + file_path: The file path to determine language + file_contents: The entire file contents for proper AST parsing + block_start_index: The absolute position where code_block starts in file_contents + + Returns: + The byte offset for the prefill start position relative to code_block, or None if AST parsing fails + """ + return None + + +recent_edits_format = """The user recently made the following changes: + + +{recent_changes} + + +""" + +prompt = """<|file_sep|>{file_path} +{initial_file}{retrieval_results} +{recent_changes} +<|file_sep|>original/{file_path}:{start_line}:{end_line} +{prev_section} +<|file_sep|>current/{file_path}:{start_line}:{end_line} +{code_block} +<|file_sep|>updated/{file_path}:{start_line}:{end_line}""" + +diff_format = """<|file_sep|>{file_path}:{start_line}:{end_line} +original: +{old_code} +updated: +{new_code}""" + +_session = None # Add session for connection pooling + + +def remove_last_line_from_string(s: str) -> str: + return s[: s.rfind("\n")] if "\n" in s else s + + +def check_early_return_condition( + accumulated_response: str, + prefill: str, + cleaned_lines: list[str], + cursor_pos: int, +) -> str | None: + """ + Check if early return condition is met based on accumulated response. + This is adapted from the production code. + + Returns: + The early completion string if early return should happen, None otherwise. + """ + cleaned_content = "".join(cleaned_lines) + cleaned_content = cleaned_content.removeprefix(prefill) + + lines = cleaned_content.splitlines(True) + if not lines: + # Handle empty content case + return None + first_line, *rest = lines + adjusted_cursor_pos = cursor_pos - len(prefill) + prefix = first_line[:adjusted_cursor_pos] + suffix = first_line[adjusted_cursor_pos:] + + if "\n" in accumulated_response: + first_accumulated_response_line = accumulated_response.splitlines(True)[0] + remainder = cleaned_content[adjusted_cursor_pos:] + if ( + first_accumulated_response_line.startswith(prefix) + and first_accumulated_response_line.endswith(suffix) + and len(first_accumulated_response_line) > len(first_line) + and not remainder.strip().startswith(first_line.strip()) + ): + return first_accumulated_response_line + "".join(rest) + return None + + +@dataclass +class AutocompleteResult: + """Represents the result of an autocomplete operation.""" + + start_index: int + end_index: int + completion: str + confidence: float + autocomplete_id: str + logprobs: list = None + finish_reason: Literal["stop", "length", "timeout", "cancelled", None] = None + + +@dataclass +class AutocompleteMetadata: + """Structured metadata for autocomplete events. + + This dataclass replaces the previous loose dict usage to make it + easy to add new fields while keeping backwards compatibility where needed. + Use the `extra` field for arbitrary key-value data that doesn't yet + warrant a dedicated top-level attribute. + """ + + exit_reason: str = "unknown" + reason: Optional[str] = "unknown" + # Context size metrics for GCP dashboard correlation + retrieval_chunks_count: int = -1 + retrieval_chunks_char_count: int = -1 + retrieval_chunks_line_count: int = -1 + file_chunks_count: int = -1 + file_chunks_char_count: int = -1 + file_chunks_line_count: int = -1 + is_retrieval_autocomplete: bool = False + extra: dict[str, Any] = field(default_factory=dict) + + def convert_to_string(self) -> str: + """Convert all metadata fields to a nicely formatted string. + + Returns: + A string representation of all key-value pairs in the format: + "key1=value1 | key2=value2 | ..." + """ + parts = [ + f"exit_reason={self.exit_reason}", + f"reason={self.reason}", + f"retrieval_chunks_count={self.retrieval_chunks_count}", + f"retrieval_chunks_char_count={self.retrieval_chunks_char_count}", + f"retrieval_chunks_line_count={self.retrieval_chunks_line_count}", + f"file_chunks_count={self.file_chunks_count}", + f"file_chunks_char_count={self.file_chunks_char_count}", + f"file_chunks_line_count={self.file_chunks_line_count}", + ] + for key, value in self.extra.items(): + parts.append(f"{key}={value}") + return " | ".join(parts) + + +def get_block_around_cursor_line( + lines: list[str], cursor_line: int, num_lines_before: int, num_lines_after: int +): + block_start = max(0, cursor_line - num_lines_before) + block_end = min( + len(lines), cursor_line + num_lines_after + 1 + ) # +1 to include the cursor line + + while block_start < block_end and (lines[block_start].strip() == ""): + block_start += 1 + if block_end < len(lines): + block_end += 1 + + while block_end > block_start and (lines[block_end - 1].strip() == ""): + block_end -= 1 + + current_block = "".join(lines[block_start:block_end]) + + prefix_start = max(0, block_start - 10) + prefix = "".join(lines[prefix_start:block_start]) + + suffix_end = min(len(lines), block_end + 10) + suffix = "".join(lines[block_end:suffix_end]) + + if current_block.endswith("\n"): + current_block = current_block.strip("\n") + "\n" + + return current_block, prefix.strip("\n"), suffix.strip("\n") + + +def truncate_code_block_by_tokens( + code_block: str, max_token_limit: int = int(AUTOCOMPLETE_OUTPUT_MAX_TOKENS / 2) +) -> str: + """ + Truncate a code block to fit within the specified token limit. + + Args: + code_block: The code block to truncate + max_token_limit: Maximum number of tokens allowed (default: AUTOCOMPLETE_OUTPUT_MAX_TOKENS / 2) + + Returns: + The truncated code block + """ + code_block_lines = code_block.splitlines(True) + prefilled_code_block = "".join(code_block_lines[:NUM_LINES_BEFORE]) + remaining_code_block = "".join(code_block_lines[NUM_LINES_BEFORE:]) + estimated_tokens = estimate_token_count(remaining_code_block) + + if estimated_tokens > max_token_limit: + max_chars = int(max_token_limit * CHARS_PER_TOKEN) + remaining_code_block = remaining_code_block[:max_chars] + + # Truncate to last complete line + remaining_code_block_lines = remaining_code_block.splitlines(True) + if remaining_code_block_lines: + remaining_code_block = "".join(remaining_code_block_lines[:-1]) + code_block = prefilled_code_block + remaining_code_block + + return code_block + + +def get_block_at_cursor( + file_contents: str, cursor_position: int +) -> tuple[str, str, str, int]: + """ + Extract the code block surrounding the cursor position. + Returns the code block, prefix, suffix, and block start index. + """ + # Find cursor line number + lines = file_contents.splitlines(True) + cursor_line = get_line_number_from_position(file_contents, cursor_position) + code_block, prefix, suffix = get_block_around_cursor_line( + lines, cursor_line, NUM_LINES_BEFORE, NUM_LINES_AFTER + ) + block_start_line = max(0, cursor_line - NUM_LINES_BEFORE) + block_start_index = sum(len(line) for line in lines[:block_start_line]) + + # Apply token-based truncation + code_block = truncate_code_block_by_tokens(code_block) + + # Print final token count + line_count = code_block.count("\n") + 1 + logger.debug(f"Line count: {line_count}") + + return code_block, prefix, suffix, block_start_index + + +def is_pure_insertion_above_cursor( + cleaned_code_block: str, completion: str, relative_cursor_position: int +) -> bool: + current_line_index = len( + cleaned_code_block[:relative_cursor_position].splitlines(True) + ) + code_block_lines = cleaned_code_block.splitlines(True) + cursor_line = code_block_lines[current_line_index - 1] + + if cleaned_code_block.strip() == completion.strip(): + return False + + if not cursor_line.strip(): + return False + + prefix_lines = code_block_lines[: current_line_index - 1] + prefix = "".join(prefix_lines) + suffix_lines = code_block_lines[current_line_index:] + suffix = "".join(suffix_lines) + + if completion.startswith(prefix) and completion.endswith(cursor_line + suffix): + return True + + return False + + +def format_recent_changes_and_prev_section( + recent_changes: str, current_section: str +) -> tuple[str, str, list[str]]: + hunks = split_into_hunks(recent_changes) + prev_section = current_section.replace("<|cursor|>", "") + hunks = [hunk for hunk in hunks if len(hunk.strip().splitlines()) > 1] + # Filter out pure whitespace changes + hunks = filter_whitespace_only_hunks(hunks) + prev_sections = [] + if hunks: + copied_hunks = hunks[::-1].copy() + # any_reverts_made = False + for hunk in copied_hunks: + first_line, *rest = hunk.splitlines(True) + file_path = first_line.removeprefix( + "File: " + ) + old_code, new_code = extract_diff_parts("".join(rest)) + old_code_with_context, new_code_with_context = extract_diff_parts( + hunk, num_context_lines=1 + ) + if new_code_with_context.strip() and new_code_with_context in prev_section: + prev_section = prev_section.replace( + new_code_with_context, old_code_with_context, 1 + ) + prev_sections.append(prev_section) + # any_reverts_made = True + elif new_code.strip() and new_code in prev_section: + prev_section = prev_section.replace(new_code, old_code, 1) + prev_sections.append(prev_section) + # any_reverts_made = True + else: + break + # if any_reverts_made: + # hunks = hunks[:-1] + result = "" + for hunk in hunks[-6:]: # only keep last three recent changes + first_line, *rest = hunk.splitlines(True) + file_path = first_line.removeprefix("File: ").rstrip("\n") + old_code, new_code = extract_diff_parts("".join(rest), num_context_lines=1) + _, _, start_line, lines = parse_hunk("".join(rest)) + end_line = start_line + len(lines) - 1 + if old_code.strip() or new_code.strip(): + result += ( + diff_format.format( + old_code=old_code.strip("\n"), + new_code=new_code.strip("\n"), + file_path=file_path, + start_line=start_line, + end_line=end_line, + ) + + "\n" + ) + return result.rstrip("\n"), prev_section, prev_sections + + +def get_latest_user_action_non_cursor_movement( + recent_user_actions: list[UserAction], +) -> UserAction | None: + for action in recent_user_actions[::-1]: + if action.action_type != "CURSOR_MOVEMENT": + return action + return None + + +def get_last_user_action_index_above_cursor( + recent_user_actions: list[UserAction], + cursor_position: int, + file_path: str, + block_start_index: int, +) -> int: + """ + Get the last index of a recent_user_action that occurred above the cursor_position in the current file. + + Args: + recent_user_actions: List of UserAction objects + cursor_position: Current cursor position in the file + file_path: Current file path to filter actions for this file + + Returns: + Index of the last action above cursor, or -1 if none found + """ + if not recent_user_actions: + return -1 + + # Find actions in the current file that are above the cursor position + for i in range( + len(recent_user_actions) - 1, + max(-1, len(recent_user_actions) - 1 - NUM_RECENT_ACTIONS_TO_PRESERVE), + -1, + ): + action = recent_user_actions[i] + if ( + action.file_path == file_path + and action.action_type != "CURSOR_MOVEMENT" + and action.offset < cursor_position + ): + return action.offset - block_start_index + + return -1 + + +def get_ghost_text_with_location( + completion: str, cleaned_code_block: str, relative_cursor_position: int +) -> str: + prefix = cleaned_code_block[:relative_cursor_position] + suffix = cleaned_code_block[relative_cursor_position:] + if completion.startswith(prefix) and completion.endswith(suffix): + # Handle empty suffix case: -0 would slice to beginning, so use conditional + if suffix: + ghost_text = completion[len(prefix) : -len(suffix)] + else: + ghost_text = completion[len(prefix) :] + if ghost_text: + return ghost_text + return "" + + +def find_ghost_text_non_local( + completion: str, cleaned_code_block: str, relative_cursor_position: int +) -> tuple[str, int]: + if len(cleaned_code_block) > len(completion): + return "", -1 + + # Find all valid ghost text positions and prioritize the one with the longest prefix match + valid_positions = [] + # Check all positions including len(cleaned_code_block) for empty suffix case + for pos in range(len(cleaned_code_block) + 1): + ghost_text = get_ghost_text_with_location(completion, cleaned_code_block, pos) + if ghost_text: + valid_positions.append((pos, ghost_text)) + + if valid_positions: + # Prioritize the position with the longest prefix (highest position value) + # This ensures we match the longest common prefix before the insertion + best_position, best_ghost_text = max(valid_positions, key=lambda x: x[0]) + return best_ghost_text, best_position + + return "", -1 + + +def is_single_line_ghost_text( + completion: str, cleaned_code_block: str, relative_cursor_position: int +): + if len(cleaned_code_block) < relative_cursor_position: + return "" + + prefix = cleaned_code_block[:relative_cursor_position] + suffix = cleaned_code_block[relative_cursor_position:] + if completion.startswith(prefix) and completion.endswith(suffix): + ghost_text = completion[len(prefix) : -len(suffix)] + if ghost_text and "\n" not in ghost_text: + return ghost_text + return "" + + +def apply_completions_to_code_block( + completions: list[AutocompleteResult], file_contents: str, cleaned_code_block: str +) -> str: + """ + Apply all completions to the cleaned_code_block and return the modified code section. + + Args: + completions: List of AutocompleteResult objects + file_contents: The original file contents + cleaned_code_block: The current code block to apply completions to + + Returns: + The cleaned_code_block with all completions applied + """ + if not completions: + return cleaned_code_block + + # Find the start index of the cleaned code block in the file + cleaned_code_start_index = file_contents.find(cleaned_code_block) + if cleaned_code_start_index == -1: + return cleaned_code_block + + # Create a working copy of the cleaned_code_block + modified_code_block = cleaned_code_block + + # Sort completions by start_index in descending order to avoid offset issues + sorted_completions = sorted(completions, key=lambda c: c.start_index, reverse=True) + + # Apply each completion to the cleaned_code_block using relative positioning + for completion in sorted_completions: + # Convert absolute file positions to relative positions within the code block + relative_start = completion.start_index - cleaned_code_start_index + relative_end = completion.end_index - cleaned_code_start_index + + # Check if the completion affects the code block area + if ( + relative_start >= 0 + and relative_start <= len(cleaned_code_block) + and relative_end <= len(cleaned_code_block) + ): + # Apply the completion to cleaned_code_block + modified_code_block = ( + modified_code_block[:relative_start] + + completion.completion + + modified_code_block[relative_end:] + ) + + return modified_code_block + + +def select_best_hunk_from_completion( + completion: str, + cleaned_code_block: str, + file_contents: str, + cursor_position: int, + autocomplete_id: str, + logprobs: list = None, +) -> list[AutocompleteResult]: + """ + Find the best hunk from the completion to suggest as an edit based on cursor position. + + Args: + completion: The generated completion text + cleaned_code_block: The original code block without any markers + file_contents: The entire file contents + cursor_position: The current cursor position in the file + + Returns: + AutocompleteResult containing start_offset, end_offset, new_text, confidence, and autocomplete_id + """ + completion = strip_leading_empty_newlines(completion) + cleaned_code_block = strip_leading_empty_newlines(cleaned_code_block) + + if completion == cleaned_code_block: + return [] + + block_start_offset = file_contents.find(cleaned_code_block) + if block_start_offset == -1: + return [] + + # Edge case: check for ghost text immediately: + relative_cursor_position = cursor_position - block_start_offset + ghost_text, ghost_text_position = find_ghost_text_non_local( + completion, cleaned_code_block, relative_cursor_position + ) + if ghost_text: + is_insert_next_line = ( + ghost_text_position == relative_cursor_position + 1 + and cleaned_code_block[ghost_text_position - 1] == "\n" + ) + insertion_starts_with_newline = ( + ghost_text.startswith("\n") + and ghost_text_position == relative_cursor_position + ) + if is_insert_next_line or insertion_starts_with_newline: + return [ + AutocompleteResult( + ghost_text_position + block_start_offset, + ghost_text_position + block_start_offset, + ghost_text, + 1.0, + f"{autocomplete_id}-0", + ) + ] + + first_line, *rest = ghost_text.splitlines(True) + remaining_ghost_text = "".join(rest) + + trailing_whitespace_len = len(first_line) - len(first_line.rstrip("\n")) + if trailing_whitespace_len > 0: + trailing_whitespace = first_line[-trailing_whitespace_len:] + remaining_ghost_text = trailing_whitespace + remaining_ghost_text + first_line = first_line.rstrip("\n") + + if remaining_ghost_text: + remaining_ghost_text = remaining_ghost_text.rstrip() + if ( + ghost_text_position < len(cleaned_code_block) + and cleaned_code_block[ghost_text_position] == "\n" + ): + return [ + AutocompleteResult( + ghost_text_position + block_start_offset, + ghost_text_position + block_start_offset, + first_line, + 1.0, + f"{autocomplete_id}-0", + ), + AutocompleteResult( + ghost_text_position + block_start_offset, + ghost_text_position + block_start_offset, + remaining_ghost_text, + 1.0, + f"{autocomplete_id}-1", + ), + ] + else: + return [ + AutocompleteResult( + ghost_text_position + block_start_offset, + ghost_text_position + block_start_offset, + ghost_text, + 1.0, + f"{autocomplete_id}-0", + ) + ] + + hunks = split_into_diff_hunks(cleaned_code_block, completion) + + if not hunks: + return [] + + # Process each hunk to get its absolute position in the file + processed_hunks = [] + original_lines = cleaned_code_block.splitlines(True) + + for input_start, input_lines, output_start, output_lines in hunks: + # Calculate start offset in the original text (convert 1-based to 0-based) + start_line_idx = input_start - 1 + start_offset = block_start_offset + for i in range(start_line_idx): + if i < len(original_lines): + start_offset += len(original_lines[i]) + + # Calculate end offset + end_offset = start_offset + for i in range(len(input_lines)): + line_idx = start_line_idx + i + if line_idx < len(original_lines): + end_offset += len(original_lines[line_idx]) + + new_text = "".join(output_lines) + + # Hack for end of file + if ( + start_offset == len(file_contents) + and file_contents[start_offset - 1] != "\n" + ): + new_text = "\n" + new_text + + if file_contents[start_offset:end_offset] != new_text: + processed_hunks.append((start_offset, end_offset, new_text)) + + # Find hunks before and after the cursor + hunks_after_cursor = [ + h for h in processed_hunks if h[1] >= cursor_position + ] # use end position to determine before or after + hunks_before_cursor = [h for h in processed_hunks if h[1] < cursor_position] + + if hunks_after_cursor: + hunks_after_cursor.sort(key=lambda h: h[0]) + first_hunk, *rest_hunks = hunks_after_cursor + start_offset, end_offset, new_text = first_hunk + start_line_position = file_contents[:cursor_position].rfind("\n") + 1 + + results = [] + + should_split = ( + start_line_position == start_offset <= cursor_position < end_offset + and new_text.count("\n") > 0 + ) + + if should_split: + # assume it happens in the first line + first_line, *rest = new_text.splitlines(True) + remaining_new_text = "".join(rest) + + # Find the end of the first line in the original text + original_text_section = file_contents[cursor_position:end_offset] + first_newline_pos = original_text_section.find("\n") + if first_newline_pos != -1: + first_line_end = ( + cursor_position + first_newline_pos + 1 + ) # +1 to include the newline + else: + first_line_end = end_offset # No newline found, use end_offset + + # Check if the first line matches the current cursor line + end_line_position = file_contents.find("\n", cursor_position) + if end_line_position == -1: + end_line_position = len(file_contents) + current_cursor_line_contents = file_contents[ + start_line_position:end_line_position + ] + if not first_line.startswith(current_cursor_line_contents): + should_split = False + + if should_split: + results.append( + AutocompleteResult( + start_offset, + first_line_end, + first_line, + 1.0, + f"{autocomplete_id}-0", + logprobs, + ) + ) + if ( + remaining_new_text + ): # Only add second result if there's remaining text + results.append( + AutocompleteResult( + first_line_end, + end_offset, + remaining_new_text, + 1.0, + f"{autocomplete_id}-1", + logprobs, + ) + ) + else: + results.append( + AutocompleteResult( + start_offset, + end_offset, + new_text, + 1.0, + f"{autocomplete_id}-0", + logprobs, + ) + ) + else: + results.append( + AutocompleteResult( + start_offset, + end_offset, + new_text, + 1.0, + f"{autocomplete_id}-0", + logprobs, + ) + ) + max_id = len(rest_hunks) + results.extend( + [ + AutocompleteResult( + start_offset, + end_offset, + new_text, + 1.0, + f"{autocomplete_id}-{max_id + i}", + logprobs, + ) + for i, (start_offset, end_offset, new_text) in enumerate(rest_hunks) + ] + ) + return results + + # Otherwise, handle hunks before the cursor + # Fuse hunks that are within 2 lines of each other + results = [] + if hunks_before_cursor: + # Sort hunks by start offset to process them in order + hunks_before_cursor.sort(key=lambda h: h[0]) + + # Group hunks that are within 2 lines of each other + fused_groups = [] + current_group = [hunks_before_cursor[0]] + + for i in range(1, len(hunks_before_cursor)): + prev_hunk = current_group[-1] + curr_hunk = hunks_before_cursor[i] + + # Get line numbers for the end of previous hunk and start of current hunk + prev_end_line = get_line_number_from_position(file_contents, prev_hunk[1]) + curr_start_line = get_line_number_from_position(file_contents, curr_hunk[0]) + + # If hunks are within 2 lines of each other, add to current group + if curr_start_line - prev_end_line <= 2: + current_group.append(curr_hunk) + else: + # Start a new group + fused_groups.append(current_group) + current_group = [curr_hunk] + + # Don't forget the last group + fused_groups.append(current_group) + + # Create AutocompleteResult for each fused group + for group_idx, group in enumerate(fused_groups): + # Use the first hunk's start and the last hunk's end + first_start_offset = group[0][0] + last_end_offset = group[-1][1] + + # Reconstruct the text by applying all hunks in sequence + combined_text_parts = [] + current_offset = first_start_offset + + for start_offset, end_offset, new_text in group: + # Add any unchanged text between hunks + if current_offset < start_offset: + combined_text_parts.append(file_contents[current_offset:start_offset]) + # Add the new text from this hunk + combined_text_parts.append(new_text) + current_offset = end_offset + + combined_new_text = "".join(combined_text_parts) + + results.append( + AutocompleteResult( + first_start_offset, + last_end_offset, + combined_new_text, + 1.0, + f"{autocomplete_id}-{group_idx}", + logprobs, + ) + ) + + # sort by proximity to cursor + results.sort(key=lambda x: abs(x.start_index - cursor_position)) + + return results + + return [] + + +def fetch_next_edits_http( + formatted_prompt: str, + stop: list, + cleaned_code_block: str, + file_contents: str, + cursor_position: int, + prefix: str = "", + prefill: str = "", + force_ghost_text: bool = False, + relative_cursor_line: int = 0, +) -> ( + tuple[str, int, list[Any], str] + | tuple[str | Any, int, list[Any] | Any, Any | None] + | tuple[str, int, list[Any] | Any, Any | None] +): + """Use HTTP streaming to fetch next edits from NEXT_EDIT_AUTOCOMPLETE_ENDPOINT.""" + session = get_session() + + request_data = { + "prompt": formatted_prompt, + "stop": stop, + "max_tokens": AUTOCOMPLETE_OUTPUT_MAX_TOKENS, + "temperature": 0.0, + "prefix": prefix, + } + + headers = {"Content-Type": "application/json"} + + active_endpoint = next_edit_autocomplete_service.active_endpoint + if not active_endpoint: + raise Exception("Autocomplete service not available") + + # Use /v1/completions endpoint (OpenAI-compatible) + completions_endpoint = active_endpoint.rstrip('/') + '/v1/completions' + + try: + response = session.post( + completions_endpoint, + json=request_data, + headers=headers, + timeout=5, + stream=False, + ) + response.raise_for_status() + + accumulated_response = "" + telemetry = {} + logprobs = [] + finish_reason = None + + # Parse the SGLang completions API response (OpenAI format) + obj = response.json() + + # Log the full response for debugging + logger.info(f"SGLang response: {obj}") + + # SGLang returns OpenAI-compatible format: {"choices": [{"text": "...", "finish_reason": "..."}], ...} + if "choices" in obj and len(obj["choices"]) > 0: + choice = obj["choices"][0] + accumulated_response = choice.get("text", "") + finish_reason = choice.get("finish_reason", None) + logprobs = choice.get("logprobs", []) + else: + accumulated_response = "" + finish_reason = None + logprobs = [] + + logger.info(f"Finish reason: {finish_reason}") + logger.info(f"Accumulated response length: {len(accumulated_response)}, content: {accumulated_response[:200] if accumulated_response else 'EMPTY'}") + + return ( + accumulated_response, + 0, # elapsed_time_ms not available in standard response + logprobs, + finish_reason, + ) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 503: + logger.error( + f"Rate limit exceeded during next_edit_autocomplete HTTP request: {str(e)}" + ) + return "", 0, [], "rate_limited" + else: + logger.error( + f"HTTP error {e.response.status_code} during next_edit_autocomplete HTTP request: {str(e)}" + ) + return "", 0, [], "http_error" + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: + logger.error( + f"Timeout or connection error during next_edit_autocomplete HTTP request: {str(e)}" + ) + return "", 0, [], "timeout" + + +def get_session(): + """Get or create a requests Session for connection pooling""" + global _session + if _session is None: + _session = requests.Session() + return _session + + +def is_typing_quickly( + recent_user_actions: list[UserAction], threshold_ms: int = 200, min_actions: int = 3 +) -> bool: + """ + Detect if the user is typing quickly by analyzing consecutive INSERT_CHAR actions. + + Args: + recent_user_actions: List of recent user actions + threshold_ms: Maximum average time between keystrokes to consider "quick typing" (default 200ms) + min_actions: Minimum number of consecutive INSERT_CHAR actions to analyze (default 3) + + Returns: + True if user appears to be typing quickly, False otherwise + """ + if not recent_user_actions or len(recent_user_actions) < min_actions: + return False + + # Find consecutive INSERT_CHAR actions from the end of the list + consecutive_insert_actions = [] + for action in reversed(recent_user_actions): + if action.action_type == "INSERT_CHAR" and action.timestamp > 0: + consecutive_insert_actions.insert( + 0, action + ) # Insert at beginning to maintain order + else: + break # Stop at first non-INSERT_CHAR action + + if len(consecutive_insert_actions) < min_actions: + return False + + # Calculate average time between consecutive INSERT_CHAR actions + time_diffs = [] + for i in range(1, len(consecutive_insert_actions)): + time_diff = ( + consecutive_insert_actions[i].timestamp + - consecutive_insert_actions[i - 1].timestamp + ) + time_diffs.append(time_diff) + + if not time_diffs: + return False + + avg_time_diff = sum(time_diffs) / len(time_diffs) + + return avg_time_diff <= threshold_ms + + +def _fetch_next_edits_core( + file_path: str, + file_contents: str, + recent_changes: str, + cursor_position: int, + original_file_contents: str | None, + code_block: str, + prefix: str, + suffix: str, + autocomplete_id: str, + block_start_index: int, + is_retrieval: bool, + file_chunks: list[FileChunkData] = None, + retrieval_chunks: list[FileChunkData] = None, + recent_user_actions: list[UserAction] = None, + recent_changes_high_res: str = "", + changes_above_cursor: bool = False, + do_insert_cursor: bool = True, +): + # Initialize mutable default arguments + if file_chunks is None: + file_chunks = [] + if retrieval_chunks is None: + retrieval_chunks = [] + if recent_user_actions is None: + recent_user_actions = [] + if not code_block: + metadata = AutocompleteMetadata( + exit_reason="no_code_block", is_retrieval_autocomplete=is_retrieval + ) + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + "", + False, + metadata, + ) + + # disable non ghost-text for now, will rework later + force_ghost_text = False + + if recent_user_actions and recent_user_actions[-1].action_type == "INSERT_CHAR": + # Don't force ghost text if user is typing quickly (likely to make typos) + # if not is_typing_quickly(recent_user_actions): + force_ghost_text = True + + if not recent_user_actions: + force_ghost_text = True + + relative_cursor_position = cursor_position - file_contents.find(code_block) + cleaned_code_block = code_block + relative_cursor_line = get_line_number_from_position( + code_block, relative_cursor_position + ) + if do_insert_cursor: + code_block = ( + code_block[:relative_cursor_position] + + "<|cursor|>" + + code_block[relative_cursor_position:] + ) + only_changed_lines, prev_section, prev_sections = ( + format_recent_changes_and_prev_section(recent_changes, code_block) + ) + + if recent_changes_high_res: + _, _, prev_sections = format_recent_changes_and_prev_section( + recent_changes_high_res, code_block + ) + + # removing force ghost text for now, later, we should regenerate current line + # do not force ghost text at EOF + # disable partial prefill (ghost text on current line) for local model + is_at_eof = relative_cursor_position == len(cleaned_code_block) + if force_ghost_text and not is_at_eof and NEXT_EDIT_AUTOCOMPLETE_ENDPOINT: + prefill = cleaned_code_block[:relative_cursor_position] + pretokens = pretokenize_regex.findall(prefill) + regex_based_prefill = "".join(pretokens[:-1] if pretokens else []) + prefill = regex_based_prefill + forced_prefix = ( + cleaned_code_block[:relative_cursor_position].removeprefix(prefill) + if force_ghost_text + else "" + ) + else: + if changes_above_cursor: + forced_prefix = "" + prefill = cleaned_code_block[:relative_cursor_position] + prefilled_lines = prefill.splitlines(True) + + # Keep the first NUM_LINES_ABOVE lines + # Never rstrip newlines since it breaks tokenization + + NUM_LINES_ABOVE = 1 + before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE]) + after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:]) + + for char in after_split: + if char == "\n": + before_split += "\n" + else: + break + + prefill = "".join(before_split) + else: + prefill = "" + forced_prefix = "" + + # retrieval results are placed after the recent changes for optimal KV cache hit-rate + MAX_RETRIEVAL_TOKENS_COUNT = 2048 + packed_retrieval_chunks = pack_items_for_prompt( + retrieval_chunks, + string_function=lambda chunk: chunk.to_string(), + token_limit=MAX_RETRIEVAL_TOKENS_COUNT, + char_token_ratio=3.5, + truncate_from_end=False, + ) + retrieval_results = "".join( + [f"\n{chunk.to_string()}" for i, chunk in enumerate(packed_retrieval_chunks)] + ) + + # Count actual retrieval chunks that made it past truncation + retrieval_chunks_count = len(packed_retrieval_chunks) + retrieval_chunks_char_count = sum( + len(chunk.content) for chunk in packed_retrieval_chunks + ) + retrieval_chunks_line_count = sum( + len(chunk.content.splitlines()) for chunk in packed_retrieval_chunks + ) + + if code_block.endswith("\n") and prev_section.endswith("\n"): + code_block = code_block.removesuffix("\n") + prev_section = prev_section.removesuffix("\n") + + initial_file = get_lines_around_cursor(original_file_contents, cursor_position) + + formatted_prompt = ( + prompt.format( + file_path=file_path, + recent_changes=only_changed_lines, + prev_section=prev_section, + code_block=code_block, + retrieval_results=retrieval_results, + initial_file=initial_file, + start_line=relative_cursor_line + 1, + end_line=relative_cursor_line + len(code_block.splitlines()) + 1, + ) + + f"\n{prefill}" + ) + + truncation_record = PromptTruncationRecord( + autocomplete_id=autocomplete_id, + original_prompt_length=len(formatted_prompt), + max_seq_len=MAX_INPUT_TOKENS_COUNT, + lower_bound=CHARACTER_BOUND_TO_CHECK_TOKENIZATION, + file_path=file_path, + cursor_position=cursor_position, + initial_file=initial_file, + retrieval_results=retrieval_results, + retrieval_results_length=len(retrieval_results), + recent_changes=only_changed_lines, + prev_section=prev_section, + code_block=code_block, + suffix=suffix, + prefix=prefix, + prefill=prefill, + file_chunks=file_chunks if file_chunks else [], + file_chunks_available=len(file_chunks) if file_chunks else 0, + final_prompt_length=len(formatted_prompt), + truncation_occurred=False, + truncation_reason="", + file_chunks_used=0, + start_line=relative_cursor_line + 1, + end_line=relative_cursor_line + len(code_block.splitlines()) + 1, + ) + + formatted_file_chunks = "".join([chunk.to_string() for chunk in file_chunks]) + # Track file chunk metrics that made it past truncation + file_chunks_count = -1 + file_chunks_char_count = -1 + file_chunks_line_count = -1 + + if ( + len(formatted_prompt) + len(formatted_file_chunks) + > CHARACTER_BOUND_TO_CHECK_TOKENIZATION + ): + with Timer("Prompt Truncation", precision=4, min_time=0.0): + ( + formatted_prompt, + file_chunks_count, + file_chunks_char_count, + file_chunks_line_count, + ) = truncate_prompt_when_near_limit( + truncation_record=truncation_record, + ) + if not formatted_prompt: + metadata = AutocompleteMetadata( + exit_reason="prompt_truncation_failed", + retrieval_chunks_count=-1, + retrieval_chunks_char_count=-1, + retrieval_chunks_line_count=-1, + file_chunks_count=-1, + file_chunks_char_count=-1, + file_chunks_line_count=-1, + is_retrieval_autocomplete=is_retrieval, + ) + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + "", + False, + metadata, + ) + else: + # add back the file chunks + formatted_prompt = formatted_file_chunks + formatted_prompt + # All file chunks made it + file_chunks_count = len(file_chunks) + file_chunks_char_count = sum(len(chunk.content) for chunk in file_chunks) + file_chunks_line_count = sum( + len(chunk.content.splitlines()) for chunk in file_chunks + ) + # end section on prompt truncation + + # Create base metadata with all chunk counts - will update exit_reason as needed + base_metadata = AutocompleteMetadata( + retrieval_chunks_count=retrieval_chunks_count, + retrieval_chunks_char_count=retrieval_chunks_char_count, + retrieval_chunks_line_count=retrieval_chunks_line_count, + file_chunks_count=file_chunks_count, + file_chunks_char_count=file_chunks_char_count, + file_chunks_line_count=file_chunks_line_count, + is_retrieval_autocomplete=is_retrieval, + ) + + data = {"prompt": formatted_prompt} + + stop = ["<|endoftext|>", "<|file_sep|>"] + + cursor_line = get_line_number_from_position( + cleaned_code_block, relative_cursor_position + ) + cursor_line_text = cleaned_code_block.splitlines(True)[cursor_line].strip("\n") + + # truncate long lines in formatted_prompt ~0.0001 seconds + formatted_prompt = truncate_long_lines(formatted_prompt) + + with Timer("Autocomplete", precision=4): + if NEXT_EDIT_AUTOCOMPLETE_ENDPOINT: + # Use remote HTTP endpoint + with Timer("Autocomplete HTTP", precision=4): + try: + completion, latency, logprobs, finish_reason = ( + fetch_next_edits_http( + formatted_prompt=formatted_prompt, + stop=stop, + cleaned_code_block=cleaned_code_block, + file_contents=file_contents, + cursor_position=relative_cursor_position, + prefix=forced_prefix, + prefill=prefill, + force_ghost_text=force_ghost_text, + relative_cursor_line=relative_cursor_line, + ) + ) + except PromptTooLongError as e: + logger.warning( + f"Prompt too long for line '{cursor_line_text}', error: {e}, returning empty results" + ) + metadata = replace(base_metadata, exit_reason="prompt_too_long") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + except Exception as e: + # Re-raise other exceptions + raise + else: + # Use local llama-cpp-python model + with Timer("Autocomplete Local", precision=4): + try: + completion, latency, logprobs, finish_reason = ( + generate_completion( + prompt=formatted_prompt, + stop=stop, + max_tokens=AUTOCOMPLETE_OUTPUT_MAX_TOKENS, + temperature=0.0, + prefix=forced_prefix, + ) + ) + except RequestCancelled: + logger.info("Request cancelled by newer request") + metadata = replace(base_metadata, exit_reason="cancelled") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + except Exception as e: + logger.error(f"Local model error: {e}") + raise + if not completion: + metadata = replace(base_metadata, exit_reason="no_completion_received") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + # print(f"Received completion request for cursor line '{cursor_line_text}'") + # complete_completion = prefill + completion + # completion_line = complete_completion.splitlines(True)[cursor_line].strip("\n") + # print(f"Response: {completion_line}") + + if DEBUG: + print(formatted_prompt) + print("\n\n") + print(completion) + print("Forced prefix", forced_prefix) + print("Recent changes:") + print(recent_changes) + # print(f"Time taken: {(time.time() - start_time) * 1000} milliseconds") + + if completion == "": + # Bandaid fix -- the root cause is that the deployment may be down. + logger.warning( + f"Completion is empty for line '{cursor_line_text}', likely due to deployment issues." + ) + metadata = replace(base_metadata, exit_reason="empty_completion") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + if completion.startswith("<|") or completion.removeprefix(forced_prefix).startswith( + "<|" + ): + # Bandaid fix -- root cause is it's probably a special token. + logger.warning( + f"Completion starts with special token for line '{cursor_line_text}'." + ) + metadata = replace(base_metadata, exit_reason="special_token_in_completion") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + if not completion.startswith(forced_prefix): + # Sometimes forced prefix is not respected by CPC + logger.warning( + f"Forced prefix not respected by completion for line '{cursor_line_text}'." + ) + metadata = replace(base_metadata, exit_reason="forced_prefix_not_respected") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + original_completion = completion + completion = prefill + completion + + if is_pure_insertion_above_cursor( + cleaned_code_block, completion, relative_cursor_position + ): + # Pure insertion above cursor detected, return empty completion + logger.warning(f"Pure insertion above cursor detected.") + metadata = replace(base_metadata, exit_reason="pure_insertion_above_cursor") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + if is_large_diff_above_cursor( + cleaned_code_block, completion, relative_cursor_position + ): + # Large diff above cursor detected (>5 lines added with >1 line deleted), return empty completion + logger.warning(f"Large diff above cursor detected.") + metadata = replace(base_metadata, exit_reason="large_diff_above_cursor") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + # there's a bug in the training data which causes this to be generated sometimes, will fix later + if completion.rstrip("\n").endswith(" No newline at end of file"): + completion, _ = completion.split(" No newline at end of file", maxsplit=1) + + completion = ( + # strip_leading_empty_newlines(data.get("response", "")).removesuffix("<|file_sep|>") or cleaned_code_block + strip_leading_empty_newlines(completion).removesuffix("<|file_sep|>") + or cleaned_code_block + ) + if "<|cursor|>" not in cleaned_code_block: + completion = completion.replace("<|cursor|>", "") + + cleaned_code_lines = cleaned_code_block.splitlines(True) + completion_lines = completion.splitlines(True) + if len(completion_lines) - len(cleaned_code_lines) > 20: + # cut the completion down to at most 20 additional lines + completion = "".join(completion_lines[: len(cleaned_code_lines) + 20]) + # confidence = data.get("confidence", 0.0) + + # if completion.startswith(cleaned_code_block) and completion.removeprefix(cleaned_code_block) in file_contents: + # logger.warning("Completion starts with cleaned code block and is in file contents.") + # completion = cleaned_code_block + + + # # multi-line deletions are probably bugs so let's disable it. + # if len(cleaned_code_block.splitlines()) - len(completion.splitlines()) > 2: + # completion = cleaned_code_block + + did_hit_max_tokens = finish_reason == "length" + + if did_hit_max_tokens: # here check + logger.warning("Completion length exceeds max_tokens") + metadata = replace(base_metadata, exit_reason="hit_max_tokens") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + completions = select_best_hunk_from_completion( + completion, + cleaned_code_block, + file_contents, + cursor_position, + autocomplete_id, + logprobs, + ) + + if completion.strip("\n") == cleaned_code_block.strip("\n"): + logger.warning(f"No changes made") + should_continue = not did_hit_max_tokens + metadata = replace(base_metadata, exit_reason="no_changes_made") + return ( + completions, + completions, + formatted_prompt, + should_continue, + metadata, + ) + + # Parts of the completion may be removed in select_best_hunk_from_completion. Apply to cleaned_code_block to get actual completion + code_block_with_completions = apply_completions_to_code_block( + completions, file_contents, cleaned_code_block + ) + # ghost_text = is_single_line_ghost_text(code_block_with_completions, cleaned_code_block, relative_cursor_position) + + # Non ghost texts are annoying if they revert to a previous file state + for section in prev_sections: + if is_equal_ignoring_newlines(code_block_with_completions, section): + logger.warning(f"Revert detected for section '{section.strip()}'.") + metadata = replace(base_metadata, exit_reason="revert_detected") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + # This checks if user is on an under-indented line + cleaned_lines = cleaned_code_block.splitlines(True) + current_line = ( + "" + if relative_cursor_line >= len(cleaned_lines) + else cleaned_lines[relative_cursor_line] + ) + is_pure_whitespace = current_line.strip() == "" and len(current_line) > 0 + + if is_pure_whitespace and completions: + code_block_with_first_completion = apply_completions_to_code_block( + [completions[0]], file_contents, cleaned_code_block + ) + + # Case 1: Current line is non-empty blank line and suggestion deletes pure whitespace at cursor position + first_completion = completions[0] + is_pure_whitespace_deleted = ( + first_completion.completion == "" + and file_contents[ + first_completion.start_index : first_completion.end_index + ].strip() + == "" + and first_completion.end_index in (cursor_position, cursor_position + 1) + ) + if is_pure_whitespace_deleted: + logger.warning( + f"Current line is non-empty blank line and suggestion deletes pure whitespace at cursor position." + ) + metadata = replace(base_metadata, exit_reason="pure_whitespace_deleted") + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + # Case 2: Current line is non-empty blank line and ghost text starts with whitespace + # Only apply first completion to get ghost text + ghost_text = is_single_line_ghost_text( + code_block_with_first_completion, + cleaned_code_block, + relative_cursor_position, + ) + if ghost_text and ghost_text.startswith(" "): + logger.warning( + "Current line is non-empty blank line and ghost text starts with whitespace." + ) + metadata = replace( + base_metadata, exit_reason="ghost_text_starts_with_whitespace" + ) + return ( + [AutocompleteResult(0, 0, "", 0.0, autocomplete_id)], + [], + formatted_prompt, + False, + metadata, + ) + + metadata = replace(base_metadata, exit_reason="success") + return completions, completions, formatted_prompt, True, metadata + + +def truncate_prompt_when_near_limit( + truncation_record: PromptTruncationRecord, +) -> tuple[str | None, int, int, int]: + """ + Truncate prompt when near token limit. + + Returns: + Tuple of (final_prompt, file_chunks_used, file_chunks_char_count, file_chunks_line_count) + """ + formatted_prompt_minimal = ( + prompt.format( + file_path=truncation_record.file_path, + recent_changes=truncation_record.recent_changes, + prev_section=truncation_record.prev_section, + code_block=truncation_record.code_block, + retrieval_results="", + initial_file=truncation_record.initial_file, + start_line=truncation_record.start_line, + end_line=truncation_record.end_line, + ) + + f"\n{truncation_record.prefill}" + ) + + # Case zero: this is too long, we should not even tokenize as it takes 200 ms in worst case + if len(formatted_prompt_minimal) > CHARACTER_BOUND_TO_SKIP_TOKENIZATION: + return None, 0, 0, 0 + formatted_prompt_minimal_token_count = estimate_token_count( + formatted_prompt_minimal + ) + retrieval_results_token_count = estimate_token_count( + truncation_record.retrieval_results + ) + chunks_token_count = [ + estimate_token_count(chunk.content) for chunk in truncation_record.file_chunks + ] + + # Case 1: everything fits; return the full prompt + if ( + formatted_prompt_minimal_token_count + + retrieval_results_token_count + + sum(chunks_token_count) + <= MAX_INPUT_TOKENS_COUNT + ): + formatted_file_chunks = "".join( + [chunk.to_string() for chunk in truncation_record.file_chunks] + ) + final_prompt = formatted_file_chunks + ( + prompt.format( + file_path=truncation_record.file_path, + recent_changes=truncation_record.recent_changes, + prev_section=truncation_record.prev_section, + code_block=truncation_record.code_block, + retrieval_results=truncation_record.retrieval_results, + initial_file=truncation_record.initial_file, + start_line=truncation_record.start_line, + end_line=truncation_record.end_line, + ) + + f"\n{truncation_record.prefill}" + ) + file_chunks_count = len(truncation_record.file_chunks) + file_chunks_char_count = sum( + len(chunk.content) for chunk in truncation_record.file_chunks + ) + file_chunks_line_count = sum( + len(chunk.content.splitlines()) for chunk in truncation_record.file_chunks + ) + # Case 2: minimal autocomplete is too long + elif formatted_prompt_minimal_token_count > MAX_INPUT_TOKENS_COUNT: + final_prompt = "" + file_chunks_count = 0 + file_chunks_char_count = 0 + file_chunks_line_count = 0 + # Case 3: drop all file chunks + elif ( + formatted_prompt_minimal_token_count + retrieval_results_token_count + > MAX_INPUT_TOKENS_COUNT + ): + final_prompt = ( + prompt.format( + file_path=truncation_record.file_path, + recent_changes=truncation_record.recent_changes, + prev_section=truncation_record.prev_section, + code_block=truncation_record.code_block, + retrieval_results="", + initial_file=truncation_record.initial_file, + start_line=truncation_record.start_line, + end_line=truncation_record.end_line, + ) + + f"\n{truncation_record.prefill}" + ) + file_chunks_count = 0 + file_chunks_char_count = 0 + file_chunks_line_count = 0 + # Case 4: drop some file chunks + else: + formatted_prompt_with_retrieval_chunks = ( + prompt.format( + file_path=truncation_record.file_path, + recent_changes=truncation_record.recent_changes, + prev_section=truncation_record.prev_section, + code_block=truncation_record.code_block, + retrieval_results=truncation_record.retrieval_results, + initial_file=truncation_record.initial_file, + start_line=truncation_record.start_line, + end_line=truncation_record.end_line, + ) + + f"\n{truncation_record.prefill}" + ) + current_token_count = ( + formatted_prompt_minimal_token_count + retrieval_results_token_count + ) + + partial_formatted_file_chunks = "" + all_chunks_token_count = 0 + chunks_that_fit = [] + for chunk, chunk_token_count in zip( + truncation_record.file_chunks, chunks_token_count + ): + current_chunk_str = chunk.to_string() + all_chunks_token_count += chunk_token_count + if current_token_count + all_chunks_token_count >= MAX_INPUT_TOKENS_COUNT: + break + partial_formatted_file_chunks += current_chunk_str + chunks_that_fit.append(chunk) + final_prompt = ( + partial_formatted_file_chunks + formatted_prompt_with_retrieval_chunks + ) + file_chunks_count = len(chunks_that_fit) + file_chunks_char_count = sum(len(chunk.content) for chunk in chunks_that_fit) + file_chunks_line_count = sum( + len(chunk.content.splitlines()) for chunk in chunks_that_fit + ) + + return ( + final_prompt, + file_chunks_count, + file_chunks_char_count, + file_chunks_line_count, + ) + + +hex_hash_pattern = re.compile(r"[a-f0-9]{32,}") +base64_pattern = re.compile(r"[A-Za-z0-9+/]{40,}={0,2}") # base64 strings 40+ chars + + +def should_disable_autocomplete(file_contents: str) -> tuple[bool, str]: + """ + Check if autocomplete should be disabled based on file characteristics. + + Args: + file_contents: The content of the file to check + + Returns: + Tuple of (should_disable, reason) where should_disable is True if autocomplete + should be disabled and reason explains why + """ + if not file_contents: + return False, "" + + # Check number of characters + num_chars = len(file_contents) + if num_chars > 10_000_000: # 10M characters + return True, f"file too large: {num_chars:,} characters > 10M" + + lines = file_contents.splitlines() + if not lines: + return False, "" + + # Check number of lines + num_lines = len(lines) + if num_lines > 50_000: # 50k lines + return True, f"too many lines: {num_lines:,} lines > 50k" + + # Check average line length + total_chars = sum(len(line) for line in lines) + avg_line_length = total_chars / num_lines + if avg_line_length > 240: + return True, f"average line length {avg_line_length:.1f} > 240" + + if num_lines > 1000: + length_counter = Counter(len(line) for line in lines) + if ( + sum(length_counter[length] for length in length_counter if length > 120) + > num_lines * 0.3 + ): + return True, f"30% of lines are > 120 chars" + + hash_lines = sum( + 1 + for line in lines + if hex_hash_pattern.search(line) or base64_pattern.search(line) + ) + percentage_hash_lines = (hash_lines / num_lines) * 100 if num_lines > 0 else 0 + + if percentage_hash_lines > 10: + return True, f"10% of lines are hashes" + + return False, "" + + +def fetch_next_edits( + file_path: str, + file_contents: str, + recent_changes: str, + cursor_position: int, + original_file_contents: str | None = None, + file_chunks: list[FileChunkData] = None, + retrieval_chunks: list[FileChunkData] = None, + recent_user_actions: list[UserAction] = None, + recent_changes_high_res: str = "", + changes_above_cursor: bool = False, + is_new_user: bool = False, + editor_diagnostics: list[EditorDiagnostic] = None, +): + if is_new_user: + logger.debug(f"New user detected, disabling changes_above_cursor") + changes_above_cursor = False + + # Check if autocomplete should be disabled based on file characteristics + should_disable, reason = should_disable_autocomplete(file_contents) + if should_disable: + logger.debug(f"Disabling autocomplete: {reason}") + autocomplete_id = uuid.uuid4().hex + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + "", + AutocompleteMetadata( + exit_reason="autocomplete_disabled", + reason=reason, + is_retrieval_autocomplete=False, + ), + ) + return + + autocomplete_id = uuid.uuid4().hex + if original_file_contents is None: + original_file_contents = file_contents + if file_chunks is None: + file_chunks = [] + + cursor_position = adjust_cursor_position_from_utf16( + file_contents, cursor_position + ) + code_block, prefix, suffix, block_start_index = get_block_at_cursor( + file_contents, cursor_position + ) + + if should_disable_for_code_block(code_block): + logger.debug(f"Disabling autocomplete: long lines") + autocomplete_id = uuid.uuid4().hex + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + "", + AutocompleteMetadata( + exit_reason="autocomplete_disabled_long_lines", + reason="long_lines_code_block", + is_retrieval_autocomplete=False, + ), + ) + return + + # truncate each retrieval_chunk to MAX_RETRIEVAL_CHUNK_SIZE_LINES + for retrieval_chunk in retrieval_chunks: + retrieval_chunk.content = "".join( + retrieval_chunk.content.splitlines(True)[:MAX_RETRIEVAL_CHUNK_SIZE_LINES] + ) + + # Limit chunks for local model to reduce prompt eval latency + if not NEXT_EDIT_AUTOCOMPLETE_ENDPOINT: + file_chunks = file_chunks[:1] + retrieval_chunks = retrieval_chunks[:1] + + completions, all_completions, formatted_prompt, should_continue, metadata = ( + _fetch_next_edits_core( + file_path=file_path, + file_contents=file_contents, + recent_changes=recent_changes, + cursor_position=cursor_position, + original_file_contents=original_file_contents, + code_block=code_block, + prefix=prefix, + suffix=suffix, + autocomplete_id=autocomplete_id, + block_start_index=block_start_index, + is_retrieval=False, + file_chunks=file_chunks, + retrieval_chunks=retrieval_chunks, + recent_user_actions=recent_user_actions, + recent_changes_high_res=recent_changes_high_res, + changes_above_cursor=changes_above_cursor, + ) + ) + + if not should_continue: + if all_completions: + yield all_completions[0], all_completions, formatted_prompt, metadata + else: + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + formatted_prompt, + metadata, + ) + return + + if all_completions and not all( + ( + not completion.completion.strip("\n") + and completion.start_index == completion.end_index + ) + for completion in all_completions + ): + yield all_completions[0], all_completions, formatted_prompt, metadata + return + + with Timer(min_time=0.001, precision=3, name="find_best_matching_block"): + retrieved_code_block, block_start_offset, is_block_after_cursor, diagnostic = ( + find_best_matching_block( + file_contents, + recent_changes, + cursor_position=cursor_position, + block_size=6, + editor_diagnostics=editor_diagnostics, + ) + ) + + # if diagnostic, pass it in as an additional retrieval chunk + if diagnostic: + file_contents_lines = file_contents.splitlines() + diagnostic_line = file_contents_lines[diagnostic.line_number] if diagnostic.line_number < len(file_contents_lines) else "" + # add it as the first one + retrieval_chunks = [ + FileChunkData( + content=f"{diagnostic.message} at line {diagnostic.line_number}:\n{diagnostic_line}", + file_path="diagnostics", + start_line=1, + end_line=2, + ) + ] + retrieval_chunks + + if not retrieved_code_block: + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + formatted_prompt, + AutocompleteMetadata( + exit_reason="no_retrieved_code_block", is_retrieval_autocomplete=True + ), + ) + return + + prefix_lines = file_contents[:block_start_offset].splitlines(True) + retrieved_prefix = "".join(prefix_lines[-NUM_LINES_BEFORE:]) + + num_retrieved_lines = len(retrieved_code_block.splitlines()) + num_suffix_lines = max( + 0, NUM_LINES_AFTER + 1 - num_retrieved_lines + ) # +1 to include the cursor line + + suffix_lines = file_contents[ + block_start_offset + len(retrieved_code_block) : + ].splitlines(True) + retrieved_suffix = "".join(suffix_lines[:num_suffix_lines]) + cursor_position_in_block = block_start_offset + len( + retrieved_code_block.splitlines()[0] + ) + full_block = retrieved_prefix + truncate_code_block_by_tokens( + retrieved_code_block + retrieved_suffix + ) + + if should_disable_for_code_block(full_block): + logger.debug(f"Disabling autocomplete: long lines") + autocomplete_id = uuid.uuid4().hex + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + "", + AutocompleteMetadata( + exit_reason="autocomplete_disabled_long_lines", + reason="long_lines_code_block", + is_retrieval_autocomplete=True, + ), + ) + return + + # Set cursor position to end of the suffix_lines assignment line + completions, all_completions, formatted_prompt, _, metadata = ( + _fetch_next_edits_core( + file_path=file_path, + file_contents=file_contents, + recent_changes=recent_changes, + cursor_position=cursor_position_in_block, + original_file_contents=original_file_contents, + code_block=full_block, + prefix=retrieved_prefix, + suffix=retrieved_suffix, + autocomplete_id=autocomplete_id, + block_start_index=block_start_offset, + is_retrieval=True, + file_chunks=file_chunks, + retrieval_chunks=retrieval_chunks, + recent_user_actions=recent_user_actions + + [ + UserAction( + action_type="CURSOR_MOVEMENT", + offset=cursor_position_in_block, + line_number=get_line_number_from_position( + file_contents=file_contents, position=cursor_position_in_block + ), + file_path=file_path, + ) + ], + recent_changes_high_res=recent_changes_high_res, + changes_above_cursor=changes_above_cursor, + ) + ) + if all_completions and recent_changes and all_completions[0].completion.strip(): + yield all_completions[0], all_completions, formatted_prompt, metadata + else: + yield ( + AutocompleteResult(0, 0, "", 0, autocomplete_id), + [], + formatted_prompt, + metadata, + ) + return diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_retrieval.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_retrieval.py new file mode 100644 index 0000000..3c773e6 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_retrieval.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import difflib +import re +from functools import lru_cache + +import numpy as np +from scipy.sparse import csr_matrix + +from sweep_autocomplete.autocomplete.next_edit_autocomplete_utils import ( + extract_diff_parts, + parse_hunk, + split_into_hunks, + get_line_number_from_position, +) +from sweep_autocomplete.dataclasses.file_chunk_data import EditorDiagnostic +from loguru import logger +from sweep_autocomplete.utils.timer import Timer + +# Precompile regex for better performance with case-insensitive flag +_WORD_PATTERN = re.compile(r"\w+", re.IGNORECASE) + + +@lru_cache(maxsize=2048) +def simple_tokenizer(text): + """Cached tokenizer that extracts words (case-insensitive).""" + # Use case-insensitive regex to avoid .lower() call + tokens = _WORD_PATTERN.findall(text) + return [token for token in tokens] + + +@lru_cache(maxsize=2048) +def simple_tokenizer_with_offsets(text) -> list[tuple[str, int, int]]: + """Cached tokenizer that extracts words with their character offsets (case-insensitive). + + Returns: + List of tuples (token, start_offset, end_offset) + """ + # Use finditer to get match objects with position information + tokens_with_offsets = [ + (match.group(), match.start(), match.end()) + for match in _WORD_PATTERN.finditer(text) + ] + return tokens_with_offsets + + +def extract_added_and_deleted_code_from_recent_changes( + recent_changes: str, file_tokens_set: set[str] +) -> tuple[list[str], list[str]]: + """ + Extracts the deleted code section from the recent_changes diff. + Returns the concatenated words that were changed (deleted or added) at the word level. + If no word-level changes exist, returns the largest deleted or added code block. + """ + hunks = split_into_hunks(recent_changes) + hunks = [hunk for hunk in hunks if len(hunk.strip().splitlines()) > 1] + if not hunks: + return [], [] + added_words = [] + deleted_words = [] + for hunk in hunks[::-1]: + first_hunk_line = hunk.splitlines()[0] + if "." not in first_hunk_line: + extension = "" + else: + extension = first_hunk_line.split(".")[1].lower() + added_words, deleted_words = extract_added_and_deleted_from_hunk( + hunk, extension=extension + ) + deleted_words = [ + word for word in deleted_words if len(word) > 1 and word in file_tokens_set + ] + if len(deleted_words) == 1: + return added_words, deleted_words + return added_words, deleted_words + + +def extract_added_and_deleted_from_hunk( + hunk: str, extension: str +) -> tuple[list[str], list[str]]: + old_code, new_code = extract_diff_parts( + "".join(line for line in hunk.splitlines(True) if not line.startswith("File: ")) + ) + # Split on word boundaries while preserving whitespace and punctuation + old_words = re.findall(r"\w+|\s+|[^\w\s]", old_code) + new_words = re.findall(r"\w+|\s+|[^\w\s]", new_code) + sm = difflib.SequenceMatcher(None, old_words, new_words) + original_deleted_words = [] + original_added_words = [] + for tag, i1, i2, j1, j2 in sm.get_opcodes(): + if tag in ("replace", "delete"): + original_deleted_words.extend(old_words[i1:i2]) + if tag in ("replace", "insert"): + original_added_words.extend(new_words[j1:j2]) + added_words = [ + word + for word in original_added_words + if len(word) > 1 + ] + deleted_words = [ + word + for word in original_deleted_words + if len(word) > 1 + ] + added_words = list(set(added_words)) + deleted_words = list(set(deleted_words)) + logger.info( + f"Original added words: {original_added_words}, added words: {added_words}" + ) + logger.info( + f"Original deleted words: {original_deleted_words}, deleted words: {deleted_words}" + ) + return added_words, deleted_words + + +def extract_deleted_lines_from_recent_changes( + recent_changes: str, +) -> list[tuple[str, int]]: + """ + Extracts the full deleted lines from the recent_changes diff with their line numbers. + Returns a list of tuples (deleted_line, line_number) where deleted_line is stripped of + leading '-' and whitespace, and line_number is the original line number in the file. + """ + hunks = split_into_hunks(recent_changes) + deleted_lines_with_numbers = [] + + for hunk in hunks: + # Skip hunks that don't have @@ markers (e.g., just file headers) + if "@@" not in hunk: + continue + + # Parse the hunk to get line numbers + first_line, *rest = hunk.splitlines(True) + diff_hunk = "".join(rest) + input_start_line, input_lines, output_start_line, output_lines = parse_hunk( + diff_hunk + ) + + # Track current line number in the input (original file) + current_line = input_start_line + + for line in input_lines: + # Check if this is a deleted line (exists in input but not in output) + line_stripped = line.strip() + if line_stripped and line not in output_lines: + deleted_lines_with_numbers.append((line_stripped, current_line)) + current_line += 1 + + return deleted_lines_with_numbers + + +def find_deleted_line_match( + file_contents: str, deleted_lines: list[tuple[str, int]] +) -> tuple[str, int] | None: + """ + Searches for deleted lines in the file contents and returns a code block around the match. + Avoids matching the same line that was deleted (based on line number). + + Args: + file_contents: The current file contents to search in + deleted_lines: List of tuples (deleted_line, original_line_number) from recent changes + + Returns: + Tuple of (retrieved_code_block, block_start_offset) if a match is found, None otherwise. + The code block includes 3 lines above and 3 lines below the matched line. + """ + file_lines = file_contents.splitlines(keepends=True) + + for deleted_line, original_line_number in deleted_lines: + # Extract distinct terms from the deleted line + line_tokens = simple_tokenizer(deleted_line) + distinct_terms = set(line_tokens) + + # Check if the line has at least 3 distinct terms + if len(distinct_terms) >= 3: + # Check if this exact line exists in the file contents + for line_index, file_line in enumerate(file_lines): + # Skip if this is the same line that was deleted + if line_index == original_line_number: + continue + + if deleted_line in file_line.strip(): + start_line = line_index + end_line = min( + len(file_lines), line_index + 4 + ) # +4 to include matched line + 3 below + retrieved_code_block = "".join(file_lines[start_line:end_line]) + logger.info( + f"Retrieved code block from deleted line match: {retrieved_code_block}" + ) + + # Convert line_index to offset + block_start_offset = sum( + len(file_lines[j]) for j in range(line_index) + ) + return retrieved_code_block, block_start_offset + return None + + +def find_best_matching_block( + file_contents: str, + recent_changes: str, + cursor_position: int, + block_size: int = 6, + editor_diagnostics: list[EditorDiagnostic] = None, +) -> tuple[str, int, bool, EditorDiagnostic | None]: + # Extract deleted lines + with Timer( + min_time=0.001, precision=3, name="extract_deleted_lines_from_recent_changes" + ): + # Extract deleted lines from recent_changes + deleted_lines = extract_deleted_lines_from_recent_changes(recent_changes) + # Check if any deleted line is long enough (3+ distinct terms) and exists elsewhere in the file + # This takes precedence over the single-word heuristic + deleted_line_match = find_deleted_line_match(file_contents, deleted_lines) + if deleted_line_match is not None: + retrieved_code_block, block_start_offset = deleted_line_match + return retrieved_code_block, block_start_offset, False, None + + # simple filter for query tokens + # takes 5ms for 4k lines + with Timer(min_time=0.001, precision=3, name="simple_tokenizer_with_offsets"): + file_tokens_with_offsets: list[tuple[str, int, int]] = ( + simple_tokenizer_with_offsets(file_contents) + ) + file_tokens = [token for token, _, _ in file_tokens_with_offsets] + + with Timer( + min_time=0.001, precision=3, name="extract_added_and_deleted_code_from_recent_changes" + ): + added_words, deleted_words = extract_added_and_deleted_code_from_recent_changes( + recent_changes, set(file_tokens) + ) + + current_cursor_line_number = get_line_number_from_position( + file_contents, cursor_position + ) + + # 1. if deleted words is of length 1, use that word to determine direction + # 2. if added words exist, for each keyword check if it's a good query term + # - a good query term is one that appears in the file_contents >= 1 + # - it also can't occur too many times. something like 3-5 times at most depending on line count + if len(deleted_words) == 1: + logger.info(f"Retrieved deleted word: {deleted_words[0]}") + query_token = deleted_words[0] + else: + for word in added_words: + query_token_line_numbers = [get_line_number_from_position(file_contents, offset) for token, offset, _ in file_tokens_with_offsets if token == word] + if word in file_tokens and 5 >= file_tokens.count(word) > 1 and \ + any(abs(line_number - current_cursor_line_number) > 10 for line_number in query_token_line_numbers): + query_token = word + break + else: + query_token = None + + # find the closest match in file_tokens + # get all indices in file_contents which match the query_token + query_token_offsets = [ + offset for token, offset, _ in file_tokens_with_offsets if token == query_token and abs(get_line_number_from_position(file_contents, offset) - current_cursor_line_number) > 10 + ] + + closest_error = None + if editor_diagnostics: + # if any are errors, we can use the closest diagnostic as the offset. this can take priority over query_token + # filter to the closest diagnostic that's not within 10 lines of cursor (in line numbers) + filtered_error_diagnostics = [ + diagnostic + for diagnostic in editor_diagnostics + if diagnostic.severity == "ERROR" + and abs(current_cursor_line_number - diagnostic.line_number) > 10 + ] + if filtered_error_diagnostics: + + closest_error = min( + filtered_error_diagnostics, + key=lambda x: abs(cursor_position - x.start_offset), + ) + + if closest_error: + # get a block using closest_error.start_offset + lines = file_contents.splitlines(keepends=True) + cumulative_offset = 0 + start_line = 0 + for i, line in enumerate(lines): + if cumulative_offset + len(line) > closest_error.start_offset: + start_line = i + break + cumulative_offset += len(line) + end_line = min(len(lines), start_line + 1) + retrieved_code_block = "".join(lines[start_line:end_line]) + block_start_offset = sum(len(lines[i]) for i in range(start_line)) + return retrieved_code_block, block_start_offset, False, closest_error + elif query_token_offsets: + closest_offset = min( + query_token_offsets, key=lambda x: abs(cursor_position - x) + ) + + # takes less than 1ms for 4k lines + with Timer(min_time=0.001, precision=3, name="find_line_containing_offset"): + # Find the line containing closest_offset + lines = file_contents.splitlines(keepends=True) + cumulative_offset = 0 + line_index = 0 + for i, line in enumerate(lines): + if cumulative_offset + len(line) > closest_offset: + line_index = i + break + cumulative_offset += len(line) + + end_line = min(len(lines), line_index + 1) + retrieved_code_block = "".join(lines[line_index:end_line]) + logger.info(f"Retrieved code block: {retrieved_code_block}") + # convert line_index to offset + block_start_offset = sum(len(lines[i]) for i in range(line_index)) + return retrieved_code_block, block_start_offset, False, closest_error + else: + # Go down 6 lines from cursor - this is the "block right after current block" case + suffix = file_contents[cursor_position:] + suffix_start = suffix.find("\n") + if suffix_start != -1: + suffix = suffix[suffix_start + 1 :] + lines = suffix.splitlines(keepends=True) + cursor_position += suffix_start + 1 + len("".join(lines[:block_size])) + lines = lines[block_size:] + fallback_block = "".join(lines[:block_size]) + + logger.debug( + f"[SIMPLIFIED] Using block after cursor, block_size={len(fallback_block)}chars" + ) + return ( + fallback_block, + cursor_position, + True, + None, + ) # True indicates block after cursor diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_service.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_service.py new file mode 100644 index 0000000..b715cdd --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_service.py @@ -0,0 +1,8 @@ +from sweep_autocomplete.config import NEXT_EDIT_AUTOCOMPLETE_ENDPOINT + + +class NextEditAutocompleteService: + active_endpoint = NEXT_EDIT_AUTOCOMPLETE_ENDPOINT + + +next_edit_autocomplete_service = NextEditAutocompleteService() diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_utils.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_utils.py new file mode 100644 index 0000000..77c3bd1 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/autocomplete/next_edit_autocomplete_utils.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +import difflib +from dataclasses import dataclass + +from loguru import logger + +CHARS_PER_TOKEN = 3.5 +AUTOCOMPLETE_OUTPUT_MAX_TOKENS = 1024 +AUTOCOMPLETE_TRUNCATION_LINE_LENGTH = 600 +AUTOCOMPLETE_MAXIMUM_LINE_LENGTH = 1000 + + +@dataclass +class PromptTooLongError(Exception): + message: str + + +def adjust_cursor_position_from_utf16( + file_contents: str, utf16_cursor_position: int +) -> int: + """ + Convert cursor position from UTF-16 (used by IntelliJ/JVM) to UTF-8 byte position. + + IntelliJ uses UTF-16 encoding where emojis and other Unicode characters may take 2 bytes, + but Python uses UTF-8 where they can take 3-4 bytes. This function converts the cursor + position to match Python's string indexing. + + Args: + file_contents: The file content as a Python string + utf16_cursor_position: Cursor position as reported by IntelliJ (UTF-16 based) + + Returns: + Adjusted cursor position for Python string indexing + """ + if utf16_cursor_position <= 0: + return utf16_cursor_position + + # Convert string to UTF-16 to match IntelliJ's encoding + utf16_bytes = file_contents.encode("utf-16le") + + # Ensure we don't exceed the actual content length + max_utf16_pos = len(utf16_bytes) // 2 # Each UTF-16 character is 2 bytes + actual_utf16_pos = min(utf16_cursor_position, max_utf16_pos) + + # Get the substring up to the cursor position in UTF-16 + utf16_substring_bytes = utf16_bytes[: actual_utf16_pos * 2] + + try: + # Decode back to get the corresponding Python string + utf16_substring = utf16_substring_bytes.decode("utf-16le") + # The length of this substring is the correct cursor position in Python + return len(utf16_substring) + except UnicodeDecodeError: + # If we hit a decode error, we might be in the middle of a surrogate pair + # Try backing up by one UTF-16 unit + if actual_utf16_pos > 0: + try: + utf16_substring_bytes = utf16_bytes[: (actual_utf16_pos - 1) * 2] + utf16_substring = utf16_substring_bytes.decode("utf-16le") + return len(utf16_substring) + except UnicodeDecodeError: + pass + + # Fallback: return the original position + return utf16_cursor_position + + +def extract_diff_parts(hunk: str, num_context_lines: int = 0) -> tuple[str, str]: + """Extract the old and new code from a diff hunk. + + Args: + hunk: The diff hunk string starting with @@ markers + num_context_lines: Number of context lines to keep (default: 0, set to -1 for all) + + Returns: + Tuple of (old_code, new_code) with the - and + markers removed + """ + lines = hunk.splitlines(True) + + # Skip the @@ header line + content_lines = [line for line in lines if not line.startswith("@@")] + + if num_context_lines == -1: + # Keep all lines + old_code = [] + new_code = [] + for line in content_lines: + if line.startswith("-"): + old_code.append(line[1:]) + elif line.startswith("+"): + new_code.append(line[1:]) + else: + old_code.append(line[1:]) + new_code.append(line[1:]) + + return "".join(old_code), "".join(new_code) + + # Find the range of changed lines (- and + lines) + changed_indices = [] + for i, line in enumerate(content_lines): + if line.startswith("-") or line.startswith("+"): + changed_indices.append(i) + + if not changed_indices: + # No changes, return empty + return "", "" + + # Determine the range to include with context + start_change = min(changed_indices) + end_change = max(changed_indices) + + # Calculate context range + start_idx = max(0, start_change - num_context_lines) + end_idx = min(len(content_lines), end_change + num_context_lines + 1) + + # Extract the relevant lines + relevant_lines = content_lines[start_idx:end_idx] + + old_code = [] + new_code = [] + + for line in relevant_lines: + if line.startswith("-"): + old_code.append(line[1:]) + elif line.startswith("+"): + new_code.append(line[1:]) + else: + # Context line - add to both + old_code.append(line[1:]) + new_code.append(line[1:]) + + return "".join(old_code), "".join(new_code) + + +def filter_whitespace_only_hunks(hunks: list[str]) -> list[str]: + """ + Filter out hunks that only contain whitespace changes. + + A hunk is considered whitespace-only if the old and new code are identical + when stripped of whitespace. + + Args: + hunks: List of diff hunks, each starting with "File: " marker + + Returns: + List of hunks with whitespace-only changes removed + """ + filtered_hunks = [] + for hunk in hunks: + first_line, *rest = hunk.splitlines(True) + old_code, new_code = extract_diff_parts("".join(rest)) + # Skip hunks where the only difference is whitespace + if old_code.strip() != new_code.strip(): + filtered_hunks.append(hunk) + return filtered_hunks + + +def split_into_hunks(diff: str) -> list[str]: + """Split a diff string into individual hunks. + + Args: + diff: The full diff string + + Returns: + List of individual diff hunks, each starting with @@ marker + """ + hunks = [] + current_hunk = [] + + for line in diff.splitlines(): + if line.startswith("File: ") and current_hunk: + hunks.append("\n".join(current_hunk)) + current_hunk = [] + current_hunk.append(line) + + if current_hunk: + hunks.append("\n".join(current_hunk)) + + return hunks + + +def get_line_number_from_position(file_contents: str, position: int) -> int: + """ + Convert a character position to a line number in the file. + Optimized to avoid scanning through all lines. + + Args: + file_contents: The full contents of the file + position: The character position in the file + + Returns: + The line number (0-indexed) corresponding to the position + """ + if position <= 0: + return 0 + + if position >= len(file_contents): + return len(file_contents.splitlines()) - 1 + + # Count newlines up to the position - much faster than splitting into lines + return file_contents[:position].count("\n") + + +def get_lines_around_cursor(file_contents: str, cursor_position: int) -> str: + """ + Return a fixed span sliced into overlapping CHUNK_SIZE length chunks with STRIDE, + choosing the chunk whose center is closest to the cursor line. + + The file is conceptually chunked starting at line 0, then STRIDE, 2*STRIDE, ... Each chunk + contains up to CHUNK_SIZE lines. We pick the stride-aligned chunk that best "centers" + around the cursor (i.e., whose center is closest to the cursor line). Chunks are + truncated at EOF if needed. + + Notes: + - For files with <= CHUNK_SIZE lines, return the full file. + + Args: + file_contents: The full contents of the file as a string. + cursor_position: The cursor position as a character offset. + Returns: + A string containing the fixed chunk that contains the cursor line. + """ + lines = file_contents.splitlines() + + CHUNK_SIZE = 300 + STRIDE = CHUNK_SIZE // 2 + LIMIT_TO_CHUNK = 800 + + # Small files: just return the entire contents + if len(lines) <= LIMIT_TO_CHUNK: + return file_contents + + # Find the line number for the cursor position + cursor_line = get_line_number_from_position(file_contents, cursor_position) + + # Choose the stride-aligned chunk whose center is nearest to the cursor + # Ideal centered start (not necessarily stride-aligned) + ideal_start = cursor_line - CHUNK_SIZE // 2 + # Convert to nearest stride-aligned index (banker's rounding acceptable) + chunk_index = int(round(ideal_start / STRIDE)) + # Clamp to non-negative + chunk_index = max(0, chunk_index) + start_line = chunk_index * STRIDE # multiple of 150: 0, 150, 300, 450, ... + end_line = min(len(lines), start_line + CHUNK_SIZE) + + return "\n".join(lines[start_line:end_line]) + + +def strip_leading_empty_newlines(completion: str) -> str: + lines = completion.split("\n") + + start_index = 0 + while start_index < len(lines) and not lines[start_index].strip(): + start_index += 1 + + return "\n".join(lines[start_index:]) + + +def keep_only_changing_lines(changes: str) -> str: + if not changes: + return "" + + hunks = changes.split("\n@@") + + processed_hunks = [] + for i, hunk in enumerate(hunks): + if not hunk.strip(): + continue + lines = hunk.splitlines() + if i > 0: + lines = lines[1:] + filtered_lines = [ + line + for line in lines + if line.startswith("+") or line.startswith("-") and len(line.strip()) > 1 + ] + if filtered_lines: + processed_hunks.append("\n".join(filtered_lines)) + + return "\n".join(processed_hunks) + + +def extract_minimal_diff(original_code: str, new_code: str) -> tuple[str, int, int]: + """Extract the minimal diff between original and new code with one line of context. + + Args: + original_code: The original code block + new_code: The new code block with changes + + Returns: + Tuple of (minimal_diff, start_offset, end_offset) where offsets are relative to original_code + """ + original_lines = original_code.splitlines(keepends=True) + new_lines = new_code.splitlines(keepends=True) + + start_diff = 0 + while ( + start_diff < min(len(original_lines), len(new_lines)) + and original_lines[start_diff] == new_lines[start_diff] + ): + start_diff += 1 + + # Find last differing line (from the end) + end_diff_orig = len(original_lines) - 1 + end_diff_new = len(new_lines) - 1 + while ( + end_diff_orig >= 0 + and end_diff_new >= 0 + and end_diff_orig >= start_diff + and end_diff_new >= start_diff + and original_lines[end_diff_orig] == new_lines[end_diff_new] + ): + end_diff_orig -= 1 + end_diff_new -= 1 + + start_context = max(0, start_diff - 1) + end_context_orig = min(len(original_lines) - 1, end_diff_orig + 1) + end_context_new = min(len(new_lines) - 1, end_diff_new + 1) + + start_offset = sum(len(line) for line in original_lines[:start_context]) + end_offset = sum(len(line) for line in original_lines[: end_context_orig + 1]) + + minimal_new = "".join(new_lines[start_context : end_context_new + 1]) + if minimal_new.startswith("\n"): # hacky but works + minimal_new = minimal_new[1:] + start_offset += 1 + + return minimal_new, start_offset, end_offset + + +def parse_hunk(hunk: str) -> tuple[int, list[str], int, list[str]]: + """ + Parse a single diff hunk and return the input/output line numbers and content. + + Args: + hunk (str): The complete hunk string including header and diff lines + + Returns: + tuple: (input_start_line, input_lines, output_start_line, output_lines) + """ + lines = hunk.splitlines(keepends=True) + hunk_header = lines[0] + diff_lines = lines[2:] + + parts = hunk_header.split() + input_range = parts[1].lstrip("-") + output_range = parts[2].lstrip("+") + + input_parts = input_range.split(",") + output_parts = output_range.split(",") + + input_start = int(input_parts[0]) + output_start = int(output_parts[0]) + + input_lines = [] + output_lines = [] + + for line in diff_lines: + if line.startswith("-"): + input_lines.append(line[1:]) + elif line.startswith("+"): + output_lines.append(line[1:]) + else: + # Context line, add to both + input_lines.append(line[1:]) + output_lines.append(line[1:]) + + # EDGE CASE -- Python STD's difflib has an off-by-one error when n=0 and the a-lines are empty + if not input_lines: + input_start += 1 + + return input_start, input_lines, output_start, output_lines + + +def split_into_diff_hunks(input_content: str, output_content: str): + """ + Split two files into diff hunks. + + Args: + input_content (str): Content of the input file + output_content (str): Content of the output file + + Returns: + list: A list of tuples where each tuple contains: + (input_start_line, input_lines, output_start_line, output_lines) + """ + + input_lines = input_content.splitlines() + output_lines = output_content.splitlines() + + diff = list( + difflib.unified_diff( + input_lines, + output_lines, + "input", + "output", + n=0, + ) + ) + + hunks = [] + current_hunk_lines = [] + + for line in diff: + if line.startswith("@@"): + if current_hunk_lines: + hunks.append(parse_hunk("\n".join(current_hunk_lines) + "\n")) + current_hunk_lines = [line] + elif current_hunk_lines: + current_hunk_lines.append(line) + + if current_hunk_lines: + hunks.append(parse_hunk("\n".join(current_hunk_lines) + "\n")) + + return hunks + + +def is_large_diff_above_cursor( + original: str, completion: str, relative_cursor_position: int +) -> bool: + """ + Check if the completion has a large diff above the cursor position. + A large diff is defined as >5 lines added with >5 lines deleted. + + NOTE(wzeng): I'm actually approximating this as the AND of: + 1. there's a change before the user's current position + 2. the *entire* block (including after the cursor) has at least 5 lines added and 5 lines deleted + + Args: + original: The original code block + completion: The completed code block after changes + relative_cursor_position: The cursor position relative to the code block + + Returns: + True if there's a large diff above the cursor, False otherwise + """ + if original == completion: + return False + + # Get the content above the cursor for both original and completion + original_above_cursor = original[:relative_cursor_position] + completion_above_cursor = completion[:relative_cursor_position] + + if original_above_cursor == completion_above_cursor: + # There's no change above the cursor + return False + + diff = list( + difflib.unified_diff( + original.splitlines(keepends=True), + completion.splitlines(keepends=True), + n=0, + ) + ) + + additions = 0 + deletions = 0 + + for line in diff: + if line.startswith("+++") or line.startswith("---") or line.startswith("@@"): + continue + elif line.startswith("+"): + additions += 1 + elif line.startswith("-"): + deletions += 1 + is_large = additions > 5 and deletions > 5 + if is_large: + logger.debug( + f"Large diff above cursor detected: {additions} additions, {deletions} deletions" + ) + + return is_large + + +def is_completion_max_tokens(completion: str, model: str = "qwen") -> bool: + """ + Check if the completion length equals max_tokens using the appropriate tokenizer. + + Args: + completion: The completion text to check + model: The model name to use for tokenization (defaults to "qwen") + + Returns: + True if completion length equals AUTOCOMPLETE_OUTPUT_MAX_TOKENS, False otherwise + """ + try: + token_count = int(len(completion) / CHARS_PER_TOKEN) + return token_count >= AUTOCOMPLETE_OUTPUT_MAX_TOKENS + except Exception as e: + logger.warning(f"Failed to count tokens for completion: {e}") + return False + + +# this is fairly well tested +def detect_and_revert_end_deletion(original: str, completion: str) -> tuple[str, bool]: + """ + Detect if completion has a large deletion from the end and revert it. + Uses suffix anchor approach to handle cases where completion has modifications but also end deletions. + """ + + # Only consider cases where completion is significantly shorter (potential deletion) + if len(completion) >= len(original) * 0.8: # Less than 20% deleted + return completion, False + + # Only run end deletion detection if completion length equals max_tokens, after length check as this costs some ms + if not is_completion_max_tokens(completion): + return completion, False + + # Find a substantial suffix from completion that appears uniquely in original + MIN_SUFFIX_LENGTH = 50 # Increased for better uniqueness + MAX_SUFFIX_LENGTH = min( + 200, len(completion) // 2 + ) # Don't use more than half the completion + + for suffix_len in range(MIN_SUFFIX_LENGTH, MAX_SUFFIX_LENGTH + 1): + if suffix_len > len(completion): + break + + suffix = completion[-suffix_len:] + + # Must be unique in original + if original.count(suffix) != 1: + continue + + suffix_pos = original.find(suffix) + + # Only consider if there's substantial content after the suffix in original + potential_deletion = original[suffix_pos + suffix_len :] + if len(potential_deletion) < 50: # Not a substantial deletion + continue + + # Check if this looks like a real end deletion (multiple lines or significant content) + if potential_deletion.count("\n") >= 2 or len(potential_deletion) > 100: + logger.warning(f"Detected end deletion, adding back {potential_deletion}") + return completion + potential_deletion, True + + return completion, True + + +def truncate_long_lines(content: str) -> str: + """ + Truncate lines that exceed max_line_length while preserving structure. + + Args: + content: The content to process + max_line_length: Maximum allowed line length (default: 300) + + Returns: + Content with long lines truncated + Mapping of truncated lines to their full contents + """ + lines = content.splitlines(True) # Keep line endings + truncated_lines = [] + + for original_line in lines: + if len(original_line) > AUTOCOMPLETE_TRUNCATION_LINE_LENGTH: + # Keep the line ending if present + has_newline = original_line.endswith("\n") + line_without_newline = original_line.rstrip("\n") + + # Truncate and add ellipsis + truncated = ( + line_without_newline[:AUTOCOMPLETE_TRUNCATION_LINE_LENGTH] + "..." + ) + + # Restore newline if it was there + if has_newline: + truncated += "\n" + + truncated_lines.append(truncated) + else: + truncated_lines.append(original_line) + + return "".join(truncated_lines) + + +def should_disable_for_code_block(code_block: str) -> bool: + """ + Check if the code block should be disabled due to long lines, this uses a larger threshold. + """ + lines = code_block.splitlines() + return any(len(line) > AUTOCOMPLETE_MAXIMUM_LINE_LENGTH for line in lines) + + +def normalize_newlines(text: str) -> str: + """ + Normalize consecutive newlines in a string by collapsing multiple newlines into single newlines. + This makes string comparisons newline-agnostic, so that "a\\n\\nb" is considered equal to "a\\nb". + + Args: + text: The text to normalize + + Returns: + Text with consecutive newlines collapsed to single newlines + + Examples: + >>> normalize_newlines("a\\n\\nb") + 'a\\nb' + >>> normalize_newlines("a\\n\\n\\nb") + 'a\\nb' + >>> normalize_newlines("hello\\n\\nworld\\n\\ntest") + 'hello\\nworld\\ntest' + """ + import re + + # Replace multiple consecutive newlines with a single newline + return re.sub(r"\n+", "\n", text) + + +def is_equal_ignoring_newlines(text1: str, text2: str) -> bool: + """ + Compare two strings for equality while ignoring differences in consecutive newlines. + This treats "a\\n\\nb" as equal to "a\\nb". + + Args: + text1: First string to compare + text2: Second string to compare + + Returns: + True if the strings are equal after normalizing newlines, False otherwise + + Examples: + >>> is_equal_ignoring_newlines("a\\n\\nb", "a\\nb") + True + >>> is_equal_ignoring_newlines("hello\\nworld", "hello\\n\\nworld") + True + >>> is_equal_ignoring_newlines("hello\\nworld", "hello world") + False + """ + return normalize_newlines(text1) == normalize_newlines(text2) diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/cli.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/cli.py new file mode 100644 index 0000000..9b496bb --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/cli.py @@ -0,0 +1,15 @@ +import argparse +import uvicorn + + +def main(): + parser = argparse.ArgumentParser(description="Sweep Autocomplete Server") + parser.add_argument("--host", default="0.0.0.0", help="Bind host (default: 0.0.0.0)") + parser.add_argument("--port", type=int, default=8081, help="Bind port (default: 8081)") + args = parser.parse_args() + + uvicorn.run("sweep_autocomplete.app:app", host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/config.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/config.py new file mode 100644 index 0000000..dc50502 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/config.py @@ -0,0 +1,9 @@ +import os + +NEXT_EDIT_AUTOCOMPLETE_ENDPOINT = os.environ.get( + "NEXT_EDIT_AUTOCOMPLETE_ENDPOINT", None +) + +# MLX model — use the HuggingFace repo directly (mlx_lm.load handles download & caching) +# Override with MODEL_REPO env var to use a different model (e.g. a local path or custom conversion) +MODEL_REPO = os.environ.get("MODEL_REPO", "Chris-Kode/sweep-next-edit-1.5b-mlx") diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/__init__.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/file_chunk_data.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/file_chunk_data.py new file mode 100644 index 0000000..f0f4599 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/dataclasses/file_chunk_data.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + + +@dataclass +class FileChunkData: + file_path: str + start_line: int + end_line: int + content: str + + def to_string(self) -> str: + return f"<|file_sep|>{self.file_path}\n{self.content}\n" + + +@dataclass +class UserAction: + action_type: str + line_number: int + offset: int + file_path: str + timestamp: int = 0 + + +@dataclass +class EditorDiagnostic: + line: int # this is 1-indexed, use line_number to get 0-indexed + start_offset: int + end_offset: int + severity: str + message: str + timestamp: int = 0 + + @property + def line_number(self) -> int: + return self.line - 1 diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/__init__.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/compression_middleware.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/compression_middleware.py new file mode 100644 index 0000000..3c85003 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/compression_middleware.py @@ -0,0 +1,74 @@ +import gzip + +import brotli +from starlette.types import ASGIApp, Receive, Scope, Send + +from loguru import logger + + +class RequestCompressionMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + headers = dict(scope.get("headers", [])) + content_encoding = ( + headers.get(b"content-encoding", b"").decode("latin1").lower() + ) + if "gzip" not in content_encoding and "br" not in content_encoding: + await self.app(scope, receive, send) + return + + async def receive_with_decompression(): + message = await receive() + + if message["type"] == "http.request": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + body_parts = [body] if body else [] + + while more_body: + current_message = await receive() + body_part = current_message.get("body", b"") + if body_part: + body_parts.append(body_part) + more_body = current_message.get("more_body", False) + + full_body = b"".join(body_parts) + + if full_body: + try: + if "gzip" in content_encoding: + decompressed_body = gzip.decompress(full_body) + elif "br" in content_encoding: + decompressed_body = brotli.decompress(full_body) + else: + decompressed_body = full_body + except Exception as e: + logger.error(f"Decompression failed: {str(e)}") + raise + + message["body"] = decompressed_body + + new_headers = [] + for name, value in scope.get("headers", []): + if name.lower() == b"content-length": + new_headers.append( + (name, str(len(decompressed_body)).encode()) + ) + elif name.lower() == b"content-encoding": + continue + else: + new_headers.append((name, value)) + + scope["headers"] = new_headers + message["more_body"] = False + + return message + + await self.app(scope, receive_with_decompression, send) diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/str_utils.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/str_utils.py new file mode 100644 index 0000000..5aa4fbe --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/str_utils.py @@ -0,0 +1,40 @@ +from typing import Callable, Iterable, Union + +from loguru import logger + + +def pack_items_for_prompt( + iterable: Iterable, + string_function: Union[callable, None], + token_limit: int, + char_token_ratio: int = 3.5, + truncate_from_end: bool = True, +) -> list: + """ + Packs items from an iterable into a list of strings, using a string function to convert each item to a string. + The total number of tokens in the packed items will not exceed the token limit. + Truncates from the end if truncate_from_end is True, otherwise from the beginning. + """ + char_limit = token_limit * char_token_ratio + packed_items = [] + current_str = "" + if truncate_from_end: + for item in iterable: + item_str = string_function(item) if string_function else str(item) + if len(current_str) + len(item_str) <= char_limit: + packed_items.append(item) + current_str += item_str + else: + break + else: + for item in reversed(iterable): + item_str = string_function(item) if string_function else str(item) + if len(current_str) + len(item_str) <= char_limit: + packed_items.insert(0, item) + current_str = item_str + current_str + else: + break + logger.info( + f"Removed {len(iterable) - len(packed_items)} items to fit within the token limit ({len(packed_items)} items remaining). Final token estimate: {int(len(current_str) // char_token_ratio)}" + ) + return packed_items diff --git a/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/timer.py b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/timer.py new file mode 100644 index 0000000..29ab152 --- /dev/null +++ b/vendor/sweep-autocomplete-mlx/sweep_autocomplete/utils/timer.py @@ -0,0 +1,59 @@ +import time +from contextlib import contextmanager +from dataclasses import dataclass, field + +from loguru import logger + + +@dataclass +class Timer: + name: str = "" + min_time: float = 0.01 + start: float = 0 + end: float = 0 + time_elapsed: float = -1 + do_print: bool = True + precision: int = 2 + steps: list[tuple[str, float]] = field(default_factory=list) + max_expected_time: float = float("inf") + + def __enter__(self): + self.start = time.time() + return self + + def print(self, /, name: str | None = None, time_elapsed: float | None = None): + if time_elapsed is None: + time_elapsed = self.time_elapsed + if name is None: + name = self.name + if time_elapsed > self.max_expected_time: + log = logger.warning + else: + log = logger.debug + if name: + log(f"Timer {name} elapsed: {time_elapsed:.{self.precision}f} seconds") + else: + log(f"Timer elapsed: {time_elapsed:.{self.precision}f} seconds") + + @contextmanager + def step(self, name: str): + start = time.time() + with Timer(name=name, do_print=False) as timer: + yield timer + end = time.time() + time_elapsed = end - start + if self.do_print and time_elapsed > self.min_time: + self.print(name=name, time_elapsed=time_elapsed) + self.steps.append((name, time_elapsed)) + + def __exit__(self, exc_type, exc_value, traceback): + self.end = time.time() + self.time_elapsed = self.end - self.start + if self.steps: + logger.debug( + f"Breakdown of {self.name} ({self.time_elapsed:.{self.precision}f} seconds):" + ) + for name, time_elapsed in self.steps: + logger.debug(f" {name}: {time_elapsed:.{self.precision}f} seconds") + elif self.do_print and self.time_elapsed > self.min_time: + self.print() From 086dfee8ea3baf9a138d35933cbcd501c91f5b46 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Tue, 14 Apr 2026 21:47:57 +0200 Subject: [PATCH 4/7] Add native NES engine with direct llama-server integration Port the Python sweep-autocomplete prompt construction and completion parsing logic to Kotlin, eliminating the Python server dependency. The plugin now constructs NES prompts in-process and calls llama-server's /v1/completions endpoint directly. Key components: - NesUtils, NesRetrieval, NesCompletionParser, NesPromptBuilder: Pure Kotlin port of the Python NES logic with 24 unit tests verified against Python-generated fixtures for parity - LlamaServerClient: HTTP client with SSE streaming, early abort on oversized completions, and request cancellation via thread interrupt - NextEditAutocompleteEngine: Top-level orchestrator with two-pass autocomplete (cursor-based + retrieval-based) - NesModelConfig: Model selector supporting 0.5B, 1.5B, and 7B variants - LocalAutocompleteServerManager: Launches llama-server with ngram-mod speculative decoding (--spec-type ngram-mod), auto-downloads model via hf CLI or curl fallback, auto-restarts on model change Performance: ~125ms median latency (2.7x faster than Python path) with llama-server + ngram speculative decoding on Apple Silicon. Includes benchmark script (bin/benchmark_autocomplete.py) for comparing Python vs native engine performance across multiple scenarios. Co-Authored-By: Claude Opus 4.6 (1M context) --- bin/benchmark_autocomplete.py | 515 ++++++++++++++++++ bin/generate_parser_fixtures.py | 112 ++++ bin/generate_test_fixtures.py | 222 ++++++++ .../edit/EditAutocompleteModels.kt | 4 + .../edit/engine/LlamaServerClient.kt | 180 ++++++ .../edit/engine/NesCompletionParser.kt | 346 ++++++++++++ .../autocomplete/edit/engine/NesConstants.kt | 63 +++ .../edit/engine/NesModelConfig.kt | 49 ++ .../edit/engine/NesPromptBuilder.kt | 389 +++++++++++++ .../autocomplete/edit/engine/NesRetrieval.kt | 301 ++++++++++ .../autocomplete/edit/engine/NesUtils.kt | 407 ++++++++++++++ .../edit/engine/NextEditAutocompleteEngine.kt | 358 ++++++++++++ .../sweep/assistant/components/SweepConfig.kt | 69 +++ .../services/AutocompleteIpResolverService.kt | 101 +++- .../LocalAutocompleteServerManager.kt | 191 ++++++- .../sweep/assistant/settings/SweepSettings.kt | 6 +- .../edit/engine/NesCompletionParserTest.kt | 149 +++++ .../autocomplete/edit/engine/NesUtilsTest.kt | 276 ++++++++++ .../nes_fixtures/parser_fixtures.json | 77 +++ .../nes_fixtures/utils_fixtures.json | 183 +++++++ 20 files changed, 3993 insertions(+), 5 deletions(-) create mode 100755 bin/benchmark_autocomplete.py create mode 100644 bin/generate_parser_fixtures.py create mode 100644 bin/generate_test_fixtures.py create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/LlamaServerClient.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParser.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesConstants.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesModelConfig.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesPromptBuilder.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesRetrieval.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtils.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NextEditAutocompleteEngine.kt create mode 100644 src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParserTest.kt create mode 100644 src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtilsTest.kt create mode 100644 src/test/resources/nes_fixtures/parser_fixtures.json create mode 100644 src/test/resources/nes_fixtures/utils_fixtures.json diff --git a/bin/benchmark_autocomplete.py b/bin/benchmark_autocomplete.py new file mode 100755 index 0000000..637b930 --- /dev/null +++ b/bin/benchmark_autocomplete.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +""" +Benchmark for local next-edit autocomplete: Python server vs llama-server direct. + +Usage: + # Python path (sweep-autocomplete on port 8081): + uvx sweep-autocomplete --port 8081 + python bin/benchmark_autocomplete.py --port 8081 --mode python + + # llama-server direct (llama-server on port 8081): + llama-server -m model.gguf --port 8081 -ngl -1 --flash-attn + python bin/benchmark_autocomplete.py --port 8081 --mode llama-server + + # Both (Python on 8081, llama-server on 8082): + python bin/benchmark_autocomplete.py --python-port 8081 --llama-port 8082 +""" + +import argparse +import json +import time +import requests +import statistics + +# --- Test fixtures --- + +FILE_PATH = "src/utils/data_processor.py" + +FILE_CONTENTS = '''\ +import json +import os +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Config: + input_path: str + output_path: str + batch_size: int = 32 + verbose: bool = False + + +def load_config(path: str) -> Config: + with open(path) as f: + data = json.load(f) + return Config(**data) + + +def process_batch(items: list[dict], config: Config) -> list[dict]: + results = [] + for item in items: + transformed = { + "id": item["id"], + "name": item["name"].strip().lower(), + "score": round(item.get("score", 0) * 100, 2), + } + if config.verbose: + print(f"Processing {transformed['id']}") + results.append(transformed) + return results + + +def save_results(results: list[dict], path: str) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as f: + json.dump(results, f, indent=2) + + +def main(): + config = load_config("config.json") + data = json.load(open(config.input_path)) + + all_results = [] + for i in range(0, len(data), config.batch_size): + batch = data[i:i + config.batch_size] + results = process_batch(batch, config) + all_results.extend(results) + + save_results(all_results, config.output_path) + print(f"Processed {len(all_results)} items") + + +if __name__ == "__main__": + main() +''' + +CURSOR_POSITION = FILE_CONTENTS.index('results.append(transformed)') + len('results.append(transformed)') + +RECENT_CHANGES = """\ +File: src/utils/data_processor.py +@@ -28,7 +28,7 @@ + transformed = { + "id": item["id"], + "name": item["name"].strip().lower(), +- "score": round(item.get("score", 0) * 100, 2), ++ "value": round(item.get("score", 0) * 100, 2), + } +""" + +CURSOR_POSITION_2 = len(FILE_CONTENTS) + +RECENT_CHANGES_2 = """\ +File: src/utils/data_processor.py +@@ -48,3 +48,6 @@ + data = json.load(open(config.input_path)) + + all_results = [] ++ for i in range(0, len(data), config.batch_size): ++ batch = data[i:i + config.batch_size] ++ results = process_batch(batch, config) +""" + +# Pre-built prompt matching NES format (for llama-server direct testing) +# This is what the Kotlin engine would construct +LLAMA_PROMPT_SCENARIO_1 = f"""<|file_sep|>{FILE_PATH} +{FILE_CONTENTS} + +<|file_sep|>{FILE_PATH}:1:1 +original: + "score": round(item.get("score", 0) * 100, 2), +updated: + "value": round(item.get("score", 0) * 100, 2), +<|file_sep|>original/{FILE_PATH}:6:14 + for item in items: + transformed = {{ + "id": item["id"], + "name": item["name"].strip().lower(), + "score": round(item.get("score", 0) * 100, 2), + }} + if config.verbose: + print(f"Processing {{transformed['id']}}") + results.append(transformed) +<|file_sep|>current/{FILE_PATH}:6:14 + for item in items: + transformed = {{ + "id": item["id"], + "name": item["name"].strip().lower(), + "value": round(item.get("score", 0) * 100, 2), + }} + if config.verbose: + print(f"Processing {{transformed['id']}}") + results.append(transformed)<|cursor|> +<|file_sep|>updated/{FILE_PATH}:6:14 +""" + +LLAMA_PROMPT_SCENARIO_3 = f"""<|file_sep|>{FILE_PATH} +{FILE_CONTENTS} + +<|file_sep|>original/{FILE_PATH}:12:20 + save_results(all_results, config.output_path) + print(f"Processed {{len(all_results)}} items") + + +if __name__ == "__main__": + main() +<|file_sep|>current/{FILE_PATH}:12:20 + save_results(all_results, config.output_path) + print(f"Processed {{len(all_results)}} items") + + +if __name__ == "__main__": + main() +<|file_sep|>updated/{FILE_PATH}:12:20 +""" + + +def make_python_request(port, cursor_position, recent_changes): + """Send a request to the Python sweep-autocomplete server.""" + url = f"http://localhost:{port}/backend/next_edit_autocomplete" + payload = { + "file_path": FILE_PATH, + "file_contents": FILE_CONTENTS, + "original_file_contents": FILE_CONTENTS, + "recent_changes": recent_changes, + "cursor_position": cursor_position, + "file_chunks": [], + "retrieval_chunks": [], + "recent_user_actions": [], + "multiple_suggestions": True, + "recent_changes_high_res": "", + "changes_above_cursor": True, + "editor_diagnostics": [], + } + + start = time.perf_counter() + resp = requests.post(url, json=payload, timeout=30) + elapsed_ms = (time.perf_counter() - start) * 1000 + resp.raise_for_status() + + result = None + for line in resp.text.strip().split("\n"): + if line.strip(): + try: + result = json.loads(line) + except json.JSONDecodeError: + pass + + completions = result.get("completions", []) if result else [] + completion_text = completions[0]["completion"] if completions else "" + server_elapsed = result.get("elapsed_time_ms", -1) if result else -1 + + return { + "round_trip_ms": round(elapsed_ms, 1), + "server_ms": server_elapsed, + "completion_length": len(completion_text), + "completion_preview": completion_text[:80].replace("\n", "\\n") if completion_text else "(empty)", + "num_completions": len(completions), + } + + +def make_llama_request(port, prompt): + """Send a request directly to llama-server /v1/completions.""" + url = f"http://localhost:{port}/v1/completions" + payload = { + "prompt": prompt, + "stop": ["<|endoftext|>", "<|file_sep|>"], + "max_tokens": 1024, + "temperature": 0.0, + "n_predict": 1024, + } + + start = time.perf_counter() + resp = requests.post(url, json=payload, timeout=30) + elapsed_ms = (time.perf_counter() - start) * 1000 + resp.raise_for_status() + + obj = resp.json() + text = "" + finish_reason = None + if "choices" in obj and len(obj["choices"]) > 0: + text = obj["choices"][0].get("text", "") + finish_reason = obj["choices"][0].get("finish_reason") + + return { + "round_trip_ms": round(elapsed_ms, 1), + "server_ms": round(elapsed_ms, 1), # no separate server timing for llama-server + "completion_length": len(text), + "completion_preview": text[:80].replace("\n", "\\n") if text else "(empty)", + "num_completions": 1 if text else 0, + "finish_reason": finish_reason, + } + + +def run_scenario(name, request_fn, runs, warmup): + """Run a benchmark scenario multiple times and print stats.""" + print(f"\n{'=' * 60}") + print(f"Scenario: {name}") + print(f"{'=' * 60}") + + for i in range(warmup): + r = request_fn() + print(f" warmup {i+1}: {r['round_trip_ms']:.0f}ms (server: {r['server_ms']}ms)") + + results = [] + for i in range(runs): + r = request_fn() + results.append(r) + preview = r['completion_preview'] + extra = f" finish={r.get('finish_reason', 'n/a')}" if 'finish_reason' in r else "" + print(f" run {i+1}: {r['round_trip_ms']:.0f}ms (server: {r['server_ms']}ms) " + f"completions={r['num_completions']}{extra} [{preview}]") + + round_trips = [r["round_trip_ms"] for r in results] + server_times = [r["server_ms"] for r in results if r["server_ms"] > 0] + + print(f"\n Round-trip: median={statistics.median(round_trips):.0f}ms " + f"mean={statistics.mean(round_trips):.0f}ms " + f"min={min(round_trips):.0f}ms max={max(round_trips):.0f}ms") + if server_times: + print(f" Server: median={statistics.median(server_times):.0f}ms " + f"mean={statistics.mean(server_times):.0f}ms " + f"min={min(server_times):.0f}ms max={max(server_times):.0f}ms") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark local autocomplete: Python vs llama-server direct") + parser.add_argument("--port", type=int, default=None, help="Server port (for single-mode testing)") + parser.add_argument("--python-port", type=int, default=None, help="Python sweep-autocomplete port") + parser.add_argument("--llama-port", type=int, default=None, help="llama-server port") + parser.add_argument("--mode", choices=["python", "llama-server", "both"], default="both", + help="Which server to benchmark") + parser.add_argument("--runs", type=int, default=5, help="Timed runs per scenario (default: 5)") + parser.add_argument("--warmup", type=int, default=2, help="Warmup runs per scenario (default: 2)") + args = parser.parse_args() + + # Resolve ports + python_port = args.python_port or args.port + llama_port = args.llama_port or args.port + + if args.mode in ("python", "both") and python_port: + try: + resp = requests.get(f"http://localhost:{python_port}/health", timeout=5) + resp.raise_for_status() + print(f"Python server healthy at port {python_port}") + except Exception as e: + print(f"Warning: Python server not reachable at port {python_port}: {e}") + if args.mode == "python": + return + args.mode = "llama-server" + + if args.mode in ("llama-server", "both") and llama_port: + try: + resp = requests.get(f"http://localhost:{llama_port}/health", timeout=5) + resp.raise_for_status() + print(f"llama-server healthy at port {llama_port}") + except Exception as e: + print(f"Warning: llama-server not reachable at port {llama_port}: {e}") + if args.mode == "llama-server": + return + args.mode = "python" + + # ===================== + # Python server benchmarks + # ===================== + if args.mode in ("python", "both") and python_port: + print(f"\n{'#' * 60}") + print(f"# PYTHON SERVER (port {python_port})") + print(f"# Full stack: plugin → HTTP → FastAPI → prompt → llama-cpp-python → Metal") + print(f"{'#' * 60}") + + run_scenario( + "Python: Rename variable", + lambda: make_python_request(python_port, CURSOR_POSITION, RECENT_CHANGES), + args.runs, args.warmup, + ) + + run_scenario( + "Python: Repeated identical request", + lambda: make_python_request(python_port, CURSOR_POSITION, RECENT_CHANGES), + args.runs, 0, + ) + + run_scenario( + "Python: New function at EOF", + lambda: make_python_request(python_port, CURSOR_POSITION_2, RECENT_CHANGES_2), + args.runs, args.warmup, + ) + + # ===================== + # llama-server direct benchmarks + # ===================== + if args.mode in ("llama-server", "both") and llama_port: + print(f"\n{'#' * 60}") + print(f"# LLAMA-SERVER DIRECT (port {llama_port})") + print(f"# Inference only: plugin → HTTP → llama-server (C++) → Metal") + print(f"{'#' * 60}") + + run_scenario( + "llama-server: Rename variable (prompt constructed in plugin)", + lambda: make_llama_request(llama_port, LLAMA_PROMPT_SCENARIO_1), + args.runs, args.warmup, + ) + + run_scenario( + "llama-server: Repeated identical prompt (KV cache hit)", + lambda: make_llama_request(llama_port, LLAMA_PROMPT_SCENARIO_1), + args.runs, 0, + ) + + run_scenario( + "llama-server: New function at EOF", + lambda: make_llama_request(llama_port, LLAMA_PROMPT_SCENARIO_3), + args.runs, args.warmup, + ) + + # ===================== + # Iterative editing benchmark (ngram-mod cross-request learning) + # ===================== + if args.mode in ("llama-server", "both") and llama_port: + print(f"\n{'#' * 60}") + print(f"# ITERATIVE EDITING (llama-server port {llama_port})") + print(f"# Simulates adding lines one at a time — tests ngram-mod learning") + print(f"{'#' * 60}") + + run_iterative_scenario(llama_port, args.runs) + + if args.mode in ("python", "both") and python_port: + print(f"\n{'#' * 60}") + print(f"# ITERATIVE EDITING (Python port {python_port})") + print(f"# Same scenario through Python server for comparison") + print(f"{'#' * 60}") + + run_iterative_scenario_python(python_port, args.runs) + + print(f"\n{'=' * 60}") + print("Done.") + + +# --- Iterative editing scenarios --- + +ITERATIVE_BASE = '''\ +import json +import os + +def process_items(items: list[dict]) -> list[dict]: + results = [] + for item in items: + results.append(item) + return results +''' + +# Simulate adding validation lines one at a time +ITERATIVE_ADDITIONS = [ + (' if not item.get("id"):\n continue\n', "skip items without id"), + (' if not item.get("name"):\n continue\n', "skip items without name"), + (' if item.get("score", 0) < 0:\n continue\n', "skip negative scores"), + (' if len(item.get("name", "")) > 100:\n continue\n', "skip long names"), + (' if item.get("deleted", False):\n continue\n', "skip deleted items"), + (' if not isinstance(item.get("score", 0), (int, float)):\n continue\n', "skip non-numeric scores"), + (' if item.get("name", "").strip() == "":\n continue\n', "skip empty names"), + (' if item.get("id") in seen_ids:\n continue\n seen_ids.add(item.get("id"))\n', "skip duplicates"), +] + + +def build_iterative_prompt(file_contents, cursor_position, recent_change, prev_file): + """Build a NES prompt for iterative editing.""" + lines = file_contents.split("\n") + cursor_line = file_contents[:cursor_position].count("\n") + + # Build code block around cursor (simplified) + start = max(0, cursor_line - 2) + end = min(len(lines), cursor_line + 6) + code_block_lines = lines[start:end] + code_block = "\n".join(code_block_lines) + + prev_block = prev_file.split("\n") + prev_start = max(0, cursor_line - 2) + prev_end = min(len(prev_block), cursor_line + 6) + prev_section = "\n".join(prev_block[prev_start:prev_end]) + + return f"""<|file_sep|>process_items.py +{prev_file} +{recent_change} +<|file_sep|>original/process_items.py:{start+1}:{end+1} +{prev_section} +<|file_sep|>current/process_items.py:{start+1}:{end+1} +{code_block} +<|file_sep|>updated/process_items.py:{start+1}:{end+1} +""" + + +def run_iterative_scenario(llama_port, runs): + """Simulate iterative line-by-line additions to test ngram-mod learning.""" + print(f"\n{'=' * 60}") + print(f"Scenario: Iterative additions ({len(ITERATIVE_ADDITIONS)} steps x {runs} runs)") + print(f"{'=' * 60}") + + insert_point = ITERATIVE_BASE.index(" results.append(item)") + + for run in range(runs): + current_file = ITERATIVE_BASE + prev_file = ITERATIVE_BASE + run_times = [] + + for i, (addition, desc) in enumerate(ITERATIVE_ADDITIONS): + # Insert the new line before results.append + current_insert = current_file.index(" results.append(item)") + new_file = current_file[:current_insert] + addition + current_file[current_insert:] + + # Build diff for recent_changes + recent_change = f"<|file_sep|>process_items.py:1:1\noriginal:\n results.append(item)\nupdated:\n{addition} results.append(item)" + + cursor_pos = current_insert + len(addition) + prompt = build_iterative_prompt(new_file, cursor_pos, recent_change, prev_file) + + result = make_llama_request(llama_port, prompt) + run_times.append(result["round_trip_ms"]) + + prev_file = current_file + current_file = new_file + + times_str = " ".join(f"{t:.0f}" for t in run_times) + avg = statistics.mean(run_times) + print(f" run {run+1}: [{times_str}] ms avg={avg:.0f}ms trend={'↓' if run_times[-1] < run_times[0] else '→'}") + + print(f"\n If ngram-mod is learning, later additions should be faster than earlier ones.") + + +def run_iterative_scenario_python(python_port, runs): + """Same iterative scenario through Python server.""" + print(f"\n{'=' * 60}") + print(f"Scenario: Iterative additions ({len(ITERATIVE_ADDITIONS)} steps x {runs} runs)") + print(f"{'=' * 60}") + + for run in range(runs): + current_file = ITERATIVE_BASE + prev_file = ITERATIVE_BASE + run_times = [] + + for i, (addition, desc) in enumerate(ITERATIVE_ADDITIONS): + current_insert = current_file.index(" results.append(item)") + new_file = current_file[:current_insert] + addition + current_file[current_insert:] + + recent_change = f"""File: process_items.py +@@ -{6+i},{1} +{6+i},{addition.count(chr(10))+1} @@ + results.append(item) ++{addition.rstrip()}""" + + cursor_pos = current_insert + len(addition) + + result = make_python_request(python_port, cursor_pos, recent_change) + run_times.append(result["round_trip_ms"]) + + prev_file = current_file + current_file = new_file + + times_str = " ".join(f"{t:.0f}" for t in run_times) + avg = statistics.mean(run_times) + print(f" run {run+1}: [{times_str}] ms avg={avg:.0f}ms") + + +if __name__ == "__main__": + main() diff --git a/bin/generate_parser_fixtures.py b/bin/generate_parser_fixtures.py new file mode 100644 index 0000000..3245907 --- /dev/null +++ b/bin/generate_parser_fixtures.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Generate test fixtures for NesCompletionParser.""" +import json +import sys +sys.path.insert(0, "/var/folders/9b/jr54dbsj28128y6wjclqrs9h0000gn/T/sweep-pkg/sweep_src") + +from sweep_autocomplete.autocomplete.next_edit_autocomplete import ( + get_ghost_text_with_location, + find_ghost_text_non_local, + is_single_line_ghost_text, + is_pure_insertion_above_cursor, + apply_completions_to_code_block, + select_best_hunk_from_completion, + AutocompleteResult, +) + + +def generate_fixtures(): + fixtures = {} + + # --- ghost text tests --- + code_block = "def hello():\n print('hi')\n return True\n" + # Insertion at cursor position 25 (after "print('hi')") + completion_ghost = "def hello():\n print('hi')\n print('bye')\n return True\n" + cursor_pos = len("def hello():\n print('hi')\n") + + fixtures["ghost_text_with_location"] = { + "input": {"completion": completion_ghost, "code_block": code_block, "cursor_pos": cursor_pos}, + "output": get_ghost_text_with_location(completion_ghost, code_block, cursor_pos), + } + + fixtures["find_ghost_text_non_local"] = { + "input": {"completion": completion_ghost, "code_block": code_block, "cursor_pos": cursor_pos}, + "output": list(find_ghost_text_non_local(completion_ghost, code_block, cursor_pos)), + } + + # Single line ghost text + code_single = "x = 1\ny = \nz = 3\n" + comp_single = "x = 1\ny = 2\nz = 3\n" + cursor_single = len("x = 1\ny = ") + fixtures["single_line_ghost_text"] = { + "input": {"completion": comp_single, "code_block": code_single, "cursor_pos": cursor_single}, + "output": is_single_line_ghost_text(comp_single, code_single, cursor_single), + } + + # No ghost text (different text) + fixtures["single_line_ghost_text_empty"] = { + "input": {"completion": "totally different", "code_block": code_single, "cursor_pos": cursor_single}, + "output": is_single_line_ghost_text("totally different", code_single, cursor_single), + } + + # --- pure insertion above cursor --- + code_pure = "line1\nline2\nline3\n" + comp_pure = "line1\nnew_line\nline2\nline3\n" + fixtures["is_pure_insertion_above_cursor_true"] = { + "input": {"code_block": code_pure, "completion": comp_pure, "cursor_pos": len("line1\nline2\n")}, + "output": is_pure_insertion_above_cursor(code_pure, comp_pure, len("line1\nline2\n")), + } + + # --- select_best_hunk_from_completion --- + file_contents = "import os\n\ndef process():\n x = 1\n y = 2\n z = x + y\n return z\n" + code_block_2 = " x = 1\n y = 2\n z = x + y\n return z\n" + completion_2 = " x = 10\n y = 2\n z = x + y\n return z\n" + cursor_2 = file_contents.index(" x = 1") + len(" x = 1") + block_start = file_contents.index(code_block_2) + + results = select_best_hunk_from_completion( + completion_2, code_block_2, file_contents, cursor_2, "test-id" + ) + fixtures["select_best_hunk_simple_change"] = { + "input": { + "completion": completion_2, + "cleaned_code_block": code_block_2, + "file_contents": file_contents, + "cursor_position": cursor_2, + "autocomplete_id": "test-id", + }, + "output": [ + {"start_index": r.start_index, "end_index": r.end_index, + "completion": r.completion, "confidence": r.confidence, + "autocomplete_id": r.autocomplete_id} + for r in results + ], + } + + # --- apply_completions_to_code_block --- + completions = [AutocompleteResult( + start_index=block_start, + end_index=block_start + len(" x = 1"), + completion=" x = 10", + confidence=1.0, + autocomplete_id="test-0", + )] + applied = apply_completions_to_code_block(completions, file_contents, code_block_2) + fixtures["apply_completions_to_code_block"] = { + "input": { + "file_contents": file_contents, + "cleaned_code_block": code_block_2, + "completions": [{"start_index": c.start_index, "end_index": c.end_index, + "completion": c.completion} for c in completions], + }, + "output": applied, + } + + output_path = "src/test/resources/nes_fixtures/parser_fixtures.json" + with open(output_path, "w") as f: + json.dump(fixtures, f, indent=2, ensure_ascii=False) + print(f"Wrote {len(fixtures)} fixtures to {output_path}") + + +if __name__ == "__main__": + generate_fixtures() diff --git a/bin/generate_test_fixtures.py b/bin/generate_test_fixtures.py new file mode 100644 index 0000000..67a991b --- /dev/null +++ b/bin/generate_test_fixtures.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" +Generate JSON test fixtures by running the Python NES functions. +These fixtures are used by Kotlin unit tests for parity verification. + +Usage: + uvx --no-env-file --with sweep-autocomplete python3 bin/generate_test_fixtures.py +""" +import json +import sys +sys.path.insert(0, "/var/folders/9b/jr54dbsj28128y6wjclqrs9h0000gn/T/sweep-pkg/sweep_src") + +from sweep_autocomplete.autocomplete.next_edit_autocomplete_utils import ( + extract_diff_parts, + filter_whitespace_only_hunks, + split_into_hunks, + get_line_number_from_position, + get_lines_around_cursor, + strip_leading_empty_newlines, + parse_hunk, + split_into_diff_hunks, + is_large_diff_above_cursor, + truncate_long_lines, + should_disable_for_code_block, + is_equal_ignoring_newlines, +) +from sweep_autocomplete.autocomplete.next_edit_autocomplete import ( + format_recent_changes_and_prev_section, + get_block_at_cursor, + select_best_hunk_from_completion, + AutocompleteResult, +) +from sweep_autocomplete.utils.str_utils import pack_items_for_prompt + +import regex +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +pretokenize_regex = regex.compile(PRETOKENIZE_REGEX) + + +def generate_fixtures(): + fixtures = {} + + # --- extract_diff_parts --- + hunk1 = """@@ -28,7 +28,7 @@ + transformed = { + "id": item["id"], +- "score": round(item.get("score", 0) * 100, 2), ++ "value": round(item.get("score", 0) * 100, 2), + }""" + fixtures["extract_diff_parts_no_context"] = { + "input": {"hunk": hunk1, "num_context_lines": 0}, + "output": list(extract_diff_parts(hunk1, 0)), + } + fixtures["extract_diff_parts_one_context"] = { + "input": {"hunk": hunk1, "num_context_lines": 1}, + "output": list(extract_diff_parts(hunk1, 1)), + } + fixtures["extract_diff_parts_all"] = { + "input": {"hunk": hunk1, "num_context_lines": -1}, + "output": list(extract_diff_parts(hunk1, -1)), + } + + # --- split_into_hunks --- + diff1 = """File: src/main.py +@@ -1,3 +1,3 @@ +-old ++new +File: src/utils.py +@@ -5,2 +5,2 @@ +-foo ++bar""" + fixtures["split_into_hunks"] = { + "input": diff1, + "output": split_into_hunks(diff1), + } + + # --- get_line_number_from_position --- + text1 = "line0\nline1\nline2\nline3" + fixtures["get_line_number_from_position"] = { + "input": {"text": text1}, + "outputs": { + "pos_0": get_line_number_from_position(text1, 0), + "pos_6": get_line_number_from_position(text1, 6), + "pos_12": get_line_number_from_position(text1, 12), + "pos_end": get_line_number_from_position(text1, len(text1)), + "pos_negative": get_line_number_from_position(text1, -1), + } + } + + # --- strip_leading_empty_newlines --- + fixtures["strip_leading_empty_newlines"] = { + "input": "\n\n\nhello\nworld", + "output": strip_leading_empty_newlines("\n\n\nhello\nworld"), + } + + # --- filter_whitespace_only_hunks --- + hunks_ws = [ + "File: a.py\n@@ -1,1 +1,1 @@\n- foo\n+ foo", # whitespace only + "File: b.py\n@@ -1,1 +1,1 @@\n-old\n+new", # real change + ] + fixtures["filter_whitespace_only_hunks"] = { + "input": hunks_ws, + "output": filter_whitespace_only_hunks(hunks_ws), + } + + # --- split_into_diff_hunks --- + input_code = "line1\nline2\nline3\nline4\n" + output_code = "line1\nchanged2\nline3\nnew_line4\nextra_line\n" + hunks_result = split_into_diff_hunks(input_code, output_code) + fixtures["split_into_diff_hunks"] = { + "input": {"input_content": input_code, "output_content": output_code}, + "output": [ + { + "input_start": h[0], + "input_lines": list(h[1]), + "output_start": h[2], + "output_lines": list(h[3]), + } + for h in hunks_result + ], + } + + # --- is_large_diff_above_cursor --- + orig = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n" + comp_large = "a\nb\nc\nd\ne\nf\ng\nh\nline7\nline8\n" + comp_small = "line1\nchanged\nline3\nline4\nline5\nline6\nline7\nline8\n" + fixtures["is_large_diff_above_cursor"] = { + "large_diff": is_large_diff_above_cursor(orig, comp_large, 20), + "small_diff": is_large_diff_above_cursor(orig, comp_small, 20), + "same": is_large_diff_above_cursor(orig, orig, 20), + } + + # --- truncate_long_lines --- + long_line = "x" * 700 + fixtures["truncate_long_lines"] = { + "input": f"short\n{long_line}\nshort2", + "output": truncate_long_lines(f"short\n{long_line}\nshort2"), + } + + # --- should_disable_for_code_block --- + fixtures["should_disable_for_code_block"] = { + "short_lines": should_disable_for_code_block("short\nlines\n"), + "long_line": should_disable_for_code_block("short\n" + "x" * 1001 + "\nshort"), + } + + # --- is_equal_ignoring_newlines --- + fixtures["is_equal_ignoring_newlines"] = { + "equal": is_equal_ignoring_newlines("a\n\nb", "a\nb"), + "not_equal": is_equal_ignoring_newlines("a\nb", "a b"), + } + + # --- pretokenize --- + text_pt = "def hello_world(x: int) -> str:\n return f'Hello {x}'" + tokens_pt = pretokenize_regex.findall(text_pt) + fixtures["pretokenize"] = { + "input": text_pt, + "output": tokens_pt, + } + + # --- get_lines_around_cursor (small file) --- + small_file = "\n".join(f"line{i}" for i in range(50)) + fixtures["get_lines_around_cursor_small"] = { + "input": {"text": small_file, "cursor_position": 30}, + "output": get_lines_around_cursor(small_file, 30), + } + + # --- get_block_at_cursor --- + file_contents = """import json + +def process(items): + results = [] + for item in items: + transformed = {"id": item["id"], "name": item["name"]} + results.append(transformed) + return results + +def main(): + data = json.load(open("data.json")) + print(process(data)) +""" + cursor_pos = file_contents.index("results.append") + len("results.append") + code_block, prefix, suffix, block_start = get_block_at_cursor(file_contents, cursor_pos) + fixtures["get_block_at_cursor"] = { + "input": {"file_contents": file_contents, "cursor_position": cursor_pos}, + "output": { + "code_block": code_block, + "prefix": prefix, + "suffix": suffix, + "block_start_index": block_start, + }, + } + + # --- pack_items_for_prompt --- + items = ["short", "medium length text", "a very long string that takes up space"] + packed = pack_items_for_prompt(items, str, token_limit=10, char_token_ratio=3.5) + fixtures["pack_items_for_prompt"] = { + "input": {"items": items, "token_limit": 10}, + "output": packed, + } + + # --- parse_hunk --- + hunk_str = "@@ -10,3 +10,4 @@\n \n context line\n-old line\n+new line1\n+new line2\n context after\n" + result = parse_hunk(hunk_str) + fixtures["parse_hunk"] = { + "input": hunk_str, + "output": { + "input_start": result[0], + "input_lines": list(result[1]), + "output_start": result[2], + "output_lines": list(result[3]), + }, + } + + # Write all fixtures + output_path = "src/test/resources/nes_fixtures/utils_fixtures.json" + with open(output_path, "w") as f: + json.dump(fixtures, f, indent=2, ensure_ascii=False) + print(f"Wrote {len(fixtures)} fixtures to {output_path}") + + +if __name__ == "__main__": + generate_fixtures() diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/EditAutocompleteModels.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/EditAutocompleteModels.kt index 585c744..cb37102 100644 --- a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/EditAutocompleteModels.kt +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/EditAutocompleteModels.kt @@ -119,8 +119,12 @@ data class NextEditAutocompleteResponse( val elapsed_time_ms: Long? = null, // this is the new completion var completions: List, + // When true, indices are already JVM-native (from native Kotlin engine) — skip Python→JVM conversion + @kotlinx.serialization.Transient + var nativeIndices: Boolean = false, ) { fun adjustIndices(text: String) { + if (nativeIndices) return // indices already in JVM string space completions.forEach { it.adjustIndices(text) } } diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/LlamaServerClient.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/LlamaServerClient.kt new file mode 100644 index 0000000..98a7d26 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/LlamaServerClient.kt @@ -0,0 +1,180 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.google.gson.Gson +import com.google.gson.JsonObject +import com.google.gson.JsonParser +import com.intellij.openapi.diagnostic.Logger +import java.io.BufferedReader +import java.io.InputStreamReader +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.time.Duration +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicReference + +/** + * HTTP client for llama-server's OpenAI-compatible /v1/completions endpoint. + * + * Uses SSE streaming to enable early abort when the completion exceeds + * the expected code block size, saving GPU time on bad generations. + */ +class LlamaServerClient( + private val baseUrl: String, + private val timeoutMs: Long = 10_000, +) { + private val logger = Logger.getInstance(LlamaServerClient::class.java) + private val httpClient = HttpClient.newBuilder() + .connectTimeout(Duration.ofMillis(5000)) + .build() + private val gson = Gson() + private val requestCounter = AtomicLong(0) + private val currentRequestThread = AtomicReference(null) + + data class CompletionResult( + val text: String, + val elapsedMs: Long, + val finishReason: String?, + ) + + class RequestCancelledException : Exception("Request cancelled by newer request") + + /** + * Generate a completion from llama-server using SSE streaming. + * + * Streams tokens and aborts early if: + * - The output exceeds maxOutputChars (estimated from code block size) + * - A stop token is detected in the accumulated text + * - A newer request has been enqueued (thread interrupt) + * + * @param maxOutputChars Abort if accumulated text exceeds this length. + * Set to 0 to disable early abort (use max_tokens only). + */ + fun generateCompletion( + prompt: String, + stop: List = NesConstants.STOP_TOKENS, + maxTokens: Int = NesConstants.AUTOCOMPLETE_OUTPUT_MAX_TOKENS, + temperature: Float = 0.0f, + maxOutputChars: Int = 0, + ): CompletionResult { + val myId = requestCounter.incrementAndGet() + + // Cancel any in-flight request by interrupting its thread + currentRequestThread.getAndSet(Thread.currentThread())?.interrupt() + + if (myId != requestCounter.get()) { + throw RequestCancelledException() + } + + val requestBody = mapOf( + "prompt" to prompt, + "stop" to stop, + "max_tokens" to maxTokens, + "temperature" to temperature, + "n_predict" to maxTokens, + "stream" to true, + ) + + val json = gson.toJson(requestBody) + val url = "$baseUrl/v1/completions" + + val request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofMillis(timeoutMs)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(json)) + .build() + + val start = System.currentTimeMillis() + + val response = try { + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()) + } catch (e: java.io.IOException) { + if (Thread.interrupted() || myId != requestCounter.get()) { + throw RequestCancelledException() + } + throw e + } catch (e: InterruptedException) { + Thread.currentThread().interrupt() + throw RequestCancelledException() + } finally { + currentRequestThread.compareAndSet(Thread.currentThread(), null) + } + + Thread.interrupted() // clear interrupt flag + + if (response.statusCode() != 200) { + val body = response.body().bufferedReader().readText() + logger.warn("llama-server returned ${response.statusCode()}: $body") + return CompletionResult("", System.currentTimeMillis() - start, "error") + } + + // Stream SSE events, accumulate text, abort early if needed + val accumulated = StringBuilder() + var finishReason: String? = null + var abortedEarly = false + + try { + BufferedReader(InputStreamReader(response.body())).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + val l = line ?: continue + if (!l.startsWith("data: ")) continue + val data = l.removePrefix("data: ").trim() + if (data == "[DONE]") break + + try { + val event = JsonParser.parseString(data).asJsonObject + val choices = event.getAsJsonArray("choices") + if (choices != null && choices.size() > 0) { + val choice = choices[0].asJsonObject + val text = choice.get("text")?.asString ?: "" + accumulated.append(text) + finishReason = choice.get("finish_reason")?.let { + if (it.isJsonNull) null else it.asString + } + } + } catch (_: Exception) { + // Skip malformed SSE events + } + + // Early abort: output exceeds expected code block size + if (maxOutputChars > 0 && accumulated.length > maxOutputChars) { + logger.info("Early abort: output ${accumulated.length} chars > limit $maxOutputChars") + abortedEarly = true + finishReason = "length" + break + } + } + } + } catch (e: java.io.IOException) { + if (Thread.interrupted() || myId != requestCounter.get()) { + throw RequestCancelledException() + } + // Stream closed — use whatever we accumulated + } + + val elapsedMs = System.currentTimeMillis() - start + val text = accumulated.toString() + + logger.info("llama-server completion: ${text.length} chars, ${elapsedMs}ms, finish=$finishReason${if (abortedEarly) " (early abort)" else ""}") + + return CompletionResult(text, elapsedMs, finishReason) + } + + /** Health check — returns true if llama-server is reachable. */ + fun isHealthy(): Boolean { + return try { + val request = HttpRequest.newBuilder() + .uri(URI.create("$baseUrl/health")) + .timeout(Duration.ofMillis(3000)) + .GET() + .build() + val response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()) + response.statusCode() == 200 + } catch (e: Exception) { + false + } + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParser.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParser.kt new file mode 100644 index 0000000..ee0c82b --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParser.kt @@ -0,0 +1,346 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import kotlin.math.abs + +/** + * Post-processing of LLM completions into concrete edit suggestions. + * Ported from Python next_edit_autocomplete.py (select_best_hunk_from_completion and helpers). + * No IntelliJ dependencies — fully unit-testable. + */ +object NesCompletionParser { + + data class AutocompleteResult( + val startIndex: Int, + val endIndex: Int, + val completion: String, + val confidence: Float, + val autocompleteId: String, + ) + + /** + * Check if the completion is a pure ghost text insertion at the given position. + * Returns the ghost text if found, empty string otherwise. + */ + fun getGhostTextWithLocation( + completion: String, + cleanedCodeBlock: String, + relativeCursorPosition: Int, + ): String { + val prefix = cleanedCodeBlock.substring(0, relativeCursorPosition) + val suffix = cleanedCodeBlock.substring(relativeCursorPosition) + if (completion.startsWith(prefix) && completion.endsWith(suffix)) { + val ghostText = if (suffix.isNotEmpty()) { + completion.substring(prefix.length, completion.length - suffix.length) + } else { + completion.substring(prefix.length) + } + if (ghostText.isNotEmpty()) return ghostText + } + return "" + } + + /** + * Find the best ghost text position by checking all possible split points. + * Returns (ghostText, position) or ("", -1) if not found. + */ + fun findGhostTextNonLocal( + completion: String, + cleanedCodeBlock: String, + relativeCursorPosition: Int, + ): Pair { + if (cleanedCodeBlock.length > completion.length) return "" to -1 + + val validPositions = mutableListOf>() + for (pos in 0..cleanedCodeBlock.length) { + val ghostText = getGhostTextWithLocation(completion, cleanedCodeBlock, pos) + if (ghostText.isNotEmpty()) { + validPositions.add(pos to ghostText) + } + } + + if (validPositions.isNotEmpty()) { + val (bestPos, bestText) = validPositions.maxByOrNull { it.first }!! + return bestText to bestPos + } + return "" to -1 + } + + /** + * Check if the completion is a single-line ghost text at the cursor position. + */ + fun isSingleLineGhostText( + completion: String, + cleanedCodeBlock: String, + relativeCursorPosition: Int, + ): String { + if (cleanedCodeBlock.length < relativeCursorPosition) return "" + + val prefix = cleanedCodeBlock.substring(0, relativeCursorPosition) + val suffix = cleanedCodeBlock.substring(relativeCursorPosition) + if (completion.startsWith(prefix) && completion.endsWith(suffix)) { + val ghostText = if (suffix.isNotEmpty()) { + completion.substring(prefix.length, completion.length - suffix.length) + } else { + completion.substring(prefix.length) + } + if (ghostText.isNotEmpty() && '\n' !in ghostText) return ghostText + } + return "" + } + + /** + * Check if completion is a pure insertion above the cursor position. + */ + fun isPureInsertionAboveCursor( + cleanedCodeBlock: String, + completion: String, + relativeCursorPosition: Int, + ): Boolean { + val currentLineIndex = cleanedCodeBlock.substring(0, relativeCursorPosition) + .linesSplitKeepEnds().size + val codeBlockLines = cleanedCodeBlock.linesSplitKeepEnds() + if (currentLineIndex < 1 || currentLineIndex > codeBlockLines.size) return false + val cursorLine = codeBlockLines[currentLineIndex - 1] + + if (cleanedCodeBlock.trim() == completion.trim()) return false + if (cursorLine.trim().isEmpty()) return false + + val prefixLines = codeBlockLines.take(currentLineIndex - 1) + val prefix = prefixLines.joinToString("") + val suffixLines = codeBlockLines.drop(currentLineIndex) + val suffix = suffixLines.joinToString("") + + return completion.startsWith(prefix) && completion.endsWith(cursorLine + suffix) + } + + /** + * Apply completions to the code block and return the modified version. + */ + fun applyCompletionsToCodeBlock( + completions: List, + fileContents: String, + cleanedCodeBlock: String, + ): String { + if (completions.isEmpty()) return cleanedCodeBlock + + val cleanedCodeStartIndex = fileContents.indexOf(cleanedCodeBlock) + if (cleanedCodeStartIndex == -1) return cleanedCodeBlock + + var modifiedCodeBlock = cleanedCodeBlock + val sorted = completions.sortedByDescending { it.startIndex } + + for (comp in sorted) { + val relativeStart = comp.startIndex - cleanedCodeStartIndex + val relativeEnd = comp.endIndex - cleanedCodeStartIndex + if (relativeStart >= 0 && relativeStart <= cleanedCodeBlock.length && relativeEnd <= cleanedCodeBlock.length) { + modifiedCodeBlock = modifiedCodeBlock.substring(0, relativeStart) + + comp.completion + + modifiedCodeBlock.substring(relativeEnd) + } + } + return modifiedCodeBlock + } + + /** + * Find the best hunk from the completion to suggest as an edit. + * This is the main post-processing function. + * + * Ported from Python select_best_hunk_from_completion(). + */ + fun selectBestHunkFromCompletion( + completionRaw: String, + cleanedCodeBlockRaw: String, + fileContents: String, + cursorPosition: Int, + autocompleteId: String, + ): List { + val completion = NesUtils.stripLeadingEmptyNewlines(completionRaw) + val cleanedCodeBlock = NesUtils.stripLeadingEmptyNewlines(cleanedCodeBlockRaw) + + if (completion == cleanedCodeBlock) return emptyList() + + val blockStartOffset = fileContents.indexOf(cleanedCodeBlock) + if (blockStartOffset == -1) return emptyList() + + val relativeCursorPosition = cursorPosition - blockStartOffset + + // Check for ghost text + val (ghostText, ghostTextPosition) = findGhostTextNonLocal( + completion, cleanedCodeBlock, relativeCursorPosition + ) + if (ghostText.isNotEmpty()) { + val isInsertNextLine = ghostTextPosition == relativeCursorPosition + 1 && + cleanedCodeBlock[ghostTextPosition - 1] == '\n' + val insertionStartsWithNewline = ghostText.startsWith("\n") && + ghostTextPosition == relativeCursorPosition + + if (isInsertNextLine || insertionStartsWithNewline) { + return listOf( + AutocompleteResult( + ghostTextPosition + blockStartOffset, + ghostTextPosition + blockStartOffset, + ghostText, 1.0f, "$autocompleteId-0" + ) + ) + } + + val ghostLines = ghostText.linesSplitKeepEnds() + val firstLine = ghostLines.first() + val remainingGhostText = ghostLines.drop(1).joinToString("") + + if (remainingGhostText.isNotEmpty()) { + val trimmedRemaining = remainingGhostText.trimEnd() + if (ghostTextPosition < cleanedCodeBlock.length && + cleanedCodeBlock[ghostTextPosition] == '\n' + ) { + val trailingLen = firstLine.length - firstLine.trimEnd('\n').length + val trailing = if (trailingLen > 0) firstLine.takeLast(trailingLen) else "" + val firstLineClean = firstLine.trimEnd('\n') + val fullRemaining = trailing + trimmedRemaining + + return listOf( + AutocompleteResult( + ghostTextPosition + blockStartOffset, + ghostTextPosition + blockStartOffset, + firstLineClean, 1.0f, "$autocompleteId-0" + ), + AutocompleteResult( + ghostTextPosition + blockStartOffset, + ghostTextPosition + blockStartOffset, + fullRemaining, 1.0f, "$autocompleteId-1" + ), + ) + } + } + + return listOf( + AutocompleteResult( + ghostTextPosition + blockStartOffset, + ghostTextPosition + blockStartOffset, + ghostText, 1.0f, "$autocompleteId-0" + ) + ) + } + + // Fall through to diff-based hunk selection + val hunks = NesUtils.splitIntoDiffHunks(cleanedCodeBlock, completion) + if (hunks.isEmpty()) return emptyList() + + // Process each hunk to get absolute positions + val originalLines = cleanedCodeBlock.linesSplitKeepEnds() + val processedHunks = mutableListOf>() + + for (hunk in hunks) { + val startLineIdx = hunk.inputStart - 1 + var startOffset = blockStartOffset + for (i in 0 until startLineIdx) { + if (i < originalLines.size) startOffset += originalLines[i].length + } + + var endOffset = startOffset + for (i in 0 until hunk.inputLines.size) { + val lineIdx = startLineIdx + i + if (lineIdx < originalLines.size) endOffset += originalLines[lineIdx].length + } + + var newText = hunk.outputLines.joinToString("") + + // Hack for end of file + if (startOffset == fileContents.length && fileContents[startOffset - 1] != '\n') { + newText = "\n$newText" + } + + if (fileContents.substring(startOffset, endOffset) != newText) { + processedHunks.add(Triple(startOffset, endOffset, newText)) + } + } + + val hunksAfterCursor = processedHunks.filter { it.second >= cursorPosition } + .sortedBy { it.first } + val hunksBeforeCursor = processedHunks.filter { it.second < cursorPosition } + .sortedBy { it.first } + + // Handle hunks after cursor + if (hunksAfterCursor.isNotEmpty()) { + val (startOffset, endOffset, newText) = hunksAfterCursor.first() + val restHunks = hunksAfterCursor.drop(1) + val startLinePosition = fileContents.substring(0, cursorPosition).lastIndexOf('\n') + 1 + + val results = mutableListOf() + + val shouldSplit = startLinePosition == startOffset && + startOffset <= cursorPosition && cursorPosition < endOffset && + newText.count { it == '\n' } > 0 + + if (shouldSplit) { + val newTextLines = newText.linesSplitKeepEnds() + val firstLine = newTextLines.first() + val remainingNewText = newTextLines.drop(1).joinToString("") + + val originalTextSection = fileContents.substring(cursorPosition, endOffset) + val firstNewlinePos = originalTextSection.indexOf('\n') + val firstLineEnd = if (firstNewlinePos != -1) cursorPosition + firstNewlinePos + 1 else endOffset + + val endLinePosition = fileContents.indexOf('\n', cursorPosition).let { if (it == -1) fileContents.length else it } + val currentCursorLineContents = fileContents.substring(startLinePosition, endLinePosition) + + if (firstLine.startsWith(currentCursorLineContents)) { + results.add(AutocompleteResult(startOffset, firstLineEnd, firstLine, 1.0f, "$autocompleteId-0")) + if (remainingNewText.isNotEmpty()) { + results.add(AutocompleteResult(firstLineEnd, endOffset, remainingNewText, 1.0f, "$autocompleteId-1")) + } + } else { + results.add(AutocompleteResult(startOffset, endOffset, newText, 1.0f, "$autocompleteId-0")) + } + } else { + results.add(AutocompleteResult(startOffset, endOffset, newText, 1.0f, "$autocompleteId-0")) + } + + val maxId = restHunks.size + results.addAll(restHunks.mapIndexed { i, (s, e, t) -> + AutocompleteResult(s, e, t, 1.0f, "$autocompleteId-${maxId + i}") + }) + return results + } + + // Handle hunks before cursor — fuse nearby hunks + if (hunksBeforeCursor.isNotEmpty()) { + val fusedGroups = mutableListOf(mutableListOf(hunksBeforeCursor.first())) + + for (i in 1 until hunksBeforeCursor.size) { + val prevHunk = fusedGroups.last().last() + val currHunk = hunksBeforeCursor[i] + val prevEndLine = NesUtils.getLineNumberFromPosition(fileContents, prevHunk.second) + val currStartLine = NesUtils.getLineNumberFromPosition(fileContents, currHunk.first) + + if (currStartLine - prevEndLine <= 2) { + fusedGroups.last().add(currHunk) + } else { + fusedGroups.add(mutableListOf(currHunk)) + } + } + + val results = fusedGroups.mapIndexed { groupIdx, group -> + val firstStart = group.first().first + val lastEnd = group.last().second + + val combinedParts = mutableListOf() + var currentOffset = firstStart + for ((s, e, t) in group) { + if (currentOffset < s) combinedParts.add(fileContents.substring(currentOffset, s)) + combinedParts.add(t) + currentOffset = e + } + + AutocompleteResult( + firstStart, lastEnd, combinedParts.joinToString(""), + 1.0f, "$autocompleteId-$groupIdx" + ) + }.sortedBy { abs(it.startIndex - cursorPosition) } + + return results + } + + return emptyList() + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesConstants.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesConstants.kt new file mode 100644 index 0000000..47ae623 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesConstants.kt @@ -0,0 +1,63 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +/** + * Constants for the Next Edit Suggestion (NES) engine. + * Ported from Python sweep_autocomplete. + */ +object NesConstants { + // Block extraction around cursor + const val NUM_LINES_BEFORE = 2 + const val NUM_LINES_AFTER = 5 + + // Token estimation + const val CHARS_PER_TOKEN = 3.5 + + // Prompt truncation limits + val MAX_INPUT_TOKENS_COUNT = (8192 * 4) - 256 // ~8K tokens at 3.5 chars/token + val CHARACTER_BOUND_TO_CHECK_TOKENIZATION = (8192 * 2) - 256 + val CHARACTER_BOUND_TO_SKIP_TOKENIZATION = (8192 * 4) * 2 + + // Retrieval + const val MAX_RETRIEVAL_CHUNK_SIZE_LINES = 25 + const val MAX_RETRIEVAL_CHUNKS = 3 + const val MAX_RETRIEVAL_TOKENS_COUNT = 2048 + + // Generation + const val AUTOCOMPLETE_OUTPUT_MAX_TOKENS = 1024 + + // Line length limits + const val AUTOCOMPLETE_TRUNCATION_LINE_LENGTH = 600 + const val AUTOCOMPLETE_MAXIMUM_LINE_LENGTH = 1000 + + // User actions + const val NUM_RECENT_ACTIONS_TO_PRESERVE = 20 + + // Stop tokens + val STOP_TOKENS = listOf("<|endoftext|>", "<|file_sep|>") + + // Prompt templates — must match Python exactly + val PROMPT_TEMPLATE = """<|file_sep|>{file_path} +{initial_file}{retrieval_results} +{recent_changes} +<|file_sep|>original/{file_path}:{start_line}:{end_line} +{prev_section} +<|file_sep|>current/{file_path}:{start_line}:{end_line} +{code_block} +<|file_sep|>updated/{file_path}:{start_line}:{end_line}""" + + val DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line} +original: +{old_code} +updated: +{new_code}""" + + // Qwen2 pretokenizer regex (used for prefill computation) + // Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/tokenization_qwen2.py + const val PRETOKENIZE_REGEX = + """(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + // Chunk size for getLinesAroundCursor + const val CHUNK_SIZE = 300 + const val CHUNK_STRIDE = CHUNK_SIZE / 2 + const val LIMIT_TO_CHUNK = 800 +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesModelConfig.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesModelConfig.kt new file mode 100644 index 0000000..502c865 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesModelConfig.kt @@ -0,0 +1,49 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +/** + * Available NES model configurations for llama-server. + */ +data class NesModel( + val id: String, + val displayName: String, + val repo: String, + val filename: String, + val description: String, +) + +object NesModelConfig { + val MODELS = listOf( + NesModel( + id = "sweep-0.5B", + displayName = "Sweep 0.5B (Q8, fastest)", + repo = "sweepai/sweep-next-edit-0.5B", + filename = "sweep-next-edit-0.5b.q8_0.gguf", + description = "Smallest and fastest model, good for most edits", + ), + NesModel( + id = "sweep-1.5B", + displayName = "Sweep 1.5B (Q8)", + repo = "sweepai/sweep-next-edit-1.5B", + filename = "sweep-next-edit-1.5b.q8_0.v2.gguf", + description = "Better quality, slightly slower", + ), + NesModel( + id = "sweep-7B-q5", + displayName = "Sweep 7B v2 (Q5_K_M)", + repo = "henrik3/sweep-next-edit-v2-7B-GGUF", + filename = "q5_k_m.gguf", + description = "Highest quality, requires more RAM", + ), + NesModel( + id = "sweep-7B-q4", + displayName = "Sweep 7B v2 (Q4_K_M)", + repo = "henrik3/sweep-next-edit-v2-7B-GGUF", + filename = "q4_k_m.gguf", + description = "High quality, less RAM than Q5", + ), + ) + + val DEFAULT_MODEL_ID = "sweep-0.5B" + + fun getModel(id: String): NesModel = MODELS.find { it.id == id } ?: MODELS.first() +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesPromptBuilder.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesPromptBuilder.kt new file mode 100644 index 0000000..1424806 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesPromptBuilder.kt @@ -0,0 +1,389 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHARS_PER_TOKEN +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.NUM_LINES_AFTER +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.NUM_LINES_BEFORE +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.AUTOCOMPLETE_OUTPUT_MAX_TOKENS +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHARACTER_BOUND_TO_CHECK_TOKENIZATION +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHARACTER_BOUND_TO_SKIP_TOKENIZATION +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.MAX_INPUT_TOKENS_COUNT +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.MAX_RETRIEVAL_TOKENS_COUNT +import kotlin.math.max +import kotlin.math.min + +/** + * Constructs prompts for the NES model from editor state. + * Ported from Python _fetch_next_edits_core() prompt construction logic. + * No IntelliJ dependencies — fully unit-testable. + */ +object NesPromptBuilder { + + data class BlockAtCursor( + val codeBlock: String, + val prefix: String, + val suffix: String, + val blockStartIndex: Int, + ) + + data class FileChunkData( + val filePath: String, + val content: String, + val startLine: Int, + val endLine: Int, + ) { + fun toPromptString(): String = "<|file_sep|>$filePath\n$content\n" + } + + data class PromptBuildResult( + val formattedPrompt: String, + val cleanedCodeBlock: String, + val prefill: String, + val forcedPrefix: String, + val prevSections: List, + val relativeCursorPosition: Int, + val relativeCursorLine: Int, + val blockStartIndex: Int, + ) + + /** + * Extract the code block surrounding the cursor position. + * Ported from Python get_block_at_cursor(). + */ + fun getBlockAtCursor(fileContents: String, cursorPosition: Int): BlockAtCursor { + val lines = fileContents.linesSplitKeepEnds() + val cursorLine = NesUtils.getLineNumberFromPosition(fileContents, cursorPosition) + val (codeBlock, prefix, suffix) = getBlockAroundCursorLine( + lines, cursorLine, NUM_LINES_BEFORE, NUM_LINES_AFTER + ) + val blockStartLine = max(0, cursorLine - NUM_LINES_BEFORE) + val blockStartIndex = lines.take(blockStartLine).sumOf { it.length } + + val truncatedBlock = truncateCodeBlockByTokens(codeBlock) + + return BlockAtCursor(truncatedBlock, prefix, suffix, blockStartIndex) + } + + private fun getBlockAroundCursorLine( + lines: List, + cursorLine: Int, + numLinesBefore: Int, + numLinesAfter: Int, + ): Triple { + var blockStart = max(0, cursorLine - numLinesBefore) + var blockEnd = min(lines.size, cursorLine + numLinesAfter + 1) + + while (blockStart < blockEnd && lines[blockStart].trim().isEmpty()) { + blockStart++ + if (blockEnd < lines.size) blockEnd++ + } + while (blockEnd > blockStart && lines[blockEnd - 1].trim().isEmpty()) { + blockEnd-- + } + + var currentBlock = lines.subList(blockStart, blockEnd).joinToString("") + val prefixStart = max(0, blockStart - 10) + val prefix = lines.subList(prefixStart, blockStart).joinToString("").trim('\n') + val suffixEnd = min(lines.size, blockEnd + 10) + val suffix = lines.subList(blockEnd, suffixEnd).joinToString("").trim('\n') + + if (currentBlock.endsWith("\n")) { + currentBlock = currentBlock.trimEnd('\n') + "\n" + } + + return Triple(currentBlock, prefix, suffix) + } + + /** Public access for the engine's retrieval pass. */ + fun truncateCodeBlockByTokensPublic( + codeBlock: String, + maxTokenLimit: Int = AUTOCOMPLETE_OUTPUT_MAX_TOKENS / 2, + ): String = truncateCodeBlockByTokens(codeBlock, maxTokenLimit) + + private fun truncateCodeBlockByTokens( + codeBlock: String, + maxTokenLimit: Int = AUTOCOMPLETE_OUTPUT_MAX_TOKENS / 2, + ): String { + val codeBlockLines = codeBlock.linesSplitKeepEnds() + val prefilledCodeBlock = codeBlockLines.take(NUM_LINES_BEFORE).joinToString("") + val remainingCodeBlock = codeBlockLines.drop(NUM_LINES_BEFORE).joinToString("") + val estimatedTokens = NesUtils.estimateTokenCount(remainingCodeBlock) + + if (estimatedTokens > maxTokenLimit) { + val maxChars = (maxTokenLimit * CHARS_PER_TOKEN).toInt() + val truncated = remainingCodeBlock.substring(0, min(maxChars, remainingCodeBlock.length)) + val truncatedLines = truncated.linesSplitKeepEnds() + if (truncatedLines.size > 1) { + return prefilledCodeBlock + truncatedLines.dropLast(1).joinToString("") + } + } + return codeBlock + } + + /** + * Format recent changes into diff format and compute the previous section. + * Ported from Python format_recent_changes_and_prev_section(). + */ + fun formatRecentChangesAndPrevSection( + recentChanges: String, + currentSection: String, + ): Triple> { + val hunks = NesUtils.splitIntoHunks(recentChanges) + .filter { it.trim().lines().size > 1 } + .let { NesUtils.filterWhitespaceOnlyHunks(it) } + + var prevSection = currentSection.replace("<|cursor|>", "") + val prevSections = mutableListOf() + + if (hunks.isNotEmpty()) { + for (hunk in hunks.reversed()) { + val firstLine = hunk.lines().first() + val filePath = firstLine.removePrefix("File: ").trimEnd('\n') + val rest = hunk.linesSplitKeepEnds().drop(1).joinToString("") + val (oldCode, newCode) = NesUtils.extractDiffParts(rest) + val (oldCodeCtx, newCodeCtx) = NesUtils.extractDiffParts(rest, 1) + val parsed = NesUtils.parseHunk(rest) + val startLine = parsed.inputStart + val endLine = startLine + parsed.inputLines.size - 1 + + if (newCodeCtx.trim().isNotEmpty() && newCodeCtx in prevSection) { + prevSection = prevSection.replaceFirst(newCodeCtx, oldCodeCtx) + prevSections.add(prevSection) + } else if (newCode.trim().isNotEmpty() && newCode in prevSection) { + prevSection = prevSection.replaceFirst(newCode, oldCode) + prevSections.add(prevSection) + } else { + break + } + } + } + + // Format as diff_format + var result = "" + for (hunk in hunks.takeLast(6)) { + val firstLine = hunk.lines().first() + val filePath = firstLine.removePrefix("File: ").trimEnd('\n') + val rest = hunk.linesSplitKeepEnds().drop(1).joinToString("") + val (oldCode, newCode) = NesUtils.extractDiffParts(rest, 1) + val parsed = NesUtils.parseHunk(rest) + val startLine = parsed.inputStart + val endLine = startLine + parsed.inputLines.size - 1 + + if (oldCode.trim().isNotEmpty() || newCode.trim().isNotEmpty()) { + result += NesConstants.DIFF_FORMAT + .replace("{old_code}", oldCode.trim('\n')) + .replace("{new_code}", newCode.trim('\n')) + .replace("{file_path}", filePath) + .replace("{start_line}", startLine.toString()) + .replace("{end_line}", endLine.toString()) + "\n" + } + } + + return Triple(result.trimEnd('\n'), prevSection, prevSections) + } + + /** + * Build the full prompt for the NES model. + * + * @param filePath The file path + * @param fileContents Current file contents + * @param originalFileContents Original file contents (before recent edits) + * @param recentChanges Recent diff changes + * @param cursorPosition Cursor position in the file + * @param codeBlock Code block around cursor (from getBlockAtCursor) + * @param blockStartIndex Start index of the code block in the file + * @param fileChunks Additional file chunks for context + * @param retrievalChunks Retrieval chunks from similar code + * @param recentChangesHighRes High-resolution recent changes + * @param changesAboveCursor Whether recent changes are above the cursor + * @param forceGhostText Whether to force ghost text mode + * @param useRemoteEndpoint Whether a remote endpoint is being used (affects prefill logic) + */ + fun buildPrompt( + filePath: String, + fileContents: String, + originalFileContents: String, + recentChanges: String, + cursorPosition: Int, + codeBlock: String, + prefix: String, + suffix: String, + blockStartIndex: Int, + fileChunks: List = emptyList(), + retrievalChunks: List = emptyList(), + recentChangesHighRes: String = "", + changesAboveCursor: Boolean = false, + forceGhostText: Boolean = false, + useRemoteEndpoint: Boolean = false, + ): PromptBuildResult { + val relativeCursorPosition = cursorPosition - fileContents.indexOf(codeBlock) + var cleanedCodeBlock = codeBlock + val relativeCursorLine = NesUtils.getLineNumberFromPosition(codeBlock, relativeCursorPosition) + + // Insert cursor marker + val codeBlockWithCursor = codeBlock.substring(0, relativeCursorPosition) + + "<|cursor|>" + + codeBlock.substring(relativeCursorPosition) + + val (onlyChangedLines, prevSection, prevSections0) = + formatRecentChangesAndPrevSection(recentChanges, codeBlockWithCursor) + + val prevSections = if (recentChangesHighRes.isNotEmpty()) { + formatRecentChangesAndPrevSection(recentChangesHighRes, codeBlockWithCursor).third + } else { + prevSections0 + } + + // Compute prefill and forced prefix + val isAtEof = relativeCursorPosition == cleanedCodeBlock.length + val prefill: String + val forcedPrefix: String + + if (forceGhostText && !isAtEof && useRemoteEndpoint) { + val prefillCandidate = cleanedCodeBlock.substring(0, relativeCursorPosition) + val pretokens = NesUtils.pretokenize(prefillCandidate) + val regexBasedPrefill = if (pretokens.size > 1) pretokens.dropLast(1).joinToString("") else "" + prefill = regexBasedPrefill + forcedPrefix = cleanedCodeBlock.substring(0, relativeCursorPosition).removePrefix(prefill) + } else if (changesAboveCursor) { + forcedPrefix = "" + val prefillFull = cleanedCodeBlock.substring(0, relativeCursorPosition) + val prefillLines = prefillFull.linesSplitKeepEnds() + val numLinesAbove = 1 + var beforeSplit = prefillLines.take(numLinesAbove).joinToString("") + val afterSplit = prefillLines.drop(numLinesAbove).joinToString("") + for (char in afterSplit) { + if (char == '\n') beforeSplit += "\n" else break + } + prefill = beforeSplit + } else { + prefill = "" + forcedPrefix = "" + } + + // Pack retrieval chunks + val packedRetrievalChunks = NesUtils.packItemsForPrompt( + retrievalChunks, + { it.toPromptString() }, + MAX_RETRIEVAL_TOKENS_COUNT, + truncateFromEnd = false, + ) + val retrievalResults = packedRetrievalChunks.joinToString("") { "\n${it.toPromptString()}" } + + // Format code block and prev section + var formattedCodeBlock = codeBlockWithCursor + var formattedPrevSection = prevSection + if (formattedCodeBlock.endsWith("\n") && formattedPrevSection.endsWith("\n")) { + formattedCodeBlock = formattedCodeBlock.removeSuffix("\n") + formattedPrevSection = formattedPrevSection.removeSuffix("\n") + } + + val initialFile = NesUtils.getLinesAroundCursor(originalFileContents, cursorPosition) + + var formattedPrompt = NesConstants.PROMPT_TEMPLATE + .replace("{file_path}", filePath) + .replace("{recent_changes}", onlyChangedLines) + .replace("{prev_section}", formattedPrevSection) + .replace("{code_block}", formattedCodeBlock) + .replace("{retrieval_results}", retrievalResults) + .replace("{initial_file}", initialFile) + .replace("{start_line}", (relativeCursorLine + 1).toString()) + .replace("{end_line}", (relativeCursorLine + formattedCodeBlock.lines().size + 1).toString()) + + "\n$prefill" + + // Truncation logic + val formattedFileChunks = fileChunks.joinToString("") { it.toPromptString() } + + if (formattedPrompt.length + formattedFileChunks.length > CHARACTER_BOUND_TO_CHECK_TOKENIZATION) { + formattedPrompt = truncatePrompt( + filePath, onlyChangedLines, formattedPrevSection, formattedCodeBlock, + retrievalResults, initialFile, prefill, fileChunks, + relativeCursorLine, + ) ?: return PromptBuildResult( + "", cleanedCodeBlock, prefill, forcedPrefix, prevSections, + relativeCursorPosition, relativeCursorLine, blockStartIndex, + ) + } else { + formattedPrompt = formattedFileChunks + formattedPrompt + } + + // Truncate long lines + formattedPrompt = NesUtils.truncateLongLines(formattedPrompt) + + return PromptBuildResult( + formattedPrompt, cleanedCodeBlock, prefill, forcedPrefix, prevSections, + relativeCursorPosition, relativeCursorLine, blockStartIndex, + ) + } + + private fun truncatePrompt( + filePath: String, + recentChanges: String, + prevSection: String, + codeBlock: String, + retrievalResults: String, + initialFile: String, + prefill: String, + fileChunks: List, + relativeCursorLine: Int, + ): String? { + val minimalPrompt = NesConstants.PROMPT_TEMPLATE + .replace("{file_path}", filePath) + .replace("{recent_changes}", recentChanges) + .replace("{prev_section}", prevSection) + .replace("{code_block}", codeBlock) + .replace("{retrieval_results}", "") + .replace("{initial_file}", initialFile) + .replace("{start_line}", (relativeCursorLine + 1).toString()) + .replace("{end_line}", (relativeCursorLine + codeBlock.lines().size + 1).toString()) + + "\n$prefill" + + if (minimalPrompt.length > CHARACTER_BOUND_TO_SKIP_TOKENIZATION) return null + + val minimalTokens = NesUtils.estimateTokenCount(minimalPrompt) + val retrievalTokens = NesUtils.estimateTokenCount(retrievalResults) + val chunkTokens = fileChunks.map { NesUtils.estimateTokenCount(it.content) } + + // Case 1: everything fits + if (minimalTokens + retrievalTokens + chunkTokens.sum() <= MAX_INPUT_TOKENS_COUNT) { + val fullPrompt = NesConstants.PROMPT_TEMPLATE + .replace("{file_path}", filePath) + .replace("{recent_changes}", recentChanges) + .replace("{prev_section}", prevSection) + .replace("{code_block}", codeBlock) + .replace("{retrieval_results}", retrievalResults) + .replace("{initial_file}", initialFile) + .replace("{start_line}", (relativeCursorLine + 1).toString()) + .replace("{end_line}", (relativeCursorLine + codeBlock.lines().size + 1).toString()) + + "\n$prefill" + return fileChunks.joinToString("") { it.toPromptString() } + fullPrompt + } + + // Case 2: minimal prompt too long + if (minimalTokens > MAX_INPUT_TOKENS_COUNT) return null + + // Case 3: drop all file chunks + if (minimalTokens + retrievalTokens > MAX_INPUT_TOKENS_COUNT) return minimalPrompt + + // Case 4: fit as many file chunks as possible + val promptWithRetrieval = NesConstants.PROMPT_TEMPLATE + .replace("{file_path}", filePath) + .replace("{recent_changes}", recentChanges) + .replace("{prev_section}", prevSection) + .replace("{code_block}", codeBlock) + .replace("{retrieval_results}", retrievalResults) + .replace("{initial_file}", initialFile) + .replace("{start_line}", (relativeCursorLine + 1).toString()) + .replace("{end_line}", (relativeCursorLine + codeBlock.lines().size + 1).toString()) + + "\n$prefill" + + var currentTokenCount = minimalTokens + retrievalTokens + val partialChunks = StringBuilder() + for ((chunk, tokens) in fileChunks.zip(chunkTokens)) { + if (currentTokenCount + tokens >= MAX_INPUT_TOKENS_COUNT) break + partialChunks.append(chunk.toPromptString()) + currentTokenCount += tokens + } + + return partialChunks.toString() + promptWithRetrieval + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesRetrieval.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesRetrieval.kt new file mode 100644 index 0000000..582203e --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesRetrieval.kt @@ -0,0 +1,301 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.github.difflib.DiffUtils +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern +import kotlin.math.abs + +/** + * Retrieval logic for finding code blocks related to recent edits. + * Ported from Python next_edit_autocomplete_retrieval.py. + * No IntelliJ dependencies — fully unit-testable. + */ +object NesRetrieval { + + private val WORD_PATTERN = Pattern.compile("\\w+", Pattern.CASE_INSENSITIVE or Pattern.UNICODE_CHARACTER_CLASS) + private val WORD_BOUNDARY_PATTERN = Pattern.compile("\\w+|\\s+|[^\\w\\s]") + + // Simple LRU-ish caches (bounded ConcurrentHashMaps) + private val tokenizerCache = ConcurrentHashMap>(256) + private val tokenizerWithOffsetsCache = ConcurrentHashMap>>(256) + + /** Extract words from text (case-insensitive). Cached. */ + fun simpleTokenizer(text: String): List { + return tokenizerCache.getOrPut(text) { + val matcher = WORD_PATTERN.matcher(text) + val tokens = mutableListOf() + while (matcher.find()) tokens.add(matcher.group()) + tokens + } + } + + /** Extract words with character offsets. Cached. */ + fun simpleTokenizerWithOffsets(text: String): List> { + return tokenizerWithOffsetsCache.getOrPut(text) { + val matcher = WORD_PATTERN.matcher(text) + val tokens = mutableListOf>() + while (matcher.find()) { + tokens.add(Triple(matcher.group(), matcher.start(), matcher.end())) + } + tokens + } + } + + /** Clear caches (call between requests if memory is a concern). */ + fun clearCaches() { + tokenizerCache.clear() + tokenizerWithOffsetsCache.clear() + } + + /** + * Extract added and deleted words from a single diff hunk at the word level. + * Uses SequenceMatcher-like diffing on word tokens. + */ + fun extractAddedAndDeletedFromHunk(hunk: String, extension: String): Pair, List> { + val linesWithoutFileHeader = hunk.linesSplitKeepEnds() + .filter { !it.startsWith("File: ") } + .joinToString("") + val (oldCode, newCode) = NesUtils.extractDiffParts(linesWithoutFileHeader) + + val oldWords = tokenizeWithBoundaries(oldCode) + val newWords = tokenizeWithBoundaries(newCode) + + val patch = DiffUtils.diff(oldWords, newWords) + val originalDeletedWords = mutableListOf() + val originalAddedWords = mutableListOf() + + for (delta in patch.deltas) { + when (delta.type) { + com.github.difflib.patch.DeltaType.DELETE -> { + originalDeletedWords.addAll(delta.source.lines) + } + com.github.difflib.patch.DeltaType.INSERT -> { + originalAddedWords.addAll(delta.target.lines) + } + com.github.difflib.patch.DeltaType.CHANGE -> { + originalDeletedWords.addAll(delta.source.lines) + originalAddedWords.addAll(delta.target.lines) + } + else -> {} + } + } + + val addedWords = originalAddedWords.filter { it.length > 1 }.toSet().toList() + val deletedWords = originalDeletedWords.filter { it.length > 1 }.toSet().toList() + return addedWords to deletedWords + } + + /** + * Extract added and deleted words from recent changes diff. + * Returns (addedWords, deletedWords). + */ + fun extractAddedAndDeletedCodeFromRecentChanges( + recentChanges: String, + fileTokensSet: Set, + ): Pair, List> { + val hunks = NesUtils.splitIntoHunks(recentChanges) + .filter { it.trim().lines().size > 1 } + if (hunks.isEmpty()) return emptyList() to emptyList() + + var addedWords = emptyList() + var deletedWords = emptyList() + + for (hunk in hunks.reversed()) { + val firstLine = hunk.lines().first() + val extension = if ("." in firstLine) firstLine.substringAfter(".").lowercase() else "" + val (added, deleted) = extractAddedAndDeletedFromHunk(hunk, extension) + addedWords = added + deletedWords = deleted.filter { it.length > 1 && it in fileTokensSet } + if (deletedWords.size == 1) return addedWords to deletedWords + } + return addedWords to deletedWords + } + + /** + * Extract full deleted lines from recent changes with their line numbers. + */ + fun extractDeletedLinesFromRecentChanges(recentChanges: String): List> { + val hunks = NesUtils.splitIntoHunks(recentChanges) + val deletedLinesWithNumbers = mutableListOf>() + + for (hunk in hunks) { + if ("@@" !in hunk) continue + + val lines = hunk.linesSplitKeepEnds() + val rest = lines.drop(1).joinToString("") + val parsed = NesUtils.parseHunk(rest) + + var currentLine = parsed.inputStart + for (line in parsed.inputLines) { + val stripped = line.trim() + if (stripped.isNotEmpty() && line !in parsed.outputLines) { + deletedLinesWithNumbers.add(stripped to currentLine) + } + currentLine++ + } + } + return deletedLinesWithNumbers + } + + /** + * Search for deleted lines in the file contents and return a code block around the match. + */ + fun findDeletedLineMatch( + fileContents: String, + deletedLines: List>, + ): Pair? { + val fileLines = fileContents.linesSplitKeepEnds() + + for ((deletedLine, originalLineNumber) in deletedLines) { + val lineTokens = simpleTokenizer(deletedLine) + val distinctTerms = lineTokens.toSet() + + if (distinctTerms.size >= 3) { + for ((lineIndex, fileLine) in fileLines.withIndex()) { + if (lineIndex == originalLineNumber) continue + if (deletedLine in fileLine.trim()) { + val startLine = lineIndex + val endLine = minOf(fileLines.size, lineIndex + 4) + val retrievedBlock = fileLines.subList(startLine, endLine).joinToString("") + val blockStartOffset = fileLines.take(lineIndex).sumOf { it.length } + return retrievedBlock to blockStartOffset + } + } + } + } + return null + } + + data class RetrievalResult( + val codeBlock: String, + val blockStartOffset: Int, + val isBlockAfterCursor: Boolean, + val diagnostic: EditorDiagnosticData?, + ) + + data class EditorDiagnosticData( + val line: Int, + val lineNumber: Int, + val startOffset: Int, + val endOffset: Int, + val severity: String, + val message: String, + ) + + /** + * Find the best matching code block for retrieval-based autocomplete. + * Ported from Python find_best_matching_block(). + */ + fun findBestMatchingBlock( + fileContents: String, + recentChanges: String, + cursorPosition: Int, + blockSize: Int = 6, + editorDiagnostics: List? = null, + ): RetrievalResult { + // Try deleted line match first + val deletedLines = extractDeletedLinesFromRecentChanges(recentChanges) + val deletedLineMatch = findDeletedLineMatch(fileContents, deletedLines) + if (deletedLineMatch != null) { + return RetrievalResult(deletedLineMatch.first, deletedLineMatch.second, false, null) + } + + // Tokenize file + val fileTokensWithOffsets = simpleTokenizerWithOffsets(fileContents) + val fileTokens = fileTokensWithOffsets.map { it.first } + + val (addedWords, deletedWords) = extractAddedAndDeletedCodeFromRecentChanges( + recentChanges, fileTokens.toSet() + ) + + val currentCursorLineNumber = NesUtils.getLineNumberFromPosition(fileContents, cursorPosition) + + // Determine query token + val queryToken: String? = if (deletedWords.size == 1) { + deletedWords[0] + } else { + addedWords.firstOrNull { word -> + val lineNumbers = fileTokensWithOffsets + .filter { it.first == word } + .map { NesUtils.getLineNumberFromPosition(fileContents, it.second) } + word in fileTokens && + fileTokens.count { it == word } in 2..5 && + lineNumbers.any { abs(it - currentCursorLineNumber) > 10 } + } + } + + // Find offsets for query token (far from cursor) + val queryTokenOffsets = fileTokensWithOffsets + .filter { + it.first == queryToken && + abs(NesUtils.getLineNumberFromPosition(fileContents, it.second) - currentCursorLineNumber) > 10 + } + .map { it.second } + + // Check for closest error diagnostic + var closestError: EditorDiagnosticData? = null + if (editorDiagnostics != null) { + val filteredErrors = editorDiagnostics.filter { + it.severity == "ERROR" && abs(currentCursorLineNumber - it.lineNumber) > 10 + } + closestError = filteredErrors.minByOrNull { abs(cursorPosition - it.startOffset) } + } + + val lines = fileContents.linesSplitKeepEnds() + + if (closestError != null) { + var cumOffset = 0 + var startLine = 0 + for ((i, line) in lines.withIndex()) { + if (cumOffset + line.length > closestError.startOffset) { + startLine = i + break + } + cumOffset += line.length + } + val endLine = minOf(lines.size, startLine + 1) + val block = lines.subList(startLine, endLine).joinToString("") + val offset = lines.take(startLine).sumOf { it.length } + return RetrievalResult(block, offset, false, closestError) + } else if (queryTokenOffsets.isNotEmpty()) { + val closestOffset = queryTokenOffsets.minByOrNull { abs(cursorPosition - it) }!! + + var lineIndex = 0 + var cumOffset = 0 + for ((i, line) in lines.withIndex()) { + if (cumOffset + line.length > closestOffset) { + lineIndex = i + break + } + cumOffset += line.length + } + val endLine = minOf(lines.size, lineIndex + 1) + val block = lines.subList(lineIndex, endLine).joinToString("") + val offset = lines.take(lineIndex).sumOf { it.length } + return RetrievalResult(block, offset, false, null) + } else { + // Fallback: block after cursor + var suffix = fileContents.substring(cursorPosition) + val suffixStart = suffix.indexOf('\n') + var adjustedCursorPosition = cursorPosition + if (suffixStart != -1) { + suffix = suffix.substring(suffixStart + 1) + adjustedCursorPosition += suffixStart + 1 + } + val suffixLines = suffix.linesSplitKeepEnds() + adjustedCursorPosition += suffixLines.take(blockSize).sumOf { it.length } + val remainingLines = suffixLines.drop(blockSize) + val fallbackBlock = remainingLines.take(blockSize).joinToString("") + + return RetrievalResult(fallbackBlock, adjustedCursorPosition, true, null) + } + } + + /** Split text into word and non-word tokens (preserving whitespace/punctuation). */ + private fun tokenizeWithBoundaries(text: String): List { + val matcher = WORD_BOUNDARY_PATTERN.matcher(text) + val tokens = mutableListOf() + while (matcher.find()) tokens.add(matcher.group()) + return tokens + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtils.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtils.kt new file mode 100644 index 0000000..62c064c --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtils.kt @@ -0,0 +1,407 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.github.difflib.DiffUtils +import com.github.difflib.UnifiedDiffUtils +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.AUTOCOMPLETE_MAXIMUM_LINE_LENGTH +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.AUTOCOMPLETE_OUTPUT_MAX_TOKENS +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.AUTOCOMPLETE_TRUNCATION_LINE_LENGTH +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHARS_PER_TOKEN +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHUNK_SIZE +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.CHUNK_STRIDE +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.LIMIT_TO_CHUNK +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.NUM_LINES_BEFORE +import java.util.regex.Pattern +import kotlin.math.abs +import kotlin.math.max +import kotlin.math.min +import kotlin.math.roundToInt + +/** + * Pure utility functions for the NES engine. + * Ported from Python next_edit_autocomplete_utils.py. + * No IntelliJ dependencies — fully unit-testable. + */ +object NesUtils { + + /** Estimate token count using character-based approximation. */ + fun estimateTokenCount(text: String): Int = (text.length / CHARS_PER_TOKEN).toInt() + + /** + * Convert a character position to a 0-indexed line number. + * Equivalent to Python: file_contents[:position].count("\n") + */ + fun getLineNumberFromPosition(fileContents: String, position: Int): Int { + if (position <= 0) return 0 + if (position >= fileContents.length) { + return fileContents.lines().size - 1 + } + return fileContents.substring(0, position).count { it == '\n' } + } + + /** + * Return a chunk of the file centered around the cursor position. + * For files <= LIMIT_TO_CHUNK lines, returns the full file. + * Ported from Python get_lines_around_cursor(). + */ + fun getLinesAroundCursor(fileContents: String, cursorPosition: Int): String { + val lines = fileContents.split("\n") + + if (lines.size <= LIMIT_TO_CHUNK) return fileContents + + val cursorLine = getLineNumberFromPosition(fileContents, cursorPosition) + val idealStart = cursorLine - CHUNK_SIZE / 2 + val chunkIndex = max(0, (idealStart.toDouble() / CHUNK_STRIDE).roundToInt()) + val startLine = chunkIndex * CHUNK_STRIDE + val endLine = min(lines.size, startLine + CHUNK_SIZE) + + return lines.subList(startLine, endLine).joinToString("\n") + } + + /** + * Extract old and new code from a unified diff hunk. + * @param hunk The diff hunk string (may start with @@ header) + * @param numContextLines Number of context lines to keep (0 = changed lines only, -1 = all) + * @return Pair of (oldCode, newCode) + */ + fun extractDiffParts(hunk: String, numContextLines: Int = 0): Pair { + // Use splitlines-keep-ends to match Python's splitlines(True) + val allLines = hunk.linesSplitKeepEnds() + val contentLines = allLines.filter { !it.startsWith("@@") } + + if (numContextLines == -1) { + val oldCode = mutableListOf() + val newCode = mutableListOf() + for (line in contentLines) { + when { + line.startsWith("-") -> oldCode.add(line.substring(1)) + line.startsWith("+") -> newCode.add(line.substring(1)) + line.isNotEmpty() -> { + val content = if (line.startsWith(" ")) line.substring(1) else line + oldCode.add(content) + newCode.add(content) + } + } + } + return oldCode.joinToString("") to newCode.joinToString("") + } + + val changedIndices = contentLines.indices.filter { + contentLines[it].startsWith("-") || contentLines[it].startsWith("+") + } + + if (changedIndices.isEmpty()) return "" to "" + + val startChange = changedIndices.min() + val endChange = changedIndices.max() + val startIdx = max(0, startChange - numContextLines) + val endIdx = min(contentLines.size, endChange + numContextLines + 1) + + val relevantLines = contentLines.subList(startIdx, endIdx) + val oldCode = mutableListOf() + val newCode = mutableListOf() + + for (line in relevantLines) { + when { + line.startsWith("-") -> oldCode.add(line.substring(1)) + line.startsWith("+") -> newCode.add(line.substring(1)) + line.isNotEmpty() -> { + val content = if (line.startsWith(" ")) line.substring(1) else line + oldCode.add(content) + newCode.add(content) + } + } + } + + return oldCode.joinToString("") to newCode.joinToString("") + } + + /** + * Filter out hunks that only contain whitespace changes. + */ + fun filterWhitespaceOnlyHunks(hunks: List): List { + return hunks.filter { hunk -> + val lines = hunk.lines() + val rest = lines.drop(1).joinToString("\n") + val (oldCode, newCode) = extractDiffParts(rest) + oldCode.trim() != newCode.trim() + } + } + + /** + * Split a diff string into individual hunks by "File: " markers. + */ + fun splitIntoHunks(diff: String): List { + val hunks = mutableListOf() + val currentHunk = mutableListOf() + + for (line in diff.lines()) { + if (line.startsWith("File: ") && currentHunk.isNotEmpty()) { + hunks.add(currentHunk.joinToString("\n")) + currentHunk.clear() + } + currentHunk.add(line) + } + if (currentHunk.isNotEmpty()) { + hunks.add(currentHunk.joinToString("\n")) + } + return hunks + } + + /** Strip leading empty lines from a completion string. */ + fun stripLeadingEmptyNewlines(completion: String): String { + val lines = completion.split("\n") + val startIndex = lines.indexOfFirst { it.trim().isNotEmpty() } + return if (startIndex >= 0) lines.drop(startIndex).joinToString("\n") else completion + } + + data class HunkParseResult( + val inputStart: Int, + val inputLines: List, + val outputStart: Int, + val outputLines: List, + ) + + /** + * Parse a single diff hunk and return input/output line numbers and content. + * Ported from Python parse_hunk(). + */ + fun parseHunk(hunk: String): HunkParseResult { + val lines = hunk.lines() + val hunkHeader = lines[0] + val diffLines = if (lines.size > 2) lines.drop(2) else emptyList() + + val parts = hunkHeader.split(" ") + val inputRange = parts[1].trimStart('-') + val outputRange = parts[2].trimStart('+') + + val inputParts = inputRange.split(",") + val outputParts = outputRange.split(",") + + var inputStart = inputParts[0].toInt() + val outputStart = outputParts[0].toInt() + + val inputLinesList = mutableListOf() + val outputLinesList = mutableListOf() + + for (line in diffLines) { + when { + line.startsWith("-") -> inputLinesList.add(line.substring(1)) + line.startsWith("+") -> outputLinesList.add(line.substring(1)) + line.isNotEmpty() -> { + val content = if (line.startsWith(" ")) line.substring(1) else line + inputLinesList.add(content) + outputLinesList.add(content) + } + } + } + + // Edge case: Python's difflib has an off-by-one when input_lines is empty + if (inputLinesList.isEmpty()) { + inputStart += 1 + } + + return HunkParseResult(inputStart, inputLinesList, outputStart, outputLinesList) + } + + /** + * Split two strings into diff hunks using java-diff-utils. + * Equivalent to Python split_into_diff_hunks() with difflib.unified_diff(n=0). + * + * Uses the Patch deltas directly rather than parsing unified diff text, + * to avoid line-number format differences between java-diff-utils and Python difflib. + */ + fun splitIntoDiffHunks(inputContent: String, outputContent: String): List { + val inputLines = inputContent.linesSplitKeepEnds() + val outputLines = outputContent.linesSplitKeepEnds() + + val patch = DiffUtils.diff(inputLines, outputLines) + if (patch.deltas.isEmpty()) return emptyList() + + return patch.deltas.map { delta -> + // java-diff-utils uses 0-based positions; Python difflib uses 1-based + var inputStart = delta.source.position + 1 + val inputLinesList = delta.source.lines.map { it } + val outputStart = delta.target.position + 1 + val outputLinesList = delta.target.lines.map { it } + + // Match Python's off-by-one edge case when input_lines is empty + if (inputLinesList.isEmpty()) { + inputStart += 1 + } + + HunkParseResult(inputStart, inputLinesList, outputStart, outputLinesList) + } + } + + /** + * Check if the completion has a large diff above the cursor position. + * Large = more than 5 lines added AND more than 5 lines deleted. + */ + fun isLargeDiffAboveCursor( + original: String, + completion: String, + relativeCursorPosition: Int, + ): Boolean { + if (original == completion) return false + + val originalAbove = original.substring(0, min(relativeCursorPosition, original.length)) + val completionAbove = completion.substring(0, min(relativeCursorPosition, completion.length)) + + if (originalAbove == completionAbove) return false + + val patch = DiffUtils.diff( + original.lines(), + completion.lines() + ) + + var additions = 0 + var deletions = 0 + for (delta in patch.deltas) { + additions += delta.target.lines.size + deletions += delta.source.lines.size + } + + return additions > 5 && deletions > 5 + } + + /** Truncate lines exceeding AUTOCOMPLETE_TRUNCATION_LINE_LENGTH. */ + fun truncateLongLines(content: String): String { + return content.lines().joinToString("\n") { line -> + if (line.length > AUTOCOMPLETE_TRUNCATION_LINE_LENGTH) { + line.substring(0, AUTOCOMPLETE_TRUNCATION_LINE_LENGTH) + "..." + } else { + line + } + } + } + + /** Check if any line in the code block exceeds AUTOCOMPLETE_MAXIMUM_LINE_LENGTH. */ + fun shouldDisableForCodeBlock(codeBlock: String): Boolean { + return codeBlock.lines().any { it.length > AUTOCOMPLETE_MAXIMUM_LINE_LENGTH } + } + + /** Normalize consecutive newlines to single newlines. */ + fun normalizeNewlines(text: String): String { + return text.replace(Regex("\\n+"), "\n") + } + + /** Compare two strings ignoring differences in consecutive newlines. */ + fun isEqualIgnoringNewlines(text1: String, text2: String): Boolean { + return normalizeNewlines(text1) == normalizeNewlines(text2) + } + + /** + * Pack items from an iterable into a list, respecting a token limit. + * Ported from Python pack_items_for_prompt(). + */ + fun packItemsForPrompt( + items: List, + stringFunction: (T) -> String, + tokenLimit: Int, + charTokenRatio: Double = 3.5, + truncateFromEnd: Boolean = true, + ): List { + val charLimit = (tokenLimit * charTokenRatio).toInt() + val packed = mutableListOf() + var currentLen = 0 + + val orderedItems = if (truncateFromEnd) items else items.reversed() + + for (item in orderedItems) { + val itemStr = stringFunction(item) + if (currentLen + itemStr.length <= charLimit) { + if (truncateFromEnd) packed.add(item) else packed.add(0, item) + currentLen += itemStr.length + } else { + break + } + } + return packed + } + + /** + * Check if completion length suggests it hit max_tokens. + */ + fun isCompletionMaxTokens(completion: String): Boolean { + val tokenCount = (completion.length / CHARS_PER_TOKEN).toInt() + return tokenCount >= AUTOCOMPLETE_OUTPUT_MAX_TOKENS + } + + /** + * Extract minimal diff between original and new code. + * Returns (minimalNew, startOffset, endOffset). + */ + fun extractMinimalDiff(originalCode: String, newCode: String): Triple { + val originalLines = originalCode.linesSplitKeepEnds() + val newLines = newCode.linesSplitKeepEnds() + + var startDiff = 0 + while (startDiff < min(originalLines.size, newLines.size) && + originalLines[startDiff] == newLines[startDiff] + ) { + startDiff++ + } + + var endDiffOrig = originalLines.size - 1 + var endDiffNew = newLines.size - 1 + while (endDiffOrig >= startDiff && endDiffNew >= startDiff && + endDiffOrig >= 0 && endDiffNew >= 0 && + originalLines[endDiffOrig] == newLines[endDiffNew] + ) { + endDiffOrig-- + endDiffNew-- + } + + val startContext = max(0, startDiff - 1) + val endContextOrig = min(originalLines.size - 1, endDiffOrig + 1) + val endContextNew = min(newLines.size - 1, endDiffNew + 1) + + var startOffset = originalLines.take(startContext).sumOf { it.length } + val endOffset = originalLines.take(endContextOrig + 1).sumOf { it.length } + + var minimalNew = newLines.subList(startContext, endContextNew + 1).joinToString("") + if (minimalNew.startsWith("\n")) { + minimalNew = minimalNew.substring(1) + startOffset += 1 + } + + return Triple(minimalNew, startOffset, endOffset) + } + + // Pretokenizer regex compiled once + private val pretokenizePattern: Pattern by lazy { + Pattern.compile(NesConstants.PRETOKENIZE_REGEX, Pattern.UNICODE_CHARACTER_CLASS) + } + + /** + * Tokenize text using the Qwen2 pretokenizer regex. + * Returns the list of pretokens. + */ + fun pretokenize(text: String): List { + val matcher = pretokenizePattern.matcher(text) + val tokens = mutableListOf() + while (matcher.find()) { + tokens.add(matcher.group()) + } + return tokens + } +} + +/** + * Split a string into lines, keeping line endings (like Python's splitlines(True)). + */ +fun String.linesSplitKeepEnds(): List { + if (this.isEmpty()) return listOf("") + val result = mutableListOf() + var start = 0 + for (i in indices) { + if (this[i] == '\n') { + result.add(this.substring(start, i + 1)) + start = i + 1 + } + } + if (start < length) { + result.add(this.substring(start)) + } + return result +} diff --git a/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NextEditAutocompleteEngine.kt b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NextEditAutocompleteEngine.kt new file mode 100644 index 0000000..9136e05 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NextEditAutocompleteEngine.kt @@ -0,0 +1,358 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.intellij.openapi.diagnostic.Logger +import dev.sweep.assistant.autocomplete.edit.engine.NesCompletionParser.AutocompleteResult +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.MAX_RETRIEVAL_CHUNK_SIZE_LINES +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.NUM_LINES_AFTER +import dev.sweep.assistant.autocomplete.edit.engine.NesConstants.NUM_LINES_BEFORE +import java.util.UUID +import java.util.Collections +import kotlin.math.max +import kotlin.math.min +import kotlin.math.abs + +/** + * Top-level orchestrator for the NES engine. + * Coordinates prompt building, LLM inference, and completion parsing. + * + * This replaces the Python sweep-autocomplete server — all logic runs + * in the JVM, calling llama-server's /v1/completions endpoint for inference. + */ +class NextEditAutocompleteEngine( + private val llamaClient: LlamaServerClient, +) { + private val logger = Logger.getInstance(NextEditAutocompleteEngine::class.java) + + data class NesRequest( + val filePath: String, + val fileContents: String, + val originalFileContents: String?, + val recentChanges: String, + val cursorPosition: Int, + val fileChunks: List = emptyList(), + val retrievalChunks: List = emptyList(), + val recentUserActions: List = emptyList(), + val recentChangesHighRes: String = "", + val changesAboveCursor: Boolean = false, + val editorDiagnostics: List? = null, + ) + + data class UserAction( + val actionType: String, + val lineNumber: Int, + val offset: Int, + val filePath: String, + val timestamp: Long = 0, + ) + + data class NesResponse( + val completions: List, + val elapsedMs: Long, + val autocompleteId: String, + ) + + /** + * Main entry point: generate next-edit suggestions for the given request. + * Ported from Python fetch_next_edits() + _fetch_next_edits_core(). + */ + fun fetchNextEdits(request: NesRequest): NesResponse { + val autocompleteId = UUID.randomUUID().toString().replace("-", "") + val fileContents = request.fileContents + val originalFileContents = request.originalFileContents ?: fileContents + val cursorPosition = request.cursorPosition + + // Check if autocomplete should be disabled for this file + if (shouldDisableAutocomplete(fileContents)) { + return emptyResponse(autocompleteId) + } + + // Extract code block around cursor + val block = NesPromptBuilder.getBlockAtCursor(fileContents, cursorPosition) + + if (NesUtils.shouldDisableForCodeBlock(block.codeBlock)) { + return emptyResponse(autocompleteId) + } + + // Truncate retrieval chunks + val retrievalChunks = request.retrievalChunks.map { chunk -> + chunk.copy( + content = chunk.content.linesSplitKeepEnds() + .take(MAX_RETRIEVAL_CHUNK_SIZE_LINES) + .joinToString("") + ) + } + + // Limit chunks for local model + val fileChunks = request.fileChunks.take(1) + val limitedRetrievalChunks = retrievalChunks.take(1) + + // Determine if ghost text should be forced + val forceGhostText = request.recentUserActions.isEmpty() || + request.recentUserActions.lastOrNull()?.actionType == "INSERT_CHAR" + + // First pass: autocomplete at cursor + val startTime = System.currentTimeMillis() + val firstPassResult = runAutocompletePass( + filePath = request.filePath, + fileContents = fileContents, + originalFileContents = originalFileContents, + recentChanges = request.recentChanges, + cursorPosition = cursorPosition, + codeBlock = block.codeBlock, + prefix = block.prefix, + suffix = block.suffix, + blockStartIndex = block.blockStartIndex, + autocompleteId = autocompleteId, + fileChunks = fileChunks, + retrievalChunks = limitedRetrievalChunks, + recentChangesHighRes = request.recentChangesHighRes, + changesAboveCursor = request.changesAboveCursor, + forceGhostText = forceGhostText, + ) + + if (firstPassResult != null && firstPassResult.isNotEmpty() && + firstPassResult.any { it.completion.trim('\n').isNotEmpty() || it.startIndex != it.endIndex } + ) { + val elapsed = System.currentTimeMillis() - startTime + return NesResponse(firstPassResult, elapsed, autocompleteId) + } + + // Second pass: retrieval-based autocomplete + if (request.recentChanges.isNotEmpty()) { + val retrievalResult = NesRetrieval.findBestMatchingBlock( + fileContents, request.recentChanges, cursorPosition, + blockSize = 6, editorDiagnostics = request.editorDiagnostics, + ) + + if (retrievalResult.codeBlock.isNotEmpty()) { + val prefixLines = fileContents.substring(0, retrievalResult.blockStartOffset) + .linesSplitKeepEnds() + val retrievedPrefix = prefixLines.takeLast(NUM_LINES_BEFORE).joinToString("") + + val numRetrievedLines = retrievalResult.codeBlock.lines().size + val numSuffixLines = max(0, NUM_LINES_AFTER + 1 - numRetrievedLines) + val afterBlock = fileContents.substring( + min(fileContents.length, retrievalResult.blockStartOffset + retrievalResult.codeBlock.length) + ) + val retrievedSuffix = afterBlock.linesSplitKeepEnds().take(numSuffixLines).joinToString("") + + val cursorInBlock = retrievalResult.blockStartOffset + + retrievalResult.codeBlock.linesSplitKeepEnds().firstOrNull()?.length.let { it ?: 0 } + + val fullBlock = retrievedPrefix + NesPromptBuilder.truncateCodeBlockByTokensPublic( + retrievalResult.codeBlock + retrievedSuffix + ) + + if (!NesUtils.shouldDisableForCodeBlock(fullBlock)) { + // Add diagnostic as retrieval chunk if present + val extraChunks = if (retrievalResult.diagnostic != null) { + val diagLine = fileContents.lines().getOrElse(retrievalResult.diagnostic.lineNumber) { "" } + listOf( + NesPromptBuilder.FileChunkData( + "diagnostics", + "${retrievalResult.diagnostic.message} at line ${retrievalResult.diagnostic.lineNumber}:\n$diagLine", + 1, 2, + ) + ) + limitedRetrievalChunks + } else { + limitedRetrievalChunks + } + + val secondPassResult = runAutocompletePass( + filePath = request.filePath, + fileContents = fileContents, + originalFileContents = originalFileContents, + recentChanges = request.recentChanges, + cursorPosition = cursorInBlock, + codeBlock = fullBlock, + prefix = retrievedPrefix, + suffix = retrievedSuffix, + blockStartIndex = retrievalResult.blockStartOffset, + autocompleteId = autocompleteId, + fileChunks = fileChunks, + retrievalChunks = extraChunks, + recentChangesHighRes = request.recentChangesHighRes, + changesAboveCursor = request.changesAboveCursor, + forceGhostText = forceGhostText, + ) + + if (secondPassResult != null && secondPassResult.isNotEmpty() && + request.recentChanges.isNotEmpty() && + secondPassResult.first().completion.trim().isNotEmpty() + ) { + val elapsed = System.currentTimeMillis() - startTime + return NesResponse(secondPassResult, elapsed, autocompleteId) + } + } + } + } + + val elapsed = System.currentTimeMillis() - startTime + return emptyResponse(autocompleteId, elapsed) + } + + private fun runAutocompletePass( + filePath: String, + fileContents: String, + originalFileContents: String, + recentChanges: String, + cursorPosition: Int, + codeBlock: String, + prefix: String, + suffix: String, + blockStartIndex: Int, + autocompleteId: String, + fileChunks: List, + retrievalChunks: List, + recentChangesHighRes: String, + changesAboveCursor: Boolean, + forceGhostText: Boolean, + ): List? { + if (codeBlock.isEmpty()) return null + + val promptResult = NesPromptBuilder.buildPrompt( + filePath = filePath, + fileContents = fileContents, + originalFileContents = originalFileContents, + recentChanges = recentChanges, + cursorPosition = cursorPosition, + codeBlock = codeBlock, + prefix = prefix, + suffix = suffix, + blockStartIndex = blockStartIndex, + fileChunks = fileChunks, + retrievalChunks = retrievalChunks, + recentChangesHighRes = recentChangesHighRes, + changesAboveCursor = changesAboveCursor, + forceGhostText = forceGhostText, + useRemoteEndpoint = false, // local llama-server + ) + + if (promptResult.formattedPrompt.isEmpty()) return null + + // Log prompt details for debugging + logger.info("NES: prompt length=${promptResult.formattedPrompt.length} chars, " + + "codeBlock length=${promptResult.cleanedCodeBlock.length}, " + + "relativeCursorPos=${promptResult.relativeCursorPosition}, " + + "relativeCursorLine=${promptResult.relativeCursorLine}, " + + "blockStartIndex=${promptResult.blockStartIndex}") + // Log the last ~200 chars of the prompt (the part right before model generation) + val promptTail = promptResult.formattedPrompt.takeLast(300) + logger.info("NES: prompt tail: ...${promptTail.replace("\n", "\\n")}") + + // Call llama-server + // Allow output up to 2x the code block size (room for insertions) + 20 lines buffer + val maxOutputChars = (promptResult.cleanedCodeBlock.length * 2) + (20 * 80) + val completionResult = try { + llamaClient.generateCompletion( + prompt = promptResult.formattedPrompt, + maxOutputChars = maxOutputChars, + ) + } catch (e: LlamaServerClient.RequestCancelledException) { + logger.info("NES request cancelled") + return null + } catch (e: Exception) { + logger.warn("NES inference error: ${e.message}") + return null + } + + if (completionResult.text.isEmpty()) { + logger.warn("NES: empty completion text") + return null + } + + // Post-process completion + var completion = promptResult.prefill + completionResult.text + logger.info("NES: raw completion (${completionResult.text.length} chars, finish=${completionResult.finishReason}): ${completionResult.text.take(200)}") + logger.info("NES: prefill='${promptResult.prefill.take(50)}', forcedPrefix='${promptResult.forcedPrefix.take(50)}'") + + if (completion.startsWith("<|") || completion.removePrefix(promptResult.forcedPrefix).startsWith("<|")) { + logger.warn("NES: filtered — completion starts with special token") + return null + } + if (promptResult.forcedPrefix.isNotEmpty() && !completionResult.text.startsWith(promptResult.forcedPrefix)) { + logger.warn("NES: filtered — forced prefix '${promptResult.forcedPrefix.take(30)}' not respected") + return null + } + + // Clean up completion + if (completion.trimEnd('\n').endsWith(" No newline at end of file")) { + completion = completion.substringBefore(" No newline at end of file") + } + completion = NesUtils.stripLeadingEmptyNewlines(completion).removeSuffix("<|file_sep|>") + .ifEmpty { promptResult.cleanedCodeBlock } + if ("<|cursor|>" !in promptResult.cleanedCodeBlock) { + completion = completion.replace("<|cursor|>", "") + } + + // Check max tokens + if (completionResult.finishReason == "length") { + logger.warn("NES: filtered — hit max tokens") + return null + } + + // Check for pure insertion above cursor + if (NesCompletionParser.isPureInsertionAboveCursor( + promptResult.cleanedCodeBlock, completion, promptResult.relativeCursorPosition + ) + ) { + logger.warn("NES: filtered — pure insertion above cursor") + return null + } + + // Check for large diff above cursor + if (NesUtils.isLargeDiffAboveCursor( + promptResult.cleanedCodeBlock, completion, promptResult.relativeCursorPosition + ) + ) { + logger.warn("NES: filtered — large diff above cursor") + return null + } + + // Select best hunks + val completions = NesCompletionParser.selectBestHunkFromCompletion( + completion, promptResult.cleanedCodeBlock, fileContents, cursorPosition, autocompleteId, + ) + + logger.info("NES: selectBestHunk returned ${completions.size} completions") + completions.forEachIndexed { i, c -> + logger.info("NES: [$i] start=${c.startIndex} end=${c.endIndex} text='${c.completion.take(60)}'") + } + + if (completions.isEmpty()) { + logger.warn("NES: filtered — no hunks selected from completion") + logger.warn("NES: cleanedCodeBlock (${promptResult.cleanedCodeBlock.length} chars): '${promptResult.cleanedCodeBlock.take(150).replace("\n", "\\n")}'") + logger.warn("NES: completion after cleanup (${completion.length} chars): '${completion.take(150).replace("\n", "\\n")}'") + logger.warn("NES: completion == cleanedCodeBlock? ${completion == promptResult.cleanedCodeBlock}") + return null + } + + // Check for reverts + val codeBlockWithCompletions = NesCompletionParser.applyCompletionsToCodeBlock( + completions, fileContents, promptResult.cleanedCodeBlock, + ) + for (section in promptResult.prevSections) { + if (NesUtils.isEqualIgnoringNewlines(codeBlockWithCompletions, section)) { + logger.warn("NES: filtered — revert detected") + return null + } + } + + logger.info("NES: returning ${completions.size} completions successfully") + return completions + } + + private fun shouldDisableAutocomplete(fileContents: String): Boolean { + if (fileContents.isEmpty()) return false + if (fileContents.length > 10_000_000) return true + val lines = fileContents.lines() + if (lines.size > 50_000) return true + val avgLineLength = fileContents.length.toDouble() / lines.size + if (avgLineLength > 240) return true + return false + } + + private fun emptyResponse(autocompleteId: String, elapsedMs: Long = 0) = NesResponse( + emptyList(), elapsedMs, autocompleteId, + ) +} diff --git a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt index 29f1d42..0a74f5b 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt @@ -1047,6 +1047,18 @@ class SweepConfig( SweepSettings.getInstance().autocompleteLocalMlx = enabled } + fun isAutocompleteLocalNativeEngine(): Boolean = SweepSettings.getInstance().autocompleteLocalNativeEngine + + fun updateAutocompleteLocalNativeEngine(enabled: Boolean) { + SweepSettings.getInstance().autocompleteLocalNativeEngine = enabled + } + + fun getAutocompleteLocalModel(): String = SweepSettings.getInstance().autocompleteLocalModel + + fun updateAutocompleteLocalModel(modelId: String) { + SweepSettings.getInstance().autocompleteLocalModel = modelId + } + fun getAutocompleteLocalPort(): Int = SweepSettings.getInstance().autocompleteLocalPort fun updateAutocompleteLocalPort(port: Int) { @@ -4864,6 +4876,63 @@ class SweepConfig( }, ) } + add(Box.createRigidArea(Dimension(0, 4.scaled))) + add( + JCheckBox("Use native engine (llama-server direct)").apply { + isSelected = isAutocompleteLocalNativeEngine() + withSweepFont(project) + border = JBUI.Borders.emptyLeft(24) + addActionListener { + updateAutocompleteLocalNativeEngine(isSelected) + } + }, + ) + add(Box.createRigidArea(Dimension(0, 2.scaled))) + add( + JLabel("Runs prompt construction in the plugin and calls llama-server directly. Requires llama-server on PATH.").apply { + withSweepFont(project, scale = 0.85f) + foreground = JBColor.GRAY + font = font.deriveFont(Font.ITALIC) + border = JBUI.Borders.emptyLeft(48) + }, + ) + add(Box.createRigidArea(Dimension(0, 4.scaled))) + add( + JPanel().apply { + layout = BoxLayout(this, BoxLayout.X_AXIS) + border = JBUI.Borders.emptyLeft(24) + add(JLabel("Model").apply { withSweepFont(project) }) + add(Box.createRigidArea(Dimension(8.scaled, 0))) + val models = dev.sweep.assistant.autocomplete.edit.engine.NesModelConfig.MODELS + val currentModelId = getAutocompleteLocalModel() + val comboBox = javax.swing.JComboBox(models.map { it.displayName }.toTypedArray()).apply { + withSweepFont(project) + maximumSize = Dimension(300.scaled, 30.scaled) + selectedIndex = models.indexOfFirst { it.id == currentModelId }.coerceAtLeast(0) + addActionListener { + val idx = (this as javax.swing.JComboBox<*>).selectedIndex + if (idx >= 0 && models[idx].id != getAutocompleteLocalModel()) { + updateAutocompleteLocalModel(models[idx].id) + // Auto-restart if server is running + if (isAutocompleteLocalMode() && LocalAutocompleteServerManager.getInstance().isServerHealthy()) { + LocalAutocompleteServerManager.getInstance().restartServerInTerminal(project) + } + } + } + } + add(comboBox) + add(Box.createHorizontalGlue()) + }, + ) + add(Box.createRigidArea(Dimension(0, 2.scaled))) + add( + JLabel("Select the GGUF model for local autocomplete. Larger models are slower but may produce better suggestions.").apply { + withSweepFont(project, scale = 0.85f) + foreground = JBColor.GRAY + font = font.deriveFont(Font.ITALIC) + border = JBUI.Borders.emptyLeft(48) + }, + ) }, gbc, ) diff --git a/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt b/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt index f800ab6..e7bd726 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt @@ -87,8 +87,15 @@ class AutocompleteIpResolverService( * This centralizes the entire HTTP request flow in the DNS resolver service. */ @RequiresBackgroundThread - suspend fun fetchNextEditAutocomplete(request: NextEditAutocompleteRequest): NextEditAutocompleteResponse? = - try { + suspend fun fetchNextEditAutocomplete(request: NextEditAutocompleteRequest): NextEditAutocompleteResponse? { + // Native engine path: prompt construction + llama-server direct + if (SweepConfig.getInstance(project).isAutocompleteLocalMode() && + SweepConfig.getInstance(project).isAutocompleteLocalNativeEngine() + ) { + return fetchViaNativeEngine(request) + } + + return try { if (SweepConfig.getInstance(project).isAutocompleteLocalMode()) { LocalAutocompleteServerManager.getInstance().ensureServerRunning() } @@ -201,6 +208,7 @@ class AutocompleteIpResolverService( } throw e } + } init { startPeriodicResolution() @@ -326,6 +334,95 @@ class AutocompleteIpResolverService( } } + // --- Native engine (Kotlin prompt construction + llama-server direct) --- + + @Volatile + private var nativeEngine: dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine? = null + + private fun getOrCreateNativeEngine(): dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine { + nativeEngine?.let { return it } + val port = SweepSettings.getInstance().autocompleteLocalPort + val client = dev.sweep.assistant.autocomplete.edit.engine.LlamaServerClient( + baseUrl = "http://localhost:$port", + ) + val engine = dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine(client) + nativeEngine = engine + return engine + } + + private fun fetchViaNativeEngine(request: NextEditAutocompleteRequest): NextEditAutocompleteResponse? { + val engine = getOrCreateNativeEngine() + + val nesRequest = dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine.NesRequest( + filePath = request.file_path, + fileContents = request.file_contents, + originalFileContents = request.original_file_contents, + recentChanges = request.recent_changes, + cursorPosition = request.cursor_position, + fileChunks = request.file_chunks.map { + dev.sweep.assistant.autocomplete.edit.engine.NesPromptBuilder.FileChunkData( + filePath = it.file_path, + content = it.content, + startLine = it.start_line, + endLine = it.end_line, + ) + }, + retrievalChunks = request.retrieval_chunks.map { + dev.sweep.assistant.autocomplete.edit.engine.NesPromptBuilder.FileChunkData( + filePath = it.file_path, + content = it.content, + startLine = it.start_line, + endLine = it.end_line, + ) + }, + recentUserActions = request.recent_user_actions.map { + dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine.UserAction( + actionType = it.action_type.name, + lineNumber = it.line_number, + offset = it.offset, + filePath = it.file_path, + timestamp = it.timestamp, + ) + }, + recentChangesHighRes = request.recent_changes_high_res, + changesAboveCursor = request.changes_above_cursor, + editorDiagnostics = request.editor_diagnostics.map { + dev.sweep.assistant.autocomplete.edit.engine.NesRetrieval.EditorDiagnosticData( + line = it.line, + lineNumber = it.line - 1, + startOffset = it.start_offset, + endOffset = it.end_offset, + severity = it.severity, + message = it.message, + ) + }, + ) + + val result = engine.fetchNextEdits(nesRequest) + + if (result.completions.isEmpty()) return null + + val first = result.completions.first() + return NextEditAutocompleteResponse( + start_index = first.startIndex, + end_index = first.endIndex, + completion = first.completion, + confidence = first.confidence, + autocomplete_id = result.autocompleteId, + elapsed_time_ms = result.elapsedMs, + completions = result.completions.map { + dev.sweep.assistant.autocomplete.edit.NextEditAutocompletion( + start_index = it.startIndex, + end_index = it.endIndex, + completion = it.completion, + confidence = it.confidence, + autocomplete_id = it.autocompleteId, + ) + }, + nativeIndices = true, // indices are already JVM-native, skip Python→JVM conversion + ) + } + override fun dispose() { resolutionJob?.cancel() healthCheckJob?.cancel() diff --git a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt index 2813d11..40d12f5 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt @@ -364,11 +364,152 @@ class LocalAutocompleteServerManager : Disposable { startServer() } + // --- llama-server support for native engine --- + + private fun getSelectedModel(): dev.sweep.assistant.autocomplete.edit.engine.NesModel { + val modelId = SweepSettings.getInstance().autocompleteLocalModel + return dev.sweep.assistant.autocomplete.edit.engine.NesModelConfig.getModel(modelId) + } + + /** + * Resolve llama-server binary on PATH. + */ + private fun resolveLlamaServer(): String? { + val envPath = try { + val env = com.intellij.util.EnvironmentUtil.getEnvironmentMap() + if (env.isNotEmpty()) env["PATH"] else System.getenv("PATH") + } catch (_: Throwable) { + System.getenv("PATH") + } + + val exeName = if (isWindows) "llama-server.exe" else "llama-server" + if (!envPath.isNullOrEmpty()) { + for (dir in envPath.split(File.pathSeparatorChar)) { + if (dir.isEmpty()) continue + val cand = File(dir, exeName) + if (cand.isFile && cand.canExecute()) return cand.absolutePath + } + } + + // Check common locations + val commonPaths = listOf( + "/opt/homebrew/bin/llama-server", + "/usr/local/bin/llama-server", + System.getProperty("user.home") + "/.local/bin/llama-server", + ) + for (path in commonPaths) { + val f = File(path) + if (f.isFile && f.canExecute()) return f.absolutePath + } + + return null + } + + /** + * Resolve the GGUF model path. + * Checks: 1) HuggingFace cache, 2) Sweep models cache (~/.cache/sweep/models/). + * Returns the path to the .gguf file, or null if not cached yet. + */ + private fun resolveModelPath(): String? { + val model = getSelectedModel() + + // Check HuggingFace cache + val hfCacheBase = File(System.getProperty("user.home"), ".cache/huggingface/hub") + val modelDir = File(hfCacheBase, "models--${model.repo.replace("/", "--")}") + if (modelDir.isDirectory) { + val snapshotsDir = File(modelDir, "snapshots") + if (snapshotsDir.isDirectory) { + val snapshots = snapshotsDir.listFiles()?.filter { it.isDirectory } ?: emptyList() + for (snapshot in snapshots) { + val gguf = File(snapshot, model.filename) + if (gguf.isFile) return gguf.absolutePath + } + } + } + + // Check Sweep models cache + val sweepCache = File(System.getProperty("user.home"), ".cache/sweep/models/${model.filename}") + if (sweepCache.isFile) return sweepCache.absolutePath + + return null + } + + private val SWEEP_MODELS_DIR = System.getProperty("user.home") + "/.cache/sweep/models" + + /** + * Build a shell command to download the model. + * Tries hf first, falls back to curl. + */ + private fun buildModelDownloadCommand(): String { + val model = getSelectedModel() + val hfCliCmd = "hf download ${model.repo} ${model.filename}" + val url = "https://huggingface.co/${model.repo}/resolve/main/${model.filename}" + val destDir = SWEEP_MODELS_DIR + val destFile = "$destDir/${model.filename}" + val curlCmd = "mkdir -p $destDir && curl -L -o \"$destFile\" \"$url\"" + + // Shell one-liner: try hf, fall back to curl + return "if command -v hf >/dev/null 2>&1; then $hfCliCmd; else echo 'hf not found, downloading with curl...' && $curlCmd; fi" + } + + /** + * Build the llama-server command with speculative decoding flags. + */ + private fun buildLlamaServerCommand(llamaServerPath: String, modelPath: String, port: Int): List { + return listOf( + llamaServerPath, + "-m", modelPath, + "--port", port.toString(), + "-ngl", "-1", + "--flash-attn", "auto", + "--spec-type", "ngram-mod", + "--spec-ngram-size-n", "24", + "--draft-min", "48", + "--draft-max", "64", + ) + } + /** * Builds the full command string for starting the server. - * Returns null if uvx cannot be found (and uv install also fails). + * Uses llama-server when native engine is enabled, otherwise uvx. */ fun getServerCommand(): String? { + val useNativeEngine = SweepSettings.getInstance().autocompleteLocalNativeEngine + val port = getPort() + + if (useNativeEngine) { + val llamaPath = resolveLlamaServer() + if (llamaPath == null) { + showNotification( + "llama-server not found. Install it with: brew install llama.cpp", + NotificationType.ERROR, + ) + return null + } + + val modelPath = resolveModelPath() + if (modelPath != null) { + return buildLlamaServerCommand(llamaPath, modelPath, port).joinToString(" ") { arg -> + if (arg.contains(" ")) "\"$arg\"" else arg + } + } + + // Model not downloaded yet — return a command that downloads first, then starts + val model = getSelectedModel() + val downloadCmd = buildModelDownloadCommand() + val repoDirName = model.repo.replace("/", "--") + val sweepCachePath = "$SWEEP_MODELS_DIR/${model.filename}" + val serverCmd = buildLlamaServerCommand(llamaPath, "\$MODEL_PATH", port) + .joinToString(" ") { if (it.contains(" ")) "\"$it\"" else it } + + // After download, find the model in either HF cache or Sweep cache + val findModel = "MODEL_PATH=\$(find ~/.cache/huggingface/hub/models--$repoDirName -name '${model.filename}' 2>/dev/null | head -1); " + + "[ -z \"\$MODEL_PATH\" ] && MODEL_PATH=\"$sweepCachePath\"" + + return "$downloadCmd && $findModel && $serverCmd" + } + + // Fall back to uvx path var uvxPath = resolveUvx() if (uvxPath == null) { logger.info("uvx not found, attempting to install uv") @@ -382,7 +523,7 @@ class LocalAutocompleteServerManager : Disposable { return null } } - return buildUvxCommand(uvxPath, getPort()).joinToString(" ") { arg -> + return buildUvxCommand(uvxPath, port).joinToString(" ") { arg -> if (arg.contains(" ")) "\"$arg\"" else arg } } @@ -441,6 +582,52 @@ class LocalAutocompleteServerManager : Disposable { } } + /** + * Restarts the server in the terminal by sending Ctrl+C to the existing tab, + * then launching the new command. Used when the model is changed. + */ + fun restartServerInTerminal(project: Project) { + val command = getServerCommand() ?: return + + ApplicationManager.getApplication().invokeLater { + try { + val toolWindow = ToolWindowManager.getInstance(project) + .getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID) ?: return@invokeLater + + val existingContent = toolWindow.contentManager.contents.firstOrNull { + it.displayName == TERMINAL_TAB_NAME + } + val widget = if (existingContent != null) { + TerminalToolWindowManager.findWidgetByContent(existingContent) + } else { + null + } + + if (widget != null) { + // Send Ctrl+C to stop the running server, then start new one + ApplicationManager.getApplication().executeOnPooledThread { + val isPowerShell = TerminalApiWrapper.isPowerShell(project) + ApplicationManager.getApplication().invokeLater { + // Send Ctrl+C + TerminalApiWrapper.sendCommand(widget, "\u0003", project, isPowerShell) + } + // Wait for the server to stop + Thread.sleep(1500) + ApplicationManager.getApplication().invokeLater { + TerminalApiWrapper.sendCommand(widget, command, project, isPowerShell) + } + } + logger.info("Restarting local autocomplete server with: $command") + } else { + // No existing terminal — just start fresh + startServerInTerminal(project) + } + } catch (e: Exception) { + logger.warn("Failed to restart local autocomplete server: ${e.message}") + } + } + } + private fun stopServer() { serverProcess?.let { process -> try { diff --git a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt index 034b143..0e2197a 100644 --- a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt +++ b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt @@ -176,10 +176,14 @@ class SweepSettings : PersistentStateComponent { // Don't notify settings changed for BYOK to avoid excessive chatter } - var autocompleteLocalMode: Boolean = false + var autocompleteLocalMode: Boolean = true var autocompleteLocalMlx: Boolean = false + var autocompleteLocalNativeEngine: Boolean = true + + var autocompleteLocalModel: String = "sweep-0.5B" + var autocompleteLocalPort: Int = 8081 fun ensureDefaultPromptsInitialized() { diff --git a/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParserTest.kt b/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParserTest.kt new file mode 100644 index 0000000..56fac64 --- /dev/null +++ b/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesCompletionParserTest.kt @@ -0,0 +1,149 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.google.gson.Gson +import com.google.gson.JsonObject +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance + +/** + * Tests for NesCompletionParser, comparing outputs against Python-generated fixtures. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class NesCompletionParserTest { + + private lateinit var fixtures: JsonObject + private val gson = Gson() + + @BeforeAll + fun loadFixtures() { + val json = this::class.java.classLoader + .getResourceAsStream("nes_fixtures/parser_fixtures.json")!! + .bufferedReader().readText() + fixtures = gson.fromJson(json, JsonObject::class.java) + } + + @Test + fun `getGhostTextWithLocation matches Python`() { + val f = fixtures["ghost_text_with_location"].asJsonObject + val input = f["input"].asJsonObject + val expected = f["output"].asString + + val result = NesCompletionParser.getGhostTextWithLocation( + input["completion"].asString, + input["code_block"].asString, + input["cursor_pos"].asInt, + ) + assertEquals(expected, result) + } + + @Test + fun `findGhostTextNonLocal matches Python`() { + val f = fixtures["find_ghost_text_non_local"].asJsonObject + val input = f["input"].asJsonObject + val expectedArray = f["output"].asJsonArray + val expectedText = expectedArray[0].asString + val expectedPos = expectedArray[1].asInt + + val (text, pos) = NesCompletionParser.findGhostTextNonLocal( + input["completion"].asString, + input["code_block"].asString, + input["cursor_pos"].asInt, + ) + assertEquals(expectedText, text) + assertEquals(expectedPos, pos) + } + + @Test + fun `isSingleLineGhostText matches Python`() { + val f = fixtures["single_line_ghost_text"].asJsonObject + val input = f["input"].asJsonObject + val expected = f["output"].asString + + val result = NesCompletionParser.isSingleLineGhostText( + input["completion"].asString, + input["code_block"].asString, + input["cursor_pos"].asInt, + ) + assertEquals(expected, result) + } + + @Test + fun `isSingleLineGhostText empty when no match`() { + val f = fixtures["single_line_ghost_text_empty"].asJsonObject + val input = f["input"].asJsonObject + val expected = f["output"].asString + + val result = NesCompletionParser.isSingleLineGhostText( + input["completion"].asString, + input["code_block"].asString, + input["cursor_pos"].asInt, + ) + assertEquals(expected, result) + } + + @Test + fun `isPureInsertionAboveCursor matches Python`() { + val f = fixtures["is_pure_insertion_above_cursor_true"].asJsonObject + val input = f["input"].asJsonObject + val expected = f["output"].asBoolean + + val result = NesCompletionParser.isPureInsertionAboveCursor( + input["code_block"].asString, + input["completion"].asString, + input["cursor_pos"].asInt, + ) + assertEquals(expected, result) + } + + @Test + fun `applyCompletionsToCodeBlock matches Python`() { + val f = fixtures["apply_completions_to_code_block"].asJsonObject + val input = f["input"].asJsonObject + val expected = f["output"].asString + + val comps = input["completions"].asJsonArray.map { it.asJsonObject }.map { + NesCompletionParser.AutocompleteResult( + it["start_index"].asInt, + it["end_index"].asInt, + it["completion"].asString, + 1.0f, + "test", + ) + } + + val result = NesCompletionParser.applyCompletionsToCodeBlock( + comps, + input["file_contents"].asString, + input["cleaned_code_block"].asString, + ) + assertEquals(expected, result) + } + + @Test + fun `selectBestHunkFromCompletion matches Python`() { + val f = fixtures["select_best_hunk_simple_change"].asJsonObject + val input = f["input"].asJsonObject + val expectedArray = f["output"].asJsonArray + + val results = NesCompletionParser.selectBestHunkFromCompletion( + input["completion"].asString, + input["cleaned_code_block"].asString, + input["file_contents"].asString, + input["cursor_position"].asInt, + input["autocomplete_id"].asString, + ) + + assertEquals(expectedArray.size(), results.size, "Number of results should match") + + for (i in results.indices) { + val expected = expectedArray[i].asJsonObject + val actual = results[i] + assertEquals(expected["start_index"].asInt, actual.startIndex, "Result $i: start_index") + assertEquals(expected["end_index"].asInt, actual.endIndex, "Result $i: end_index") + assertEquals(expected["completion"].asString, actual.completion, "Result $i: completion") + assertEquals(expected["autocomplete_id"].asString, actual.autocompleteId, "Result $i: autocomplete_id") + } + } +} diff --git a/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtilsTest.kt b/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtilsTest.kt new file mode 100644 index 0000000..06a5c37 --- /dev/null +++ b/src/test/kotlin/dev/sweep/assistant/autocomplete/edit/engine/NesUtilsTest.kt @@ -0,0 +1,276 @@ +package dev.sweep.assistant.autocomplete.edit.engine + +import com.google.gson.Gson +import com.google.gson.JsonObject +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance + +/** + * Tests for NesUtils, comparing outputs against Python-generated fixtures + * to ensure parity between the Kotlin and Python implementations. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class NesUtilsTest { + + private lateinit var fixtures: JsonObject + private val gson = Gson() + + @BeforeAll + fun loadFixtures() { + val json = this::class.java.classLoader + .getResourceAsStream("nes_fixtures/utils_fixtures.json")!! + .bufferedReader().readText() + fixtures = gson.fromJson(json, JsonObject::class.java) + } + + // --- extract_diff_parts --- + + @Test + fun `extractDiffParts no context matches Python`() { + val fixture = fixtures["extract_diff_parts_no_context"].asJsonObject + val hunk = fixture["input"].asJsonObject["hunk"].asString + val expected = fixture["output"].asJsonArray.map { it.asString } + + val (oldCode, newCode) = NesUtils.extractDiffParts(hunk, 0) + + assertEquals(expected[0], oldCode) + assertEquals(expected[1], newCode) + } + + @Test + fun `extractDiffParts one context line matches Python`() { + val fixture = fixtures["extract_diff_parts_one_context"].asJsonObject + val hunk = fixture["input"].asJsonObject["hunk"].asString + val expected = fixture["output"].asJsonArray.map { it.asString } + + val (oldCode, newCode) = NesUtils.extractDiffParts(hunk, 1) + + assertEquals(expected[0], oldCode) + assertEquals(expected[1], newCode) + } + + @Test + fun `extractDiffParts all lines matches Python`() { + val fixture = fixtures["extract_diff_parts_all"].asJsonObject + val hunk = fixture["input"].asJsonObject["hunk"].asString + val expected = fixture["output"].asJsonArray.map { it.asString } + + val (oldCode, newCode) = NesUtils.extractDiffParts(hunk, -1) + + assertEquals(expected[0], oldCode) + assertEquals(expected[1], newCode) + } + + // --- split_into_hunks --- + + @Test + fun `splitIntoHunks matches Python`() { + val fixture = fixtures["split_into_hunks"].asJsonObject + val input = fixture["input"].asString + val expected = fixture["output"].asJsonArray.map { it.asString } + + val result = NesUtils.splitIntoHunks(input) + + assertEquals(expected, result) + } + + // --- get_line_number_from_position --- + + @Test + fun `getLineNumberFromPosition matches Python`() { + val fixture = fixtures["get_line_number_from_position"].asJsonObject + val text = fixture["input"].asJsonObject["text"].asString + val outputs = fixture["outputs"].asJsonObject + + assertEquals(outputs["pos_0"].asInt, NesUtils.getLineNumberFromPosition(text, 0)) + assertEquals(outputs["pos_6"].asInt, NesUtils.getLineNumberFromPosition(text, 6)) + assertEquals(outputs["pos_12"].asInt, NesUtils.getLineNumberFromPosition(text, 12)) + assertEquals(outputs["pos_end"].asInt, NesUtils.getLineNumberFromPosition(text, text.length)) + assertEquals(outputs["pos_negative"].asInt, NesUtils.getLineNumberFromPosition(text, -1)) + } + + // --- strip_leading_empty_newlines --- + + @Test + fun `stripLeadingEmptyNewlines matches Python`() { + val fixture = fixtures["strip_leading_empty_newlines"].asJsonObject + val input = fixture["input"].asString + val expected = fixture["output"].asString + + assertEquals(expected, NesUtils.stripLeadingEmptyNewlines(input)) + } + + // --- filter_whitespace_only_hunks --- + + @Test + fun `filterWhitespaceOnlyHunks matches Python`() { + val fixture = fixtures["filter_whitespace_only_hunks"].asJsonObject + val input = fixture["input"].asJsonArray.map { it.asString } + val expected = fixture["output"].asJsonArray.map { it.asString } + + val result = NesUtils.filterWhitespaceOnlyHunks(input) + + assertEquals(expected, result) + } + + // --- split_into_diff_hunks --- + + @Test + fun `splitIntoDiffHunks matches Python semantically`() { + val fixture = fixtures["split_into_diff_hunks"].asJsonObject + val inputContent = fixture["input"].asJsonObject["input_content"].asString + val outputContent = fixture["input"].asJsonObject["output_content"].asString + val expected = fixture["output"].asJsonArray + + val result = NesUtils.splitIntoDiffHunks(inputContent, outputContent) + + // Compare the number of hunks + assertEquals(expected.size(), result.size, "Number of hunks should match") + + // Compare each hunk semantically (line numbers and content) + for (i in result.indices) { + val expectedHunk = expected[i].asJsonObject + val actualHunk = result[i] + + assertEquals( + expectedHunk["input_start"].asInt, actualHunk.inputStart, + "Hunk $i: input_start mismatch" + ) + assertEquals( + expectedHunk["output_start"].asInt, actualHunk.outputStart, + "Hunk $i: output_start mismatch" + ) + + // Compare line content (stripping trailing newlines for robustness) + val expectedInputLines = expectedHunk["input_lines"].asJsonArray.map { it.asString.trimEnd('\n') } + val actualInputLines = actualHunk.inputLines.map { it.trimEnd('\n') } + assertEquals(expectedInputLines, actualInputLines, "Hunk $i: input_lines mismatch") + + val expectedOutputLines = expectedHunk["output_lines"].asJsonArray.map { it.asString.trimEnd('\n') } + val actualOutputLines = actualHunk.outputLines.map { it.trimEnd('\n') } + assertEquals(expectedOutputLines, actualOutputLines, "Hunk $i: output_lines mismatch") + } + } + + // --- is_large_diff_above_cursor --- + + @Test + fun `isLargeDiffAboveCursor matches Python`() { + val fixture = fixtures["is_large_diff_above_cursor"].asJsonObject + + val orig = "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n" + val compLarge = "a\nb\nc\nd\ne\nf\ng\nh\nline7\nline8\n" + val compSmall = "line1\nchanged\nline3\nline4\nline5\nline6\nline7\nline8\n" + + assertEquals(fixture["large_diff"].asBoolean, NesUtils.isLargeDiffAboveCursor(orig, compLarge, 20)) + assertEquals(fixture["small_diff"].asBoolean, NesUtils.isLargeDiffAboveCursor(orig, compSmall, 20)) + assertEquals(fixture["same"].asBoolean, NesUtils.isLargeDiffAboveCursor(orig, orig, 20)) + } + + // --- truncate_long_lines --- + + @Test + fun `truncateLongLines matches Python`() { + val fixture = fixtures["truncate_long_lines"].asJsonObject + val input = fixture["input"].asString + val expected = fixture["output"].asString + + assertEquals(expected, NesUtils.truncateLongLines(input)) + } + + // --- should_disable_for_code_block --- + + @Test + fun `shouldDisableForCodeBlock matches Python`() { + val fixture = fixtures["should_disable_for_code_block"].asJsonObject + + assertEquals(fixture["short_lines"].asBoolean, NesUtils.shouldDisableForCodeBlock("short\nlines\n")) + assertEquals( + fixture["long_line"].asBoolean, + NesUtils.shouldDisableForCodeBlock("short\n" + "x".repeat(1001) + "\nshort") + ) + } + + // --- is_equal_ignoring_newlines --- + + @Test + fun `isEqualIgnoringNewlines matches Python`() { + val fixture = fixtures["is_equal_ignoring_newlines"].asJsonObject + + assertEquals(fixture["equal"].asBoolean, NesUtils.isEqualIgnoringNewlines("a\n\nb", "a\nb")) + assertEquals(fixture["not_equal"].asBoolean, NesUtils.isEqualIgnoringNewlines("a\nb", "a b")) + } + + // --- pretokenize --- + + @Test + fun `pretokenize matches Python`() { + val fixture = fixtures["pretokenize"].asJsonObject + val input = fixture["input"].asString + val expected = fixture["output"].asJsonArray.map { it.asString } + + val result = NesUtils.pretokenize(input) + + assertEquals(expected, result) + } + + // --- get_lines_around_cursor (small file returns full) --- + + @Test + fun `getLinesAroundCursor small file returns full`() { + val fixture = fixtures["get_lines_around_cursor_small"].asJsonObject + val text = fixture["input"].asJsonObject["text"].asString + val cursorPos = fixture["input"].asJsonObject["cursor_position"].asInt + val expected = fixture["output"].asString + + assertEquals(expected, NesUtils.getLinesAroundCursor(text, cursorPos)) + } + + // --- pack_items_for_prompt --- + + @Test + fun `packItemsForPrompt matches Python`() { + val fixture = fixtures["pack_items_for_prompt"].asJsonObject + val items = fixture["input"].asJsonObject["items"].asJsonArray.map { it.asString } + val tokenLimit = fixture["input"].asJsonObject["token_limit"].asInt + val expected = fixture["output"].asJsonArray.map { it.asString } + + val result = NesUtils.packItemsForPrompt(items, { it }, tokenLimit) + + assertEquals(expected, result) + } + + // --- parse_hunk --- + + @Test + fun `parseHunk matches Python`() { + val fixture = fixtures["parse_hunk"].asJsonObject + val input = fixture["input"].asString + val expected = fixture["output"].asJsonObject + + val result = NesUtils.parseHunk(input) + + assertEquals(expected["input_start"].asInt, result.inputStart) + assertEquals(expected["output_start"].asInt, result.outputStart) + + val expectedInputLines = expected["input_lines"].asJsonArray.map { it.asString.trimEnd('\n') } + val actualInputLines = result.inputLines.map { it.trimEnd('\n') } + assertEquals(expectedInputLines, actualInputLines) + + val expectedOutputLines = expected["output_lines"].asJsonArray.map { it.asString.trimEnd('\n') } + val actualOutputLines = result.outputLines.map { it.trimEnd('\n') } + assertEquals(expectedOutputLines, actualOutputLines) + } + + // --- linesSplitKeepEnds --- + + @Test + fun `linesSplitKeepEnds preserves line endings`() { + assertEquals(listOf("a\n", "b\n", "c"), "a\nb\nc".linesSplitKeepEnds()) + assertEquals(listOf("a\n", "b\n"), "a\nb\n".linesSplitKeepEnds()) + assertEquals(listOf("single"), "single".linesSplitKeepEnds()) + assertEquals(listOf(""), "".linesSplitKeepEnds()) + } +} diff --git a/src/test/resources/nes_fixtures/parser_fixtures.json b/src/test/resources/nes_fixtures/parser_fixtures.json new file mode 100644 index 0000000..db29dba --- /dev/null +++ b/src/test/resources/nes_fixtures/parser_fixtures.json @@ -0,0 +1,77 @@ +{ + "ghost_text_with_location": { + "input": { + "completion": "def hello():\n print('hi')\n print('bye')\n return True\n", + "code_block": "def hello():\n print('hi')\n return True\n", + "cursor_pos": 29 + }, + "output": " print('bye')\n" + }, + "find_ghost_text_non_local": { + "input": { + "completion": "def hello():\n print('hi')\n print('bye')\n return True\n", + "code_block": "def hello():\n print('hi')\n return True\n", + "cursor_pos": 29 + }, + "output": [ + "print('bye')\n ", + 33 + ] + }, + "single_line_ghost_text": { + "input": { + "completion": "x = 1\ny = 2\nz = 3\n", + "code_block": "x = 1\ny = \nz = 3\n", + "cursor_pos": 10 + }, + "output": "2" + }, + "single_line_ghost_text_empty": { + "input": { + "completion": "totally different", + "code_block": "x = 1\ny = \nz = 3\n", + "cursor_pos": 10 + }, + "output": "" + }, + "is_pure_insertion_above_cursor_true": { + "input": { + "code_block": "line1\nline2\nline3\n", + "completion": "line1\nnew_line\nline2\nline3\n", + "cursor_pos": 12 + }, + "output": true + }, + "select_best_hunk_simple_change": { + "input": { + "completion": " x = 10\n y = 2\n z = x + y\n return z\n", + "cleaned_code_block": " x = 1\n y = 2\n z = x + y\n return z\n", + "file_contents": "import os\n\ndef process():\n x = 1\n y = 2\n z = x + y\n return z\n", + "cursor_position": 35, + "autocomplete_id": "test-id" + }, + "output": [ + { + "start_index": 35, + "end_index": 35, + "completion": "0", + "confidence": 1.0, + "autocomplete_id": "test-id-0" + } + ] + }, + "apply_completions_to_code_block": { + "input": { + "file_contents": "import os\n\ndef process():\n x = 1\n y = 2\n z = x + y\n return z\n", + "cleaned_code_block": " x = 1\n y = 2\n z = x + y\n return z\n", + "completions": [ + { + "start_index": 26, + "end_index": 35, + "completion": " x = 10" + } + ] + }, + "output": " x = 10\n y = 2\n z = x + y\n return z\n" + } +} \ No newline at end of file diff --git a/src/test/resources/nes_fixtures/utils_fixtures.json b/src/test/resources/nes_fixtures/utils_fixtures.json new file mode 100644 index 0000000..120666b --- /dev/null +++ b/src/test/resources/nes_fixtures/utils_fixtures.json @@ -0,0 +1,183 @@ +{ + "extract_diff_parts_no_context": { + "input": { + "hunk": "@@ -28,7 +28,7 @@\n transformed = {\n \"id\": item[\"id\"],\n- \"score\": round(item.get(\"score\", 0) * 100, 2),\n+ \"value\": round(item.get(\"score\", 0) * 100, 2),\n }", + "num_context_lines": 0 + }, + "output": [ + " \"score\": round(item.get(\"score\", 0) * 100, 2),\n", + " \"value\": round(item.get(\"score\", 0) * 100, 2),\n" + ] + }, + "extract_diff_parts_one_context": { + "input": { + "hunk": "@@ -28,7 +28,7 @@\n transformed = {\n \"id\": item[\"id\"],\n- \"score\": round(item.get(\"score\", 0) * 100, 2),\n+ \"value\": round(item.get(\"score\", 0) * 100, 2),\n }", + "num_context_lines": 1 + }, + "output": [ + " \"id\": item[\"id\"],\n \"score\": round(item.get(\"score\", 0) * 100, 2),\n}", + " \"id\": item[\"id\"],\n \"value\": round(item.get(\"score\", 0) * 100, 2),\n}" + ] + }, + "extract_diff_parts_all": { + "input": { + "hunk": "@@ -28,7 +28,7 @@\n transformed = {\n \"id\": item[\"id\"],\n- \"score\": round(item.get(\"score\", 0) * 100, 2),\n+ \"value\": round(item.get(\"score\", 0) * 100, 2),\n }", + "num_context_lines": -1 + }, + "output": [ + "transformed = {\n \"id\": item[\"id\"],\n \"score\": round(item.get(\"score\", 0) * 100, 2),\n}", + "transformed = {\n \"id\": item[\"id\"],\n \"value\": round(item.get(\"score\", 0) * 100, 2),\n}" + ] + }, + "split_into_hunks": { + "input": "File: src/main.py\n@@ -1,3 +1,3 @@\n-old\n+new\nFile: src/utils.py\n@@ -5,2 +5,2 @@\n-foo\n+bar", + "output": [ + "File: src/main.py\n@@ -1,3 +1,3 @@\n-old\n+new", + "File: src/utils.py\n@@ -5,2 +5,2 @@\n-foo\n+bar" + ] + }, + "get_line_number_from_position": { + "input": { + "text": "line0\nline1\nline2\nline3" + }, + "outputs": { + "pos_0": 0, + "pos_6": 1, + "pos_12": 2, + "pos_end": 3, + "pos_negative": 0 + } + }, + "strip_leading_empty_newlines": { + "input": "\n\n\nhello\nworld", + "output": "hello\nworld" + }, + "filter_whitespace_only_hunks": { + "input": [ + "File: a.py\n@@ -1,1 +1,1 @@\n- foo\n+ foo", + "File: b.py\n@@ -1,1 +1,1 @@\n-old\n+new" + ], + "output": [ + "File: b.py\n@@ -1,1 +1,1 @@\n-old\n+new" + ] + }, + "split_into_diff_hunks": { + "input": { + "input_content": "line1\nline2\nline3\nline4\n", + "output_content": "line1\nchanged2\nline3\nnew_line4\nextra_line\n" + }, + "output": [ + { + "input_start": 2, + "input_lines": [ + "line2\n" + ], + "output_start": 2, + "output_lines": [ + "changed2\n" + ] + }, + { + "input_start": 4, + "input_lines": [ + "line4\n" + ], + "output_start": 4, + "output_lines": [ + "new_line4\n", + "extra_line\n" + ] + } + ] + }, + "is_large_diff_above_cursor": { + "large_diff": true, + "small_diff": false, + "same": false + }, + "truncate_long_lines": { + "input": "short\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\nshort2", + "output": "short\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx...\nshort2" + }, + "should_disable_for_code_block": { + "short_lines": false, + "long_line": true + }, + "is_equal_ignoring_newlines": { + "equal": true, + "not_equal": false + }, + "pretokenize": { + "input": "def hello_world(x: int) -> str:\n return f'Hello {x}'", + "output": [ + "def", + " hello", + "_world", + "(x", + ":", + " int", + ")", + " ->", + " str", + ":\n", + " ", + " return", + " f", + "'Hello", + " {", + "x", + "}'" + ] + }, + "get_lines_around_cursor_small": { + "input": { + "text": "line0\nline1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10\nline11\nline12\nline13\nline14\nline15\nline16\nline17\nline18\nline19\nline20\nline21\nline22\nline23\nline24\nline25\nline26\nline27\nline28\nline29\nline30\nline31\nline32\nline33\nline34\nline35\nline36\nline37\nline38\nline39\nline40\nline41\nline42\nline43\nline44\nline45\nline46\nline47\nline48\nline49", + "cursor_position": 30 + }, + "output": "line0\nline1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10\nline11\nline12\nline13\nline14\nline15\nline16\nline17\nline18\nline19\nline20\nline21\nline22\nline23\nline24\nline25\nline26\nline27\nline28\nline29\nline30\nline31\nline32\nline33\nline34\nline35\nline36\nline37\nline38\nline39\nline40\nline41\nline42\nline43\nline44\nline45\nline46\nline47\nline48\nline49" + }, + "get_block_at_cursor": { + "input": { + "file_contents": "import json\n\ndef process(items):\n results = []\n for item in items:\n transformed = {\"id\": item[\"id\"], \"name\": item[\"name\"]}\n results.append(transformed)\n return results\n\ndef main():\n data = json.load(open(\"data.json\"))\n print(process(data))\n", + "cursor_position": 158 + }, + "output": { + "code_block": " for item in items:\n transformed = {\"id\": item[\"id\"], \"name\": item[\"name\"]}\n results.append(transformed)\n return results\n\ndef main():\n data = json.load(open(\"data.json\"))\n print(process(data))\n", + "prefix": "import json\n\ndef process(items):\n results = []", + "suffix": "", + "block_start_index": 50 + } + }, + "pack_items_for_prompt": { + "input": { + "items": [ + "short", + "medium length text", + "a very long string that takes up space" + ], + "token_limit": 10 + }, + "output": [ + "short", + "medium length text" + ] + }, + "parse_hunk": { + "input": "@@ -10,3 +10,4 @@\n \n context line\n-old line\n+new line1\n+new line2\n context after\n", + "output": { + "input_start": 10, + "input_lines": [ + "context line\n", + "old line\n", + "context after\n" + ], + "output_start": 10, + "output_lines": [ + "context line\n", + "new line1\n", + "new line2\n", + "context after\n" + ] + } + } +} \ No newline at end of file From 99b9472a1ddb0082768bfd9a9191d925d95c1386 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Tue, 14 Apr 2026 22:04:39 +0200 Subject: [PATCH 5/7] Remove intellij-community submodule Was only used as developer reference (257K files, 5.9GB git data). The same code is available on GitHub. Removing it reduces .git from 5.9GB to 12MB and git status from 3s to 11ms. Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitmodules | 3 --- vendor/intellij-community | 1 - 2 files changed, 4 deletions(-) delete mode 100644 .gitmodules delete mode 160000 vendor/intellij-community diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 5515ebb..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "vendor/intellij-community"] - path = vendor/intellij-community - url = https://github.com/JetBrains/intellij-community diff --git a/vendor/intellij-community b/vendor/intellij-community deleted file mode 160000 index 8a2ca5f..0000000 --- a/vendor/intellij-community +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8a2ca5f6b4193c00d6ca35c79b4362ae99280628 From f37daa9f56e8e72061572a4eff5917b3f51d67f2 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Wed, 15 Apr 2026 00:28:24 +0200 Subject: [PATCH 6/7] Add OpenAI-compatible chat and agent mode with tool calling Major changes: - OpenAIChatService: Fetches models from /v1/models, streams chat completions from /v1/chat/completions (works with LM Studio, Ollama, etc.) - OpenAIAgentService: Multi-turn agent loop with 7 tools (read_file, search_files, glob, list_files, str_replace, create_file, bash) using OpenAI function calling protocol - Stream.kt: Routes to OpenAI path when URL is not a Sweep backend, builds system prompt with current file context, cursor position, and project info - ModelPickerMenu: Fetches models from /v1/models for non-Sweep backends, clears stale Sweep model cache - WelcomeScreen: New local-first welcome page with setup instructions - MarkdownBlock: Added // filepath: comment pattern detection for code blocks, enabling Apply button on model responses - Post-processing: Extracts blocks, generates CodeReplacement annotations for Apply button - Chat works without authentication when native engine is enabled - Stop button works for both chat and agent streaming - Version bumped to 1.30.0 Also fixes: - WriteIntentReadAction crash on IntelliJ 2025.1+ - HttpURLConnection used instead of HttpClient (fixes localhost timeout) - Renamed "Sweep API URL" to "OpenAI Compatible API URL" Co-Authored-By: Claude Opus 4.6 (1M context) --- build.gradle.kts | 2 +- .../assistant/components/MessagesComponent.kt | 25 +- .../sweep/assistant/components/SweepConfig.kt | 4 +- .../assistant/components/WelcomeScreen.kt | 60 ++- .../dev/sweep/assistant/controllers/Stream.kt | 270 ++++++++++ .../assistant/services/OpenAIAgentService.kt | 490 ++++++++++++++++++ .../assistant/services/OpenAIChatService.kt | 176 +++++++ .../sweep/assistant/settings/SweepSettings.kt | 5 +- .../sweep/assistant/views/MarkdownBlock.kt | 20 +- .../sweep/assistant/views/ModelPickerMenu.kt | 102 +++- 10 files changed, 1089 insertions(+), 65 deletions(-) create mode 100644 src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt create mode 100644 src/main/kotlin/dev/sweep/assistant/services/OpenAIChatService.kt diff --git a/build.gradle.kts b/build.gradle.kts index 01773f7..4a22914 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ val pluginId = "dev.sweep.assistant" val pluginName = "Sweep Autocomplete OSS" println("Building plugin: $pluginName with ID: $pluginId") group = "dev.sweep" -version = "1.29.3" +version = "1.30.0" repositories { mavenCentral() diff --git a/src/main/kotlin/dev/sweep/assistant/components/MessagesComponent.kt b/src/main/kotlin/dev/sweep/assistant/components/MessagesComponent.kt index 51aa07e..f7c67a0 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/MessagesComponent.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/MessagesComponent.kt @@ -4,7 +4,6 @@ import com.intellij.openapi.Disposable import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.application.WriteIntentReadAction import com.intellij.openapi.components.Service import com.intellij.openapi.fileEditor.FileEditorManager import com.intellij.openapi.project.Project @@ -642,20 +641,18 @@ class MessagesComponent( val sessionMessageList = messageListService.getMessageListForConversation(conversationId) val latestMessages = sessionMessageList?.snapshot() ?: messageListService.snapshot() val latestMessage = latestMessages.getOrNull(messageIndex) ?: message - // Some nested UI builders access PSI/documents and may require WriteIntentReadAction - // Use write-intent read action to satisfy both PSI reads and nested editor creation + // Some nested UI builders access PSI/documents — create component directly + // since we're already on the EDT via invokeLater val real: JComponent = - WriteIntentReadAction.compute { - createComponent( - message = latestMessage, - index = messageIndex, - currentMessages = latestMessages, - loadedFromHistory = loadedFromHistory, - currentWidth = width, - updateType = updateType, - conversationId = conversationId, - ) - } + createComponent( + message = latestMessage, + index = messageIndex, + currentMessages = latestMessages, + loadedFromHistory = loadedFromHistory, + currentWidth = width, + updateType = updateType, + conversationId = conversationId, + ) removeAll() add(real, BorderLayout.CENTER) if (real is Disposable) { diff --git a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt index 0a74f5b..de5d0b3 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/SweepConfig.kt @@ -4695,7 +4695,7 @@ class SweepConfig( layout = BoxLayout(this, BoxLayout.Y_AXIS) border = JBUI.Borders.empty(8, 16) add( - JLabel("Sweep API URL").apply { + JLabel("OpenAI Compatible API URL").apply { border = JBUI.Borders.empty(0, 4, 4, 0) }, ) @@ -4705,7 +4705,7 @@ class SweepConfig( }, ) add( - createCommentLabel("The Sweep API endpoint (e.g., http://localhost:8080)").apply { + createCommentLabel("OpenAI-compatible API endpoint (e.g., http://localhost:1234 for LM Studio)").apply { border = JBUI.Borders.empty(4, 4, 0, 0) withSweepFont(project, scale = 0.9f) }, diff --git a/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt b/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt index 4dc2fc4..8b5184d 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt @@ -2,19 +2,16 @@ package dev.sweep.assistant.components import com.intellij.openapi.Disposable import com.intellij.openapi.project.Project -import com.intellij.openapi.wm.ToolWindowManager import com.intellij.ui.components.JBScrollPane import com.intellij.ui.dsl.builder.AlignX import com.intellij.ui.dsl.builder.TopGap import com.intellij.ui.dsl.builder.panel import com.intellij.util.ui.JBUI +import dev.sweep.assistant.services.LocalAutocompleteServerManager import dev.sweep.assistant.services.SweepProjectService -import dev.sweep.assistant.settings.SweepSettings -import dev.sweep.assistant.settings.SweepSettingsParser import dev.sweep.assistant.theme.SweepColors import dev.sweep.assistant.theme.SweepIcons import dev.sweep.assistant.theme.SweepIcons.scale -import dev.sweep.assistant.utils.SweepConstants.TOOLWINDOW_NAME import dev.sweep.assistant.views.RoundedButton import java.awt.GridBagConstraints import java.awt.GridBagLayout @@ -32,31 +29,28 @@ class WelcomeScreen( icon(SweepIcons.BigSweepIcon.scale(80f)).align(AlignX.CENTER) } row { - text("Welcome!") + text("Sweep Autocomplete OSS") .applyToComponent { font = font.deriveFont(java.awt.Font.BOLD, font.size * 1.2f) }.align(AlignX.CENTER) }.topGap(TopGap.SMALL) row { - text("Please sign in below:").applyToComponent { font = font.deriveFont(font.size * 1.1f) }.align(AlignX.CENTER) + text( + """ +
+ Local AI code autocomplete powered by
+ llama.cpp with next-edit suggestion models. +
+ """.trimIndent(), + ).applyToComponent { font = font.deriveFont(font.size * 1.0f) }.align(AlignX.CENTER) } row { cell( RoundedButton( - text = " Sign in", + text = " Start Local Server", parentDisposable = parentDisposable, onClick = { - val isCloudEnvironment = SweepSettingsParser.isCloudEnvironment() - if (isCloudEnvironment) { - // Trigger Sweep OAuth flow - ToolWindowManager - .getInstance(project) - .getToolWindow(TOOLWINDOW_NAME) - ?.hide() - SweepSettings.getInstance().initiateGitHubAuth(project) - } else { - SweepConfig.getInstance(project).showConfigPopup() - } + LocalAutocompleteServerManager.getInstance().startServerInTerminal(project) }, ).apply { icon = SweepIcons.UserIcon.scale(16f) @@ -67,14 +61,36 @@ class WelcomeScreen( }, ).align(AlignX.CENTER) }.topGap(TopGap.SMALL) + row { + cell( + RoundedButton( + text = " Configure Settings", + parentDisposable = parentDisposable, + onClick = { + SweepConfig.getInstance(project).showConfigPopup() + }, + ).apply { + font = font.deriveFont(font.size * 1.0f) + border = JBUI.Borders.empty(4, 16) + }, + ).align(AlignX.CENTER) + }.topGap(TopGap.SMALL) row { + val installHint = if (System.getProperty("os.name").lowercase().contains("mac")) { + "Install llama.cpp: brew install llama.cpp" + } else { + "Install llama.cpp from github.com/ggml-org/llama.cpp" + } text( """ -
- Sweep features:
- Agent
- Autocomplete +
+ Quick start:
+ 1. $installHint
+ 2. Click "Start Local Server" above
+ 3. Start editing code — suggestions appear automatically
+
+ Choose model size and runtime in Settings.
""".trimIndent(), ).align(AlignX.CENTER) diff --git a/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt b/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt index ef9a78d..97ca254 100644 --- a/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt +++ b/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt @@ -194,6 +194,8 @@ class Stream( private var markdownDisplay: MarkdownDisplay? = null private var connection: HttpURLConnection? = null + @Volatile + private var activeAgentService: dev.sweep.assistant.services.OpenAIAgentService? = null private var shouldHideStopButton: Boolean = true private var streamingJob: Job? by Delegates.observable(null) { _, oldJob, newJob -> // update listeners when stream state has changed @@ -275,6 +277,227 @@ class Stream( SweepMetaData.getInstance().chatWithoutSearch++ var finalMessage: Message? = null + + // For non-Sweep backends (LM Studio, Ollama, etc.), use OpenAI-compatible API + if (!dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) { + try { + currentMarkdownDisplay.startStreaming() + var selectedModel = SweepComponent.getSelectedModelId(project) + ?: sessionMessageList.selectedModel + // "auto" is a Sweep-specific concept — for OpenAI endpoints, use the first available model + if (selectedModel == null || selectedModel == "auto" || selectedModel.isEmpty()) { + val available = dev.sweep.assistant.services.OpenAIChatService.getInstance().fetchModels() + selectedModel = available.values.firstOrNull() ?: "" + } + + logger.info("[OpenAI] Starting chat with model='$selectedModel', ${sessionMessageList.snapshot().size} messages") + + // Build context from current file, included files, and editor state + val contextParts = mutableListOf() + + // Get cursor position for current file + val cursorInfo = com.intellij.openapi.application.ApplicationManager.getApplication().runReadAction { + try { + val editor = com.intellij.openapi.fileEditor.FileEditorManager.getInstance(project) + .selectedTextEditor + if (editor != null) { + val offset = editor.caretModel.offset + val line = editor.document.getLineNumber(offset) + 1 + val col = offset - editor.document.getLineStartOffset(line - 1) + 1 + "Cursor is at line $line, column $col" + } else null + } catch (_: Exception) { null } + } + + // Add current file context with cursor info + if (effectiveCurrentFilePath != null) { + val currentFileContent = readFile(project, effectiveCurrentFilePath, maxLines = 500, maxChars = 50000) + if (currentFileContent != null) { + val lang = effectiveCurrentFilePath.substringAfterLast('.', "") + val header = "Current file: `$effectiveCurrentFilePath`" + + (cursorInfo?.let { " ($it)" } ?: "") + contextParts.add("$header\n```$lang\n$currentFileContent\n```") + } + } + + // Add included/mentioned files + val addedFiles = mutableSetOf(effectiveCurrentFilePath) + for (value in includedFiles.values) { + val filePath = TutorialPage.normalizeTutorialPath(value) + if (filePath in addedFiles) continue + val fileContent = readFile(project, filePath, maxLines = 500, maxChars = 50000) + if (fileContent != null) { + val lang = filePath.substringAfterLast('.', "") + contextParts.add("File: `$filePath`\n```$lang\n$fileContent\n```") + addedFiles.add(filePath) + } + } + + // Add mentioned files from the latest user message + sessionMessageList.getLastUserMessage()?.mentionedFiles?.forEach { mentionedFile -> + val filePath = TutorialPage.normalizeTutorialPath(mentionedFile.relativePath) + if (filePath !in addedFiles) { + val fileContent = readFile(project, filePath, maxLines = 500, maxChars = 50000) + if (fileContent != null) { + val lang = filePath.substringAfterLast('.', "") + contextParts.add("File: `$filePath`\n```$lang\n$fileContent\n```") + addedFiles.add(filePath) + } + } + } + + // Get working directory and repo info + val workingDir = project.basePath ?: "" + + // Build system message with context + val isAgentMode = currentMode.lowercase() == "agent" + val systemPrompt = buildString { + if (isAgentMode) { + append("You are an expert coding agent integrated into a JetBrains IDE. ") + append("You can read files, search code, edit files, and run shell commands to help the user.\n\n") + append("When you need to make changes, use the available tools. ") + append("Read the relevant files first, then make targeted edits using str_replace.\n\n") + } else { + append("You are an expert coding assistant integrated into a JetBrains IDE. ") + } + append("You help the user understand, modify, and debug code.\n\n") + append("Guidelines:\n") + append("- When suggesting code changes, always include the complete modified code in a fenced code block with the language specified.\n") + append("- Include the file path as a comment at the top of the code block, e.g. `// filepath: src/main.py`\n") + append("- Be concise. Explain what you changed and why.\n") + append("- If the user asks to fix a bug, explain the root cause before the fix.\n") + append("- When showing diffs, use unified diff format.\n") + append("- IMPORTANT: Never use LaTeX math notation like $...$ or \\\\(...\\\\). The output is rendered as plain markdown without LaTeX support. Use Unicode math symbols instead: × ÷ ± ≤ ≥ ≠ ≈ √ ∑ ∏ ∫ π θ α β γ ² ³ ⁴ ₀ ₁ ₂ → ← ∈ ∉ ∅ ∞ ∧ ∨.\n") + if (workingDir.isNotEmpty()) { + append("\nProject directory: $workingDir\n") + } + if (isAgentMode) { + // In agent mode, only mention file names — the model can read_file if needed + val fileNames = mutableListOf() + if (effectiveCurrentFilePath != null) fileNames.add(effectiveCurrentFilePath) + for (value in includedFiles.values) { + val fp = TutorialPage.normalizeTutorialPath(value) + if (fp != effectiveCurrentFilePath) fileNames.add(fp) + } + if (fileNames.isNotEmpty()) { + append("\nThe user currently has these files open:\n") + fileNames.forEach { append("- $it\n") } + append("\nUse read_file to examine their contents when needed.\n") + } + } else if (contextParts.isNotEmpty()) { + // In chat mode, include full file contents for context + append("\nThe user is working with the following files:\n\n") + append(contextParts.joinToString("\n\n")) + } + } + + // Convert messages to OpenAI format, prepending system message + val openAIMessages = mutableListOf( + dev.sweep.assistant.services.OpenAIChatService.ChatMessage("system", systemPrompt), + ) + openAIMessages.addAll(sessionMessageList.snapshot().map { msg -> + dev.sweep.assistant.services.OpenAIChatService.ChatMessage( + role = msg.role.name.lowercase(), + content = msg.content, + ) + }) + + // Create empty assistant message with empty annotations for code replacements + val assistantMessage = Message( + role = MessageRole.ASSISTANT, + content = "", + annotations = dev.sweep.assistant.data.Annotations(), + ) + onMessageUpdated(assistantMessage) + + logger.info("[OpenAI] System prompt length=${systemPrompt.length}, mode=$currentMode, agent=$isAgentMode, sending ${openAIMessages.size} messages to ${dev.sweep.assistant.settings.SweepSettings.getInstance().baseUrl}") + + if (isAgentMode) { + // Agent mode: multi-turn loop with tool calling + // Only send system prompt + the latest user message to avoid + // the model trying to redo earlier chat-mode responses with tools + val agentService = dev.sweep.assistant.services.OpenAIAgentService(project) + activeAgentService = agentService + val latestUserMessage = openAIMessages.lastOrNull { it.role == "user" } + val agentMessages = mutableListOf( + dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("system", systemPrompt), + ) + if (latestUserMessage != null) { + agentMessages.add( + dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("user", latestUserMessage.content) + ) + } + + agentService.runAgentLoop( + messages = agentMessages, + model = selectedModel ?: "", + onTextChunk = { chunk -> + assistantMessage.content += chunk + onMessageUpdated(assistantMessage) + }, + onToolCallStart = { toolName, params -> + assistantMessage.content += "\n\n> **Using tool:** `$toolName`" + if (params.isNotEmpty()) { + val paramsStr = params.entries.joinToString(", ") { "${it.key}=${it.value.take(50)}" } + assistantMessage.content += " ($paramsStr)" + } + assistantMessage.content += "\n" + onMessageUpdated(assistantMessage) + }, + onToolCallResult = { toolName, result, success -> + val statusEmoji = if (success) "+" else "!" + val preview = result.take(200).replace("\n", " ") + assistantMessage.content += "> [$statusEmoji] `$toolName` result: $preview\n" + onMessageUpdated(assistantMessage) + }, + onDone = { + logger.info("[OpenAI Agent] Loop completed, total length=${assistantMessage.content.length}") + }, + onError = { e -> + logger.warn("[OpenAI Agent] Error: ${e.message}", e) + assistantMessage.content += "\n\n**Error:** ${e.message}" + onMessageUpdated(assistantMessage) + }, + isCancelled = { cancelledByUser }, + conversationId = currentConversationId, + ) + } else { + // Chat mode: simple streaming without tool calling + dev.sweep.assistant.services.OpenAIChatService.getInstance().streamChatCompletion( + messages = openAIMessages, + model = selectedModel ?: "", + onChunk = { chunk -> + assistantMessage.content += chunk + onMessageUpdated(assistantMessage) + }, + onDone = { + logger.info("[OpenAI] Stream completed, total length=${assistantMessage.content.length}") + }, + onError = { e -> + logger.warn("[OpenAI] Stream error: ${e.message}", e) + assistantMessage.content += "\n\n**Error:** ${e.message}" + onMessageUpdated(assistantMessage) + }, + isCancelled = { cancelledByUser }, + ) + } + + // Post-process: extract code blocks and generate codeReplacements for Apply button + enrichMessageWithCodeReplacements(assistantMessage) + + // Final update + onMessageUpdated(assistantMessage) + currentMarkdownDisplay.stopStreaming() + // Notify stream state service that streaming is done (hides stop button) + StreamStateService.getInstance(project).notify(false, false, false, currentConversationId) + } catch (e: Exception) { + logger.warn("[OpenAI] Chat failed: ${e.message}", e) + currentMarkdownDisplay.stopStreaming() + StreamStateService.getInstance(project).notify(false, false, false, currentConversationId) + } + return + } + try { // Use longer timeouts for streaming chat requests // - 30s to establish connection (standard) @@ -1115,7 +1338,54 @@ class Stream( return latestMessage } + /** + * Post-process an assistant message: + * 1. Extract ... blocks into reasoning annotations + * 2. Generate CodeReplacement annotations from code blocks for the "Apply" button + */ + private fun enrichMessageWithCodeReplacements(message: Message) { + try { + // Extract blocks (used by Qwen, DeepSeek, etc.) + val thinkPattern = """([\s\S]*?)""".toRegex() + val thinkMatch = thinkPattern.find(message.content) + if (thinkMatch != null) { + val thinkingContent = thinkMatch.groupValues[1].trim() + message.content = message.content.replace(thinkMatch.value, "").trimStart('\n') + // Store thinking in annotations via the mutable codeReplacements list + // (thinking is val so we can't set it directly on existing message) + // The thinking content will be visible as a quote block instead + if (thinkingContent.isNotEmpty()) { + message.content = "
Reasoning\n\n$thinkingContent\n\n
\n\n${message.content}" + } + } + + // Generate code replacement annotations for the Apply button + val ann = message.annotations ?: return + val blocks = dev.sweep.assistant.views.parseMarkdownBlocks(message, project) + val codeBlocks = blocks.filterIsInstance() + if (codeBlocks.isEmpty()) return + + codeBlocks.forEachIndexed { index, block -> + if (block.path.isNotEmpty() && block.code.isNotEmpty()) { + message.annotations?.codeReplacements?.add( + dev.sweep.assistant.data.CodeReplacement( + codeBlockIndex = index, + filePath = block.path, + codeBlockContent = block.code, + ) + ) + } + } + } catch (e: Exception) { + logger.debug("[OpenAI] Failed to enrich message: ${e.message}") + } + } + fun stop(isUserInitiated: Boolean = true) { + // Cancel any active agent HTTP connection + activeAgentService?.cancelActiveRequest() + activeAgentService = null + logger.info( "[Stream.stop] Stopping stream: conversationId=$sessionConversationId, isUserInitiated=$isUserInitiated, hasJob=${streamingJob != null}", ) diff --git a/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt b/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt new file mode 100644 index 0000000..28f51c7 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt @@ -0,0 +1,490 @@ +package dev.sweep.assistant.services + +import com.google.gson.Gson +import com.google.gson.JsonParser +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.project.Project +import dev.sweep.assistant.agent.tools.ToolType +import dev.sweep.assistant.data.CompletedToolCall +import dev.sweep.assistant.data.ToolCall +import dev.sweep.assistant.settings.SweepSettings +import java.io.BufferedReader +import java.io.InputStreamReader +import java.io.OutputStreamWriter +import java.net.HttpURLConnection +import java.net.URL + +/** + * OpenAI-compatible agent service with function calling support. + * Executes a multi-turn agent loop: sends messages → model responds with tool calls + * → executes tools → sends results back → repeats until model gives a text response. + */ +class OpenAIAgentService(private val project: Project) { + private val logger = Logger.getInstance(OpenAIAgentService::class.java) + private val gson = Gson() + + @Volatile + private var activeConnection: HttpURLConnection? = null + + /** Force-close any in-flight HTTP connection (called from stop button thread). */ + fun cancelActiveRequest() { + activeConnection?.let { + try { it.disconnect() } catch (_: Exception) {} + activeConnection = null + } + } + + companion object { + private const val MAX_AGENT_TURNS = 20 + + /** OpenAI function definitions for the core tools. */ + val TOOL_DEFINITIONS = listOf( + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "read_file", + "description" to "Read the contents of a file. For large files, use offset and limit to read specific sections. If the file is very large, the output will be truncated — use offset/limit to read the parts you need.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "path" to mapOf("type" to "string", "description" to "File path relative to the project root"), + "offset" to mapOf("type" to "integer", "description" to "Starting line number (1-based). Default: 1"), + "limit" to mapOf("type" to "integer", "description" to "Number of lines to read. Recommended: 200 for large files. Default: entire file"), + ), + "required" to listOf("path"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "search_files", + "description" to "Search for a regex pattern across files in the project. Returns matching lines with file paths and line numbers.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "regex" to mapOf("type" to "string", "description" to "Regular expression pattern to search for"), + "path" to mapOf("type" to "string", "description" to "Directory to search in (optional)"), + "glob" to mapOf("type" to "string", "description" to "File pattern filter, e.g. '*.py' (optional)"), + ), + "required" to listOf("regex"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "glob", + "description" to "Find files matching a glob pattern. Returns a list of file paths.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "pattern" to mapOf("type" to "string", "description" to "Glob pattern, e.g. '**/*.kt', 'src/*.py'"), + "path" to mapOf("type" to "string", "description" to "Directory to search in (optional)"), + ), + "required" to listOf("pattern"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "list_files", + "description" to "List files and directories in a directory tree.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "path" to mapOf("type" to "string", "description" to "Directory path to list"), + "max_depth" to mapOf("type" to "integer", "description" to "Maximum recursion depth (optional)"), + ), + "required" to listOf("path"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "str_replace", + "description" to "Replace an exact string in a file with new content. The old_str must match exactly (including whitespace).", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "path" to mapOf("type" to "string", "description" to "File path"), + "old_str" to mapOf("type" to "string", "description" to "Exact string to find and replace"), + "new_str" to mapOf("type" to "string", "description" to "Replacement string"), + ), + "required" to listOf("path", "old_str", "new_str"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "create_file", + "description" to "Create a new file with the given content.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "path" to mapOf("type" to "string", "description" to "File path to create"), + "content" to mapOf("type" to "string", "description" to "File content"), + ), + "required" to listOf("path", "content"), + ), + ), + ), + mapOf( + "type" to "function", + "function" to mapOf( + "name" to "bash", + "description" to "Execute a shell command and return the output. Use for running tests, git commands, installations, etc.", + "parameters" to mapOf( + "type" to "object", + "properties" to mapOf( + "command" to mapOf("type" to "string", "description" to "Shell command to execute"), + "timeout" to mapOf("type" to "integer", "description" to "Timeout in seconds (default: 300, max: 1800)"), + ), + "required" to listOf("command"), + ), + ), + ), + ) + } + + data class AgentMessage( + val role: String, + val content: String?, + val toolCallId: String? = null, + val name: String? = null, + val toolCalls: List? = null, // For assistant messages with tool calls + ) + + /** + * Run the agent loop: stream responses, execute tool calls, send results back. + * + * @param messages Initial conversation messages (including system prompt) + * @param model The model ID + * @param onTextChunk Called with each text chunk for live display + * @param onToolCallStart Called when a tool execution starts (for UI feedback) + * @param onToolCallResult Called when a tool finishes (for UI feedback) + * @param onDone Called when the agent loop completes + * @param onError Called on error + * @param isCancelled Check for cancellation + * @param conversationId The conversation ID for tool execution context + */ + fun runAgentLoop( + messages: MutableList, + model: String, + onTextChunk: (String) -> Unit, + onToolCallStart: (String, Map) -> Unit, + onToolCallResult: (String, String, Boolean) -> Unit, + onDone: () -> Unit, + onError: (Exception) -> Unit, + isCancelled: () -> Boolean, + conversationId: String, + ) { + var turns = 0 + var lastToolCallSignature = "" // Track last tool call to detect exact loops + var sameToolNameCount = 0 // Track consecutive calls to same tool + var lastToolName = "" + + while (turns < MAX_AGENT_TURNS && !isCancelled()) { + turns++ + logger.info("[Agent] Turn $turns, ${messages.size} messages") + + // Log the last few messages to debug tool call loop issues + val recentMessages = messages.takeLast(3) + recentMessages.forEach { msg -> + logger.info("[Agent] Message: role=${msg.role}, content=${msg.content?.take(100) ?: "null"}, toolCalls=${msg.toolCalls?.size ?: 0}, toolCallId=${msg.toolCallId}") + } + + val response = try { + streamWithToolCalls(messages, model, onTextChunk, isCancelled) + } catch (e: Exception) { + if (isCancelled()) { onDone(); return } + onError(e) + return + } + + if (isCancelled()) { + onDone() + return + } + + when { + response.toolCalls.isNotEmpty() -> { + // Add assistant message WITH tool_calls to conversation + // The OpenAI API requires the assistant message to echo back the tool calls + messages.add(AgentMessage( + role = "assistant", + content = response.textContent, + toolCalls = response.toolCalls, + )) + + // Check for repeated tool call loops + val currentSignature = response.toolCalls.joinToString("|") { "${it.name}:${it.arguments}" } + if (currentSignature == lastToolCallSignature) { + sameToolNameCount++ // reuse counter for exact match tracking + if (sameToolNameCount >= 3) { + logger.warn("[Agent] Detected exact tool call loop (3x), breaking") + onTextChunk("\n\n[Agent detected repeated tool call, stopping]") + onDone() + return + } + } else { + sameToolNameCount = 0 + } + lastToolCallSignature = currentSignature + + // Also track consecutive same-tool-name calls with different args + val primaryToolName = response.toolCalls.firstOrNull()?.name ?: "" + if (primaryToolName == lastToolName) { + // already tracked via sameToolNameCount above + } else { + lastToolName = primaryToolName + } + + // Execute each tool call + for (tc in response.toolCalls) { + if (isCancelled()) break + + logger.info("[Agent] Executing tool: ${tc.name}(${tc.arguments.keys.joinToString(", ")})") + onToolCallStart(tc.name, tc.arguments) + + val result = executeTool(tc.name, tc.arguments, conversationId) + + logger.info("[Agent] Tool ${tc.name} result: ${result.resultString.take(100)}...") + onToolCallResult(tc.name, result.resultString, result.status) + + // Add tool result to conversation + val maxToolResultSize = 8000 + val toolContent = if (result.resultString.length > maxToolResultSize) { + result.resultString.take(maxToolResultSize) + + "\n\n[Output truncated at $maxToolResultSize chars. " + + "Total: ${result.resultString.length} chars. " + + "Use offset/limit parameters to read specific sections.]" + } else { + result.resultString + } + messages.add(AgentMessage( + role = "tool", + content = toolContent, + toolCallId = tc.id, + )) + } + } + response.textContent?.isNotEmpty() == true -> { + onDone() + return + } + else -> { + onDone() + return + } + } + } + + if (turns >= MAX_AGENT_TURNS) { + onTextChunk("\n\n[Agent stopped after $MAX_AGENT_TURNS turns]*") + } + onDone() + } + + data class ParsedToolCall(val id: String, val name: String, val arguments: Map, val rawArguments: String = "") + data class StreamResponse(val textContent: String?, val toolCalls: List) + + /** + * Stream a single chat completion, returning text content and any tool calls. + */ + private fun streamWithToolCalls( + messages: List, + model: String, + onTextChunk: (String) -> Unit, + isCancelled: () -> Boolean, + ): StreamResponse { + val baseUrl = SweepSettings.getInstance().baseUrl.trimEnd('/') + val apiKey = SweepSettings.getInstance().githubToken + + // Build messages JSON with proper tool_calls format + val messagesJson = messages.map { msg -> + val obj = mutableMapOf("role" to msg.role) + // Always include content (null for assistant messages with tool calls) + obj["content"] = msg.content ?: "" + if (msg.toolCallId != null) obj["tool_call_id"] = msg.toolCallId + // Don't include "name" on tool messages — not part of OpenAI spec + // if (msg.name != null) obj["name"] = msg.name + if (msg.toolCalls != null && msg.toolCalls.isNotEmpty()) { + obj["tool_calls"] = msg.toolCalls.map { tc -> + mapOf( + "id" to tc.id, + "type" to "function", + "function" to mapOf( + "name" to tc.name, + // arguments must be a JSON string, not an object + "arguments" to (tc.rawArguments.ifEmpty { gson.toJson(tc.arguments) }), + ), + ) + } + } + obj + } + + val requestMap = mutableMapOf( + "model" to model, + "messages" to messagesJson, + "stream" to true, + "tools" to TOOL_DEFINITIONS, + ) + + val requestJson = gson.toJson(requestMap) + + // Log the full request JSON for debugging tool call issues + logger.info("[Agent] Request JSON (first 2000 chars): ${requestJson.take(2000)}") + // Log messages portion specifically + val messagesOnlyJson = gson.toJson(messagesJson) + logger.info("[Agent] Messages JSON (last 1000 chars): ...${messagesOnlyJson.takeLast(1000)}") + + val url = URL("$baseUrl/v1/chat/completions") + val conn = (url.openConnection() as HttpURLConnection).apply { + requestMethod = "POST" + doOutput = true + connectTimeout = 30_000 + readTimeout = 300_000 + setRequestProperty("Content-Type", "application/json") + if (apiKey.isNotBlank()) { + setRequestProperty("Authorization", "Bearer $apiKey") + } + } + + activeConnection = conn + OutputStreamWriter(conn.outputStream).use { it.write(requestJson); it.flush() } + + val status = conn.responseCode + if (status != 200) { + val errorBody = (conn.errorStream ?: conn.inputStream)?.bufferedReader()?.readText() ?: "" + conn.disconnect() + activeConnection = null + throw RuntimeException("HTTP $status: $errorBody") + } + + val textParts = StringBuilder() + // Accumulate tool calls from deltas + val toolCallAccumulator = mutableMapOf>() // index -> {id, name, arguments} + + BufferedReader(InputStreamReader(conn.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + if (isCancelled()) { conn.disconnect(); break } + + val l = line ?: continue + if (!l.startsWith("data: ")) continue + val data = l.removePrefix("data: ").trim() + if (data == "[DONE]") break + + try { + val event = JsonParser.parseString(data).asJsonObject + val choices = event.getAsJsonArray("choices") ?: continue + if (choices.size() == 0) continue + val choice = choices[0].asJsonObject + val delta = choice.getAsJsonObject("delta") ?: continue + + // Text content + val content = delta.get("content")?.asString + if (content != null) { + textParts.append(content) + onTextChunk(content) + } + + // Tool calls + val toolCalls = delta.getAsJsonArray("tool_calls") + if (toolCalls != null) { + for (tc in toolCalls) { + val tcObj = tc.asJsonObject + val index = tcObj.get("index")?.asInt ?: 0 + val acc = toolCallAccumulator.getOrPut(index) { mutableMapOf("id" to "", "name" to "", "arguments" to "") } + + tcObj.get("id")?.asString?.let { acc["id"] = it } + val function = tcObj.getAsJsonObject("function") + if (function != null) { + function.get("name")?.asString?.let { acc["name"] = it } + function.get("arguments")?.asString?.let { acc["arguments"] = acc["arguments"]!! + it } + } + } + } + } catch (_: Exception) {} + } + } + + conn.disconnect() + activeConnection = null + + // Parse accumulated tool calls + val parsedToolCalls = toolCallAccumulator.values.mapNotNull { acc -> + val name = acc["name"] ?: return@mapNotNull null + val id = acc["id"] ?: return@mapNotNull null + if (name.isEmpty()) return@mapNotNull null + + val argsJson = acc["arguments"] ?: "{}" + val args = try { + val parsed = JsonParser.parseString(argsJson).asJsonObject + parsed.entrySet().associate { entry -> + // Handle all JSON value types, not just strings + val value = when { + entry.value.isJsonPrimitive -> { + val prim = entry.value.asJsonPrimitive + when { + prim.isString -> prim.asString + prim.isNumber -> prim.asNumber.toString() + prim.isBoolean -> prim.asBoolean.toString() + else -> prim.toString() + } + } + else -> entry.value.toString() // Arrays, objects → raw JSON string + } + entry.key to value + } + } catch (e: Exception) { + logger.warn("[Agent] Failed to parse tool arguments: $argsJson — ${e.message}") + mapOf() + } + + logger.info("[Agent] Parsed tool call: $name($args) from JSON: ${argsJson.take(200)}") + ParsedToolCall(id, name, args, argsJson) + } + + return StreamResponse(textParts.toString().ifEmpty { null }, parsedToolCalls) + } + + /** + * Execute a tool using the existing plugin tool implementations. + */ + private fun executeTool(name: String, arguments: Map, conversationId: String): CompletedToolCall { + return try { + val tool = ToolType.createToolInstance(name, false) + if (tool == null) { + CompletedToolCall( + toolCallId = "", + toolName = name, + resultString = "Unknown tool: $name", + status = false, + ) + } else { + val toolCall = ToolCall( + toolCallId = java.util.UUID.randomUUID().toString(), + toolName = name, + toolParameters = arguments, + rawText = "", + fullyFormed = true, + ) + tool.execute(toolCall, project, conversationId) + } + } catch (e: Exception) { + logger.warn("[Agent] Tool execution error for $name: ${e.message}", e) + CompletedToolCall( + toolCallId = "", + toolName = name, + resultString = "Error executing $name: ${e.message}", + status = false, + ) + } + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/services/OpenAIChatService.kt b/src/main/kotlin/dev/sweep/assistant/services/OpenAIChatService.kt new file mode 100644 index 0000000..0841605 --- /dev/null +++ b/src/main/kotlin/dev/sweep/assistant/services/OpenAIChatService.kt @@ -0,0 +1,176 @@ +package dev.sweep.assistant.services + +import com.google.gson.JsonParser +import com.intellij.openapi.components.Service +import com.intellij.openapi.diagnostic.Logger +import dev.sweep.assistant.settings.SweepSettings +import java.io.BufferedReader +import java.io.InputStreamReader +import java.io.OutputStreamWriter +import java.net.HttpURLConnection +import java.net.URI +import java.net.URL + +/** + * Service for communicating with OpenAI-compatible API endpoints (LM Studio, Ollama, etc.) + * Handles both model listing and chat completions. + * + * Uses HttpURLConnection instead of java.net.http.HttpClient because the latter + * has known timeout issues with some local HTTP servers on macOS. + */ +@Service(Service.Level.APP) +class OpenAIChatService { + private val logger = Logger.getInstance(OpenAIChatService::class.java) + + companion object { + fun getInstance(): OpenAIChatService = + com.intellij.openapi.application.ApplicationManager.getApplication() + .getService(OpenAIChatService::class.java) + + /** Check if the configured baseUrl looks like a Sweep backend vs a generic OpenAI endpoint. */ + fun isSweepBackend(): Boolean { + val url = SweepSettings.getInstance().baseUrl + return url.contains("sweep.dev") || url.contains("sweep-") + } + } + + /** + * Fetch available models from /v1/models endpoint. + * Returns a map of displayName -> modelId. + */ + fun fetchModels(): Map { + val baseUrl = SweepSettings.getInstance().baseUrl.trimEnd('/') + if (baseUrl.isEmpty()) return emptyMap() + + val apiKey = SweepSettings.getInstance().githubToken + + return try { + val url = URL("$baseUrl/v1/models") + val conn = (url.openConnection() as HttpURLConnection).apply { + requestMethod = "GET" + connectTimeout = 10_000 + readTimeout = 10_000 + setRequestProperty("Content-Type", "application/json") + if (apiKey.isNotBlank()) { + setRequestProperty("Authorization", "Bearer $apiKey") + } + } + + val status = conn.responseCode + if (status != 200) { + logger.warn("Failed to fetch models from $baseUrl/v1/models: HTTP $status") + conn.disconnect() + return emptyMap() + } + + val body = conn.inputStream.bufferedReader().use { it.readText() } + conn.disconnect() + + val json = JsonParser.parseString(body).asJsonObject + val data = json.getAsJsonArray("data") ?: return emptyMap() + + val models = mutableMapOf() + for (element in data) { + val model = element.asJsonObject + val id = model.get("id")?.asString ?: continue + val displayName = id + .removePrefix("models/") + .removeSuffix(".gguf") + models[displayName] = id + } + + logger.info("Fetched ${models.size} models from $baseUrl/v1/models") + models + } catch (e: Exception) { + logger.warn("Error fetching models from $baseUrl/v1/models: ${e.message}") + emptyMap() + } + } + + data class ChatMessage(val role: String, val content: String) + + /** + * Stream a chat completion from an OpenAI-compatible endpoint. + * Calls the callback with each text chunk as it arrives. + */ + fun streamChatCompletion( + messages: List, + model: String, + onChunk: (String) -> Unit, + onDone: () -> Unit, + onError: (Exception) -> Unit, + isCancelled: () -> Boolean = { false }, + ) { + val baseUrl = SweepSettings.getInstance().baseUrl.trimEnd('/') + val apiKey = SweepSettings.getInstance().githubToken + + // Build JSON request manually to avoid Gson dependency issues + val messagesJson = messages.joinToString(",") { msg -> + """{"role":"${msg.role}","content":${com.google.gson.Gson().toJson(msg.content)}}""" + } + val requestJson = """{"model":${com.google.gson.Gson().toJson(model)},"messages":[$messagesJson],"stream":true}""" + + try { + val url = URL("$baseUrl/v1/chat/completions") + val conn = (url.openConnection() as HttpURLConnection).apply { + requestMethod = "POST" + doOutput = true + connectTimeout = 30_000 + readTimeout = 300_000 // 5 minutes for long responses + setRequestProperty("Content-Type", "application/json") + if (apiKey.isNotBlank()) { + setRequestProperty("Authorization", "Bearer $apiKey") + } + } + + OutputStreamWriter(conn.outputStream).use { writer -> + writer.write(requestJson) + writer.flush() + } + + val status = conn.responseCode + if (status != 200) { + val errorBody = (conn.errorStream ?: conn.inputStream)?.bufferedReader()?.readText() ?: "" + conn.disconnect() + onError(RuntimeException("HTTP $status: $errorBody")) + return + } + + BufferedReader(InputStreamReader(conn.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + // Check cancellation on each line + if (isCancelled()) { + logger.info("Chat stream cancelled by user") + conn.disconnect() + onDone() + return + } + + val l = line ?: continue + if (!l.startsWith("data: ")) continue + val data = l.removePrefix("data: ").trim() + if (data == "[DONE]") break + + try { + val event = JsonParser.parseString(data).asJsonObject + val choices = event.getAsJsonArray("choices") ?: continue + if (choices.size() == 0) continue + val delta = choices[0].asJsonObject.getAsJsonObject("delta") ?: continue + val content = delta.get("content")?.asString + if (content != null) { + onChunk(content) + } + } catch (_: Exception) { + // Skip malformed SSE events + } + } + } + + conn.disconnect() + onDone() + } catch (e: Exception) { + onError(e) + } + } +} diff --git a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt index 0e2197a..873dfa5 100644 --- a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt +++ b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt @@ -242,7 +242,10 @@ class SweepSettings : PersistentStateComponent { */ val hasBeenSet: Boolean get() = - if (SweepSettingsParser.isCloudEnvironment()) { + // Local autocomplete mode with native engine doesn't need any API tokens + if (autocompleteLocalMode && autocompleteLocalNativeEngine) { + true + } else if (SweepSettingsParser.isCloudEnvironment()) { githubToken != DEFAULT_GITHUB_TOKEN } else { githubToken != DEFAULT_GITHUB_TOKEN && baseUrl != DEFAULT_SWEEP_URL diff --git a/src/main/kotlin/dev/sweep/assistant/views/MarkdownBlock.kt b/src/main/kotlin/dev/sweep/assistant/views/MarkdownBlock.kt index 8d3b389..1b41ba2 100644 --- a/src/main/kotlin/dev/sweep/assistant/views/MarkdownBlock.kt +++ b/src/main/kotlin/dev/sweep/assistant/views/MarkdownBlock.kt @@ -165,13 +165,23 @@ fun parseMarkdownBlocks( if (path.isEmpty() && code.isNotEmpty() && language != "bash") { val lines = code.split("\n") val firstLine = lines.first().trim() - val filePathPattern = - """^(?:[a-zA-Z]:)?(?:[\/\\])?(?:[^\s\/\\:*?"<>|]+[\/\\])*[^\s\/\\:*?"<>|]+\.[a-zA-Z0-9]{1,10}$""" - .toRegex() - if (filePathPattern.matches(firstLine)) { - pathFromCode = firstLine + // Check for "// filepath: path" or "# filepath: path" comment pattern + val filepathCommentPattern = """^(?://|#)\s*filepath:\s*(.+)$""".toRegex(RegexOption.IGNORE_CASE) + val filepathMatch = filepathCommentPattern.matchEntire(firstLine) + + if (filepathMatch != null) { + pathFromCode = filepathMatch.groupValues[1].trim() code = lines.drop(1).joinToString("\n") + } else { + val filePathPattern = + """^(?:[a-zA-Z]:)?(?:[\/\\])?(?:[^\s\/\\:*?"<>|]+[\/\\])*[^\s\/\\:*?"<>|]+\.[a-zA-Z0-9]{1,10}$""" + .toRegex() + + if (filePathPattern.matches(firstLine)) { + pathFromCode = firstLine + code = lines.drop(1).joinToString("\n") + } } } diff --git a/src/main/kotlin/dev/sweep/assistant/views/ModelPickerMenu.kt b/src/main/kotlin/dev/sweep/assistant/views/ModelPickerMenu.kt index cf702bb..28a8b0c 100644 --- a/src/main/kotlin/dev/sweep/assistant/views/ModelPickerMenu.kt +++ b/src/main/kotlin/dev/sweep/assistant/views/ModelPickerMenu.kt @@ -27,20 +27,27 @@ import javax.swing.JPanel class ModelPickerMenu( private val project: Project, parentDisposable: Disposable, - private var models: Map = DEFAULT_MODELS, - private var initialModel: String = DEFAULT_MODEL, + private var models: Map = getInitialModels(), + private var initialModel: String = getInitialDefaultModel(), ) : JPanel(), Disposable { companion object { - // Fallback models in case backend is down - private val DEFAULT_MODELS = + private val SWEEP_DEFAULT_MODELS = mapOf( - // Pass through the special "auto" key when Auto is selected "Auto" to "auto", "Sonnet 4 (thinking)" to "claude-sonnet-4-20250514:thinking", "Sonnet 4" to "claude-sonnet-4-20250514", ) - private const val DEFAULT_MODEL = "Auto" + private val LOCAL_DEFAULT_MODELS = + mapOf("(loading models...)" to "") + + fun getInitialModels(): Map = + if (dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) SWEEP_DEFAULT_MODELS + else LOCAL_DEFAULT_MODELS + + fun getInitialDefaultModel(): String = + if (dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) "Auto" + else "(loading models...)" // Cache for model display names to IDs mapping private var modelIdCache: Map = mapOf() @@ -68,8 +75,14 @@ class ModelPickerMenu( comboBox.setItemTooltips(SweepConstants.MODEL_HINTS) comboBox.isTransparent = true - // Try to load cached models first - loadCachedModels() + // Only load cached Sweep models for Sweep backends; clear stale cache otherwise + if (dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) { + loadCachedModels() + } else { + // Clear any cached Sweep models from a previous config + SweepMetaData.getInstance().cachedModels = null + SweepMetaData.getInstance().cachedDefaultModel = null + } // Check if there's a saved model preference val savedModel = SweepComponent.getSelectedModel(project) @@ -135,25 +148,28 @@ class ModelPickerMenu( } private fun updateComboBoxModel() { - val favorites = SweepMetaData.getInstance().favoriteModels - val validFavorites = favorites.filter { models.keys.contains(it) } - - // Determine which models to show - fall back to all models if favorites are empty - var modelNames = - validFavorites.ifEmpty { - models.keys.toList() - } + val isSweep = dev.sweep.assistant.services.OpenAIChatService.isSweepBackend() + + val modelNames = if (isSweep) { + val favorites = SweepMetaData.getInstance().favoriteModels + val validFavorites = favorites.filter { models.keys.contains(it) } + validFavorites.ifEmpty { models.keys.toList() } + } else { + // For non-Sweep backends, just show all models from /v1/models + models.keys.toList() + } - // Add options at the end - val options = modelNames + configureFavoritesOption + // Only show "+ More models" for Sweep backends + val options = if (isSweep) modelNames + configureFavoritesOption else modelNames comboBox.setOptions(options) - // Ensure we never show "+ More models" as selected - it's only an action option comboBox.selectedItem = if (currentModel == configureFavoritesOption) { modelNames.firstOrNull() ?: currentModel - } else { + } else if (modelNames.contains(currentModel)) { currentModel + } else { + modelNames.firstOrNull() ?: currentModel } } @@ -222,6 +238,12 @@ class ModelPickerMenu( } private fun fetchAllowedModels() { + // For non-Sweep backends (LM Studio, Ollama, etc.), fetch from /v1/models + if (!dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) { + fetchOpenAIModels() + return + } + coroutineScope.launch { try { var connection: HttpURLConnection? = null @@ -331,6 +353,46 @@ class ModelPickerMenu( } } + /** + * Fetch models from an OpenAI-compatible /v1/models endpoint. + */ + private fun fetchOpenAIModels() { + val baseUrl = SweepSettings.getInstance().baseUrl + logger.info("Fetching models from OpenAI-compatible endpoint: $baseUrl/v1/models") + + coroutineScope.launch { + try { + val fetchedModels = dev.sweep.assistant.services.OpenAIChatService.getInstance().fetchModels() + logger.info("OpenAI model fetch returned ${fetchedModels.size} models: ${fetchedModels.keys.take(5)}") + if (fetchedModels.isNotEmpty()) { + ApplicationManager.getApplication().invokeLater { + modelIdCache = fetchedModels + lastFetchTime = System.currentTimeMillis() + models = fetchedModels + defaultModelFromBackend = fetchedModels.keys.first() + + // Check for saved model preference + val savedModel = SweepComponent.getSelectedModel(project) + if (savedModel.isNotEmpty() && models.keys.contains(savedModel)) { + initialModel = savedModel + currentModel = savedModel + } else { + initialModel = defaultModelFromBackend + currentModel = initialModel + } + + updateComboBoxModel() + logger.info("Updated model picker with ${fetchedModels.size} models from OpenAI endpoint") + } + } else { + logger.warn("No models returned from $baseUrl/v1/models") + } + } catch (e: Exception) { + logger.warn("Error fetching OpenAI models from $baseUrl: ${e.message}", e) + } + } + } + fun reset() { setModel(initialModel) } From b480a66ed746548ee754043fa124626d856c5c01 Mon Sep 17 00:00:00 2001 From: Stefan Bethge Date: Wed, 15 Apr 2026 01:17:11 +0200 Subject: [PATCH 7/7] Route OpenAI agent through Sweep pipeline; rules support; cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stream.kt - Agent mode is now a single completion turn that hands tool calls to SweepAgent.ingestToolCalls + awaitToolCalls. The existing CONTINUE_AGENT path drives multi-turn naturally — no parallel loop, no concurrent Stream.start() race, no manual loop detection. - New buildOpenAiAgentMessages() converts session messages to OpenAI's tool-calling shape (assistant.tool_calls + tool messages by id), preserving raw JSON arguments via ToolCall.rawText so numbers/booleans aren't string-coerced when echoed back. - System prompt now appends project rules from SweepConfig (SWEEP.md / AGENTS.md / CLAUDE.md, hierarchical + scoped to context). - stop() sets cancelledByUser before the streamingJob null-check so the OpenAI path (which has no streamingJob) can be cancelled. OpenAIAgentService - Drop dead runAgentLoop / executeTool / DESTRUCTIVE_TOOLS — execution goes through the Sweep tool pipeline now. - Drop streamWithToolCallsPublic wrapper and clearActiveConnection, drop unused project parameter. LocalAutocompleteServerManager - Replace java.net.http health check with HttpURLConnection (HttpClient has localhost timeout issues on macOS). - Add terminalStartInProgress guard to avoid duplicate server starts when multiple project windows open simultaneously. Co-Authored-By: Claude Opus 4.6 (1M context) --- build.gradle.kts | 2 +- .../components/ChatHistoryComponent.kt | 21 +- .../FileModificationToolCallItem.kt | 10 +- .../assistant/components/WelcomeScreen.kt | 2 +- .../dev/sweep/assistant/controllers/Stream.kt | 276 +++++++++++++----- .../services/AutocompleteIpResolverService.kt | 11 + .../LocalAutocompleteServerManager.kt | 42 ++- .../assistant/services/OpenAIAgentService.kt | 187 +----------- .../sweep/assistant/settings/SweepSettings.kt | 4 +- .../assistant/startup/SweepStartupActivity.kt | 10 +- 10 files changed, 292 insertions(+), 273 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 4a22914..a7eb40f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -15,7 +15,7 @@ plugins { val remoteRobotVersion = "0.11.20" val pluginId = "dev.sweep.assistant" -val pluginName = "Sweep Autocomplete OSS" +val pluginName = "Sweep Autocomplete and Agent" println("Building plugin: $pluginName with ID: $pluginId") group = "dev.sweep" version = "1.30.0" diff --git a/src/main/kotlin/dev/sweep/assistant/components/ChatHistoryComponent.kt b/src/main/kotlin/dev/sweep/assistant/components/ChatHistoryComponent.kt index cad1adf..145bdfb 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/ChatHistoryComponent.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/ChatHistoryComponent.kt @@ -1414,13 +1414,24 @@ class ChatHistoryComponent( Disposer.register(popup, popupDisposable) + // Find a visible anchor component — prefer SweepComponent, fall back to the + // parent container or the project's frame so the popup actually appears. val sweepComponent = SweepComponent.getInstance(project) - if (!sweepComponent.component.isShowing) { - Disposer.dispose(popupDisposable) - return + val anchor = when { + sweepComponent.component.isShowing -> sweepComponent.component + parentContainer.isShowing -> parentContainer + else -> { + // Last resort: show relative to the IDE frame + val frame = com.intellij.openapi.wm.WindowManager.getInstance().getFrame(project) + if (frame != null && frame.isShowing) { + popup.showInCenterOf(frame) + return + } + Disposer.dispose(popupDisposable) + return + } } - - popup.showInCenterOf(sweepComponent.component) + popup.showInCenterOf(anchor) } private fun performImport( diff --git a/src/main/kotlin/dev/sweep/assistant/components/FileModificationToolCallItem.kt b/src/main/kotlin/dev/sweep/assistant/components/FileModificationToolCallItem.kt index 3a3b02d..fe5bdc1 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/FileModificationToolCallItem.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/FileModificationToolCallItem.kt @@ -519,9 +519,13 @@ class FileModificationToolCallItem( // Create a light virtual file with the correct extension for syntax highlighting val virtualFile = LightVirtualFile(fileName, "") - // Create PSI file and document - val psiFile = PsiManager.getInstance(project).findFile(virtualFile) ?: return - diffDocument = PsiDocumentManager.getInstance(project).getDocument(psiFile) ?: return + // Create PSI file and document (requires read action even on EDT) + val psiFile = com.intellij.openapi.application.ReadAction.compute { + PsiManager.getInstance(project).findFile(virtualFile) + } ?: return + diffDocument = com.intellij.openapi.application.ReadAction.compute { + PsiDocumentManager.getInstance(project).getDocument(psiFile) + } ?: return // Create the editor diffEditor = diff --git a/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt b/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt index 8b5184d..94dd272 100644 --- a/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt +++ b/src/main/kotlin/dev/sweep/assistant/components/WelcomeScreen.kt @@ -29,7 +29,7 @@ class WelcomeScreen( icon(SweepIcons.BigSweepIcon.scale(80f)).align(AlignX.CENTER) } row { - text("Sweep Autocomplete OSS") + text("Sweep Autocomplete and Agent") .applyToComponent { font = font.deriveFont(java.awt.Font.BOLD, font.size * 1.2f) }.align(AlignX.CENTER) diff --git a/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt b/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt index 97ca254..6c7a759 100644 --- a/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt +++ b/src/main/kotlin/dev/sweep/assistant/controllers/Stream.kt @@ -280,15 +280,23 @@ class Stream( // For non-Sweep backends (LM Studio, Ollama, etc.), use OpenAI-compatible API if (!dev.sweep.assistant.services.OpenAIChatService.isSweepBackend()) { + currentMarkdownDisplay.startStreaming() + // Mark streaming active so stop button shows + StreamStateService.getInstance(project).notify(true, false, true, currentConversationId) + var selectedModel = SweepComponent.getSelectedModelId(project) + ?: sessionMessageList.selectedModel + // "auto" is a Sweep-specific concept — for OpenAI endpoints, use the first available model + if (selectedModel == null || selectedModel == "auto" || selectedModel.isEmpty()) { + val available = dev.sweep.assistant.services.OpenAIChatService.getInstance().fetchModels() + selectedModel = available.values.firstOrNull() ?: "" + } + + // Launch HTTP work in a background coroutine — just like the Sweep backend path + // uses streamingJob. This prevents blocking the caller (which may be on EDT via + // runBlocking in MessagesComponent.update). + val model = selectedModel + streamingJob = coroutineScope.launch { try { - currentMarkdownDisplay.startStreaming() - var selectedModel = SweepComponent.getSelectedModelId(project) - ?: sessionMessageList.selectedModel - // "auto" is a Sweep-specific concept — for OpenAI endpoints, use the first available model - if (selectedModel == null || selectedModel == "auto" || selectedModel.isEmpty()) { - val available = dev.sweep.assistant.services.OpenAIChatService.getInstance().fetchModels() - selectedModel = available.values.firstOrNull() ?: "" - } logger.info("[OpenAI] Starting chat with model='$selectedModel', ${sessionMessageList.snapshot().size} messages") @@ -356,7 +364,7 @@ class Stream( append("You are an expert coding agent integrated into a JetBrains IDE. ") append("You can read files, search code, edit files, and run shell commands to help the user.\n\n") append("When you need to make changes, use the available tools. ") - append("Read the relevant files first, then make targeted edits using str_replace.\n\n") + append("Read the relevant files first, then make targeted edits.\n\n") } else { append("You are an expert coding assistant integrated into a JetBrains IDE. ") } @@ -389,18 +397,28 @@ class Stream( append("\nThe user is working with the following files:\n\n") append(contextParts.joinToString("\n\n")) } - } - // Convert messages to OpenAI format, prepending system message - val openAIMessages = mutableListOf( - dev.sweep.assistant.services.OpenAIChatService.ChatMessage("system", systemPrompt), - ) - openAIMessages.addAll(sessionMessageList.snapshot().map { msg -> - dev.sweep.assistant.services.OpenAIChatService.ChatMessage( - role = msg.role.name.lowercase(), - content = msg.content, - ) - }) + // Append project-level rules (SWEEP.md / AGENTS.md / CLAUDE.md) plus + // any hierarchical rules scoped to the files currently in context. + val sweepConfig = dev.sweep.assistant.components.SweepConfig.getInstance(project) + val contextForRules = buildList { + effectiveCurrentFilePath?.let { if (it.isNotBlank()) add(it) } + for (value in includedFiles.values) { + val fp = TutorialPage.normalizeTutorialPath(value) + if (fp.isNotBlank() && fp != effectiveCurrentFilePath) add(fp) + } + }.distinct() + val rules = try { + if (contextForRules.isNotEmpty()) + sweepConfig.getDynamicRulesContent(contextForRules) + ?: sweepConfig.getCurrentRulesContent() + else sweepConfig.getCurrentRulesContent() + } catch (_: Exception) { null } + if (!rules.isNullOrBlank()) { + append("\n\n") + append(rules) + } + } // Create empty assistant message with empty annotations for code replacements val assistantMessage = Message( @@ -410,59 +428,78 @@ class Stream( ) onMessageUpdated(assistantMessage) - logger.info("[OpenAI] System prompt length=${systemPrompt.length}, mode=$currentMode, agent=$isAgentMode, sending ${openAIMessages.size} messages to ${dev.sweep.assistant.settings.SweepSettings.getInstance().baseUrl}") + logger.info("[OpenAI] System prompt length=${systemPrompt.length}, mode=$currentMode, agent=$isAgentMode, baseUrl=${dev.sweep.assistant.settings.SweepSettings.getInstance().baseUrl}") if (isAgentMode) { - // Agent mode: multi-turn loop with tool calling - // Only send system prompt + the latest user message to avoid - // the model trying to redo earlier chat-mode responses with tools - val agentService = dev.sweep.assistant.services.OpenAIAgentService(project) + // Agent mode: a single completion turn. Tool calls are handed to the + // SweepAgent pipeline (which renders the UI and triggers CONTINUE_AGENT + // → another Stream.start(isFollowupToToolCall=true) for the next turn). + val agentService = dev.sweep.assistant.services.OpenAIAgentService() activeAgentService = agentService - val latestUserMessage = openAIMessages.lastOrNull { it.role == "user" } - val agentMessages = mutableListOf( - dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("system", systemPrompt), - ) - if (latestUserMessage != null) { - agentMessages.add( - dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("user", latestUserMessage.content) - ) - } - agentService.runAgentLoop( - messages = agentMessages, - model = selectedModel ?: "", - onTextChunk = { chunk -> - assistantMessage.content += chunk - onMessageUpdated(assistantMessage) - }, - onToolCallStart = { toolName, params -> - assistantMessage.content += "\n\n> **Using tool:** `$toolName`" - if (params.isNotEmpty()) { - val paramsStr = params.entries.joinToString(", ") { "${it.key}=${it.value.take(50)}" } - assistantMessage.content += " ($paramsStr)" - } - assistantMessage.content += "\n" - onMessageUpdated(assistantMessage) - }, - onToolCallResult = { toolName, result, success -> - val statusEmoji = if (success) "+" else "!" - val preview = result.take(200).replace("\n", " ") - assistantMessage.content += "> [$statusEmoji] `$toolName` result: $preview\n" - onMessageUpdated(assistantMessage) - }, - onDone = { - logger.info("[OpenAI Agent] Loop completed, total length=${assistantMessage.content.length}") - }, - onError = { e -> - logger.warn("[OpenAI Agent] Error: ${e.message}", e) + // Build OpenAI conversation by walking the session messages and expanding + // assistant messages that have tool_calls/completed results into the + // tool-calling protocol shape. + val agentMessages = buildOpenAiAgentMessages(systemPrompt, sessionMessageList.snapshot()) + + val response = try { + agentService.streamWithToolCalls( + agentMessages, selectedModel ?: "", + onTextChunk = { chunk -> + assistantMessage.content += chunk + onMessageUpdated(assistantMessage) + }, + isCancelled = { cancelledByUser }, + ) + } catch (e: Exception) { + if (!cancelledByUser) { assistantMessage.content += "\n\n**Error:** ${e.message}" onMessageUpdated(assistantMessage) - }, - isCancelled = { cancelledByUser }, - conversationId = currentConversationId, - ) + } + null + } + + if (response != null && !cancelledByUser && response.toolCalls.isNotEmpty()) { + // Convert OpenAI tool calls to Sweep ToolCall objects. + // Always assign a fresh UUID: some models re-emit the SAME tool_call_id + // across turns, which collides with SweepAgentSession.jobsById (keyed by + // toolCallId) — the dedup short-circuits execution, awaitToolCalls returns + // with no completion in the new message, and AgentActionBlockDisplay's + // stream-stop listener then persists a spurious "Rejected" entry. + val sweepToolCalls = response.toolCalls.map { tc -> + dev.sweep.assistant.data.ToolCall( + toolCallId = java.util.UUID.randomUUID().toString(), + toolName = tc.name, + toolParameters = tc.arguments, + rawText = tc.rawArguments, + fullyFormed = true, + ) + } + assistantMessage.annotations?.toolCalls?.addAll(sweepToolCalls) + onMessageUpdated(assistantMessage) + + // Hand off to Sweep: ingest queues execution and renders UI; awaitToolCalls + // blocks until all tools complete (or are cancelled) and then fires + // CONTINUE_AGENT, which re-enters Stream.start with isFollowupToToolCall=true. + // We MUST return after this — falling through to the post-processing + // code would re-emit onMessageUpdated with tool calls in annotations, + // causing MessagesComponent to re-ingest them and spawn a duplicate cycle. + val sweepAgent = dev.sweep.assistant.agent.SweepAgent.getInstance(project) + sweepAgent.ingestToolCalls(sweepToolCalls, currentConversationId) + sweepAgent.awaitToolCalls(currentConversationId, assistantMessage) + return@launch // Sweep pipeline owns the continuation from here + } } else { - // Chat mode: simple streaming without tool calling + // Chat mode: simple streaming without tool calling. + val openAIMessages = mutableListOf( + dev.sweep.assistant.services.OpenAIChatService.ChatMessage("system", systemPrompt), + ) + openAIMessages.addAll(sessionMessageList.snapshot().map { msg -> + dev.sweep.assistant.services.OpenAIChatService.ChatMessage( + role = msg.role.name.lowercase(), + content = msg.content, + ) + }) dev.sweep.assistant.services.OpenAIChatService.getInstance().streamChatCompletion( messages = openAIMessages, model = selectedModel ?: "", @@ -482,8 +519,13 @@ class Stream( ) } - // Post-process: extract code blocks and generate codeReplacements for Apply button - enrichMessageWithCodeReplacements(assistantMessage) + // Post-process: extract code blocks and generate codeReplacements for Apply button. + // Skip in agent mode — tool results (diffs, file contents) are already rendered + // by the SweepAgent pipeline; extracting code blocks from the model's text + // response would create duplicate widgets. + if (!isAgentMode) { + enrichMessageWithCodeReplacements(assistantMessage) + } // Final update onMessageUpdated(assistantMessage) @@ -495,6 +537,7 @@ class Stream( currentMarkdownDisplay.stopStreaming() StreamStateService.getInstance(project).notify(false, false, false, currentConversationId) } + } // end streamingJob launch return } @@ -1338,6 +1381,74 @@ class Stream( return latestMessage } + /** + * Convert Sweep session messages to OpenAI's tool-calling protocol shape. + * Each Sweep ASSISTANT message with tool calls expands into: + * - one assistant message containing the original tool_calls, then + * - one `tool` message per completed call carrying the result string. + * + * Empty trailing assistant placeholders (added by SweepAgent's CONTINUE_AGENT) + * are skipped; they exist only as a UI slot for the upcoming response. + */ + private fun buildOpenAiAgentMessages( + systemPrompt: String, + sessionMessages: List, + ): MutableList { + val out = mutableListOf( + dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("system", systemPrompt), + ) + val gson = com.google.gson.Gson() + for (msg in sessionMessages) { + when (msg.role) { + MessageRole.USER -> { + if (msg.content.isNotEmpty()) { + out.add(dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("user", msg.content)) + } + } + MessageRole.ASSISTANT -> { + val toolCalls = msg.annotations?.toolCalls.orEmpty() + val completed = msg.annotations?.completedToolCalls.orEmpty() + if (toolCalls.isEmpty() && msg.content.isEmpty()) continue // empty placeholder slot + if (toolCalls.isNotEmpty()) { + val parsed = toolCalls.map { tc -> + // Re-emit raw JSON args verbatim when available so we don't + // string-coerce numbers/booleans/arrays back to quoted strings. + val raw = if (tc.rawText.isNotEmpty()) tc.rawText else gson.toJson(tc.toolParameters) + dev.sweep.assistant.services.OpenAIAgentService.ParsedToolCall( + id = tc.toolCallId, + name = tc.toolName, + arguments = tc.toolParameters, + rawArguments = raw, + ) + } + out.add(dev.sweep.assistant.services.OpenAIAgentService.AgentMessage( + role = "assistant", + content = msg.content, + toolCalls = parsed, + )) + for (tc in toolCalls) { + val result = completed.firstOrNull { it.toolCallId == tc.toolCallId } + val body = result?.resultString?.take(8000) ?: "Tool execution pending" + val suffix = if ((result?.resultString?.length ?: 0) > 8000) + "\n\n[Output truncated. Use offset/limit parameters to read specific sections.]" else "" + out.add(dev.sweep.assistant.services.OpenAIAgentService.AgentMessage( + role = "tool", + content = body + suffix, + toolCallId = tc.toolCallId, + )) + } + } else if (msg.content.isNotEmpty()) { + out.add(dev.sweep.assistant.services.OpenAIAgentService.AgentMessage("assistant", msg.content)) + } + } + MessageRole.SYSTEM -> { + // System message is supplied by the caller; ignore any persisted ones. + } + } + } + return out + } + /** * Post-process an assistant message: * 1. Extract ... blocks into reasoning annotations @@ -1382,6 +1493,15 @@ class Stream( } fun stop(isUserInitiated: Boolean = true) { + // Only set the cancellation flag for user-initiated stops. + // Non-user-initiated stops (e.g. Stream.start() cleaning up the previous stream + // before a CONTINUE_AGENT follow-up) must NOT set this flag — doing so would + // cause AgentActionBlockDisplay's stream-stop listener to mark successfully + // completed tool calls as "Rejected". + if (isUserInitiated) { + cancelledByUser = true + } + // Cancel any active agent HTTP connection activeAgentService?.cancelActiveRequest() activeAgentService = null @@ -1389,12 +1509,22 @@ class Stream( logger.info( "[Stream.stop] Stopping stream: conversationId=$sessionConversationId, isUserInitiated=$isUserInitiated, hasJob=${streamingJob != null}", ) - // Only stop tool execution for this specific session, not all sessions - sessionConversationId?.let { convId -> - SweepAgent.getInstance(project).stopToolExecution(convId) + // Only stop tool execution for user-initiated stops + if (isUserInitiated) { + sessionConversationId?.let { convId -> + SweepAgent.getInstance(project).stopToolExecution(convId) + } + } + if (streamingJob == null) { + // OpenAI path has no streamingJob — for user-initiated stops, hide the + // stop button. For non-user-initiated (CONTINUE_AGENT follow-up), don't + // fire notify — the new stream will set its own state. + if (isUserInitiated) { + markdownDisplay?.stopStreaming() + StreamStateService.getInstance(project).notify(false, false, false, sessionConversationId) + } + return } - if (streamingJob == null) return - cancelledByUser = true MessagesComponent.getInstance(project).showScrollbar() markdownDisplay?.stopStreaming() streamingJob?.cancel() diff --git a/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt b/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt index e7bd726..42e743a 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/AutocompleteIpResolverService.kt @@ -351,6 +351,17 @@ class AutocompleteIpResolverService( } private fun fetchViaNativeEngine(request: NextEditAutocompleteRequest): NextEditAutocompleteResponse? { + // Ensure the local llama-server is running before attempting an autocomplete. + // The startup activity also tries to start it on project open, but this is a + // safety net for cases where startup didn't fire or the server died. + val serverManager = LocalAutocompleteServerManager.getInstance() + if (!serverManager.isServerHealthy()) { + logger.info("Local autocomplete server not healthy on first request — starting in terminal") + serverManager.startServerInTerminal(project) + // Server takes several seconds to load; skip this autocomplete request rather + // than block. Subsequent requests will succeed once the server is up. + return null + } val engine = getOrCreateNativeEngine() val nesRequest = dev.sweep.assistant.autocomplete.edit.engine.NextEditAutocompleteEngine.NesRequest( diff --git a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt index 40d12f5..7d9f135 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/LocalAutocompleteServerManager.kt @@ -81,20 +81,22 @@ class LocalAutocompleteServerManager : Disposable { startServer(onStatus) } - fun isServerHealthy(): Boolean = - try { - val request = - HttpRequest - .newBuilder() - .uri(URI.create(getServerUrl())) - .timeout(Duration.ofMillis(HEALTH_CHECK_TIMEOUT_MS)) - .GET() - .build() - val response = httpClient.send(request, HttpResponse.BodyHandlers.discarding()) - response.statusCode() in 200..499 + fun isServerHealthy(): Boolean { + // Use HttpURLConnection — java.net.http.HttpClient has known issues with localhost on macOS + return try { + val url = java.net.URL(getServerUrl()) + val conn = (url.openConnection() as java.net.HttpURLConnection).apply { + requestMethod = "GET" + connectTimeout = HEALTH_CHECK_TIMEOUT_MS.toInt() + readTimeout = HEALTH_CHECK_TIMEOUT_MS.toInt() + } + val status = conn.responseCode + conn.disconnect() + status in 200..499 } catch (e: Exception) { false } + } @Synchronized private fun startServer(onStatus: ((String) -> Unit)? = null) { @@ -532,13 +534,25 @@ class LocalAutocompleteServerManager : Disposable { * Starts the local autocomplete server in a visible IDE terminal tab. * If the server is already healthy, does nothing. */ + @Volatile + private var terminalStartInProgress = false + fun startServerInTerminal(project: Project) { if (isServerHealthy()) { logger.info("Local autocomplete server is already running") return } + // Prevent multiple concurrent terminal starts (e.g. when multiple projects open at once) + if (terminalStartInProgress) { + logger.info("Local autocomplete server terminal start already in progress, skipping") + return + } + terminalStartInProgress = true - val command = getServerCommand() ?: return + val command = getServerCommand() ?: run { + terminalStartInProgress = false + return + } ApplicationManager.getApplication().invokeLater { try { @@ -570,10 +584,14 @@ class LocalAutocompleteServerManager : Disposable { ApplicationManager.getApplication().invokeLater { TerminalApiWrapper.sendCommand(targetWidget, command, project, isPowerShell) } + // Clear the in-progress flag after the server has had time to start + Thread.sleep(30_000) + terminalStartInProgress = false } logger.info("Started local autocomplete server in terminal: $command") } catch (e: Exception) { logger.warn("Failed to start local autocomplete server in terminal: ${e.message}") + terminalStartInProgress = false showNotification( "Failed to open terminal for local autocomplete server: ${e.message}", NotificationType.ERROR, diff --git a/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt b/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt index 28f51c7..b3f7694 100644 --- a/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt +++ b/src/main/kotlin/dev/sweep/assistant/services/OpenAIAgentService.kt @@ -3,10 +3,6 @@ package dev.sweep.assistant.services import com.google.gson.Gson import com.google.gson.JsonParser import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project -import dev.sweep.assistant.agent.tools.ToolType -import dev.sweep.assistant.data.CompletedToolCall -import dev.sweep.assistant.data.ToolCall import dev.sweep.assistant.settings.SweepSettings import java.io.BufferedReader import java.io.InputStreamReader @@ -15,11 +11,12 @@ import java.net.HttpURLConnection import java.net.URL /** - * OpenAI-compatible agent service with function calling support. - * Executes a multi-turn agent loop: sends messages → model responds with tool calls - * → executes tools → sends results back → repeats until model gives a text response. + * OpenAI-compatible chat client with function-calling support. + * Streams a single completion turn; the multi-turn agent loop and tool + * execution live in [dev.sweep.assistant.controllers.Stream] and dispatch + * through the existing SweepAgent pipeline. */ -class OpenAIAgentService(private val project: Project) { +class OpenAIAgentService { private val logger = Logger.getInstance(OpenAIAgentService::class.java) private val gson = Gson() @@ -158,144 +155,14 @@ class OpenAIAgentService(private val project: Project) { val toolCalls: List? = null, // For assistant messages with tool calls ) - /** - * Run the agent loop: stream responses, execute tool calls, send results back. - * - * @param messages Initial conversation messages (including system prompt) - * @param model The model ID - * @param onTextChunk Called with each text chunk for live display - * @param onToolCallStart Called when a tool execution starts (for UI feedback) - * @param onToolCallResult Called when a tool finishes (for UI feedback) - * @param onDone Called when the agent loop completes - * @param onError Called on error - * @param isCancelled Check for cancellation - * @param conversationId The conversation ID for tool execution context - */ - fun runAgentLoop( - messages: MutableList, - model: String, - onTextChunk: (String) -> Unit, - onToolCallStart: (String, Map) -> Unit, - onToolCallResult: (String, String, Boolean) -> Unit, - onDone: () -> Unit, - onError: (Exception) -> Unit, - isCancelled: () -> Boolean, - conversationId: String, - ) { - var turns = 0 - var lastToolCallSignature = "" // Track last tool call to detect exact loops - var sameToolNameCount = 0 // Track consecutive calls to same tool - var lastToolName = "" - - while (turns < MAX_AGENT_TURNS && !isCancelled()) { - turns++ - logger.info("[Agent] Turn $turns, ${messages.size} messages") - - // Log the last few messages to debug tool call loop issues - val recentMessages = messages.takeLast(3) - recentMessages.forEach { msg -> - logger.info("[Agent] Message: role=${msg.role}, content=${msg.content?.take(100) ?: "null"}, toolCalls=${msg.toolCalls?.size ?: 0}, toolCallId=${msg.toolCallId}") - } - - val response = try { - streamWithToolCalls(messages, model, onTextChunk, isCancelled) - } catch (e: Exception) { - if (isCancelled()) { onDone(); return } - onError(e) - return - } - - if (isCancelled()) { - onDone() - return - } - - when { - response.toolCalls.isNotEmpty() -> { - // Add assistant message WITH tool_calls to conversation - // The OpenAI API requires the assistant message to echo back the tool calls - messages.add(AgentMessage( - role = "assistant", - content = response.textContent, - toolCalls = response.toolCalls, - )) - - // Check for repeated tool call loops - val currentSignature = response.toolCalls.joinToString("|") { "${it.name}:${it.arguments}" } - if (currentSignature == lastToolCallSignature) { - sameToolNameCount++ // reuse counter for exact match tracking - if (sameToolNameCount >= 3) { - logger.warn("[Agent] Detected exact tool call loop (3x), breaking") - onTextChunk("\n\n[Agent detected repeated tool call, stopping]") - onDone() - return - } - } else { - sameToolNameCount = 0 - } - lastToolCallSignature = currentSignature - - // Also track consecutive same-tool-name calls with different args - val primaryToolName = response.toolCalls.firstOrNull()?.name ?: "" - if (primaryToolName == lastToolName) { - // already tracked via sameToolNameCount above - } else { - lastToolName = primaryToolName - } - - // Execute each tool call - for (tc in response.toolCalls) { - if (isCancelled()) break - - logger.info("[Agent] Executing tool: ${tc.name}(${tc.arguments.keys.joinToString(", ")})") - onToolCallStart(tc.name, tc.arguments) - - val result = executeTool(tc.name, tc.arguments, conversationId) - - logger.info("[Agent] Tool ${tc.name} result: ${result.resultString.take(100)}...") - onToolCallResult(tc.name, result.resultString, result.status) - - // Add tool result to conversation - val maxToolResultSize = 8000 - val toolContent = if (result.resultString.length > maxToolResultSize) { - result.resultString.take(maxToolResultSize) + - "\n\n[Output truncated at $maxToolResultSize chars. " + - "Total: ${result.resultString.length} chars. " + - "Use offset/limit parameters to read specific sections.]" - } else { - result.resultString - } - messages.add(AgentMessage( - role = "tool", - content = toolContent, - toolCallId = tc.id, - )) - } - } - response.textContent?.isNotEmpty() == true -> { - onDone() - return - } - else -> { - onDone() - return - } - } - } - - if (turns >= MAX_AGENT_TURNS) { - onTextChunk("\n\n[Agent stopped after $MAX_AGENT_TURNS turns]*") - } - onDone() - } - data class ParsedToolCall(val id: String, val name: String, val arguments: Map, val rawArguments: String = "") data class StreamResponse(val textContent: String?, val toolCalls: List) /** * Stream a single chat completion, returning text content and any tool calls. + * Caller is responsible for the multi-turn loop (see Stream.kt). */ - private fun streamWithToolCalls( + fun streamWithToolCalls( messages: List, model: String, onTextChunk: (String) -> Unit, @@ -356,7 +223,11 @@ class OpenAIAgentService(private val project: Project) { } activeConnection = conn - OutputStreamWriter(conn.outputStream).use { it.write(requestJson); it.flush() } + // Write request body + val writer = OutputStreamWriter(conn.outputStream) + writer.write(requestJson) + writer.flush() + writer.close() val status = conn.responseCode if (status != 200) { @@ -453,38 +324,4 @@ class OpenAIAgentService(private val project: Project) { return StreamResponse(textParts.toString().ifEmpty { null }, parsedToolCalls) } - - /** - * Execute a tool using the existing plugin tool implementations. - */ - private fun executeTool(name: String, arguments: Map, conversationId: String): CompletedToolCall { - return try { - val tool = ToolType.createToolInstance(name, false) - if (tool == null) { - CompletedToolCall( - toolCallId = "", - toolName = name, - resultString = "Unknown tool: $name", - status = false, - ) - } else { - val toolCall = ToolCall( - toolCallId = java.util.UUID.randomUUID().toString(), - toolName = name, - toolParameters = arguments, - rawText = "", - fullyFormed = true, - ) - tool.execute(toolCall, project, conversationId) - } - } catch (e: Exception) { - logger.warn("[Agent] Tool execution error for $name: ${e.message}", e) - CompletedToolCall( - toolCallId = "", - toolName = name, - resultString = "Error executing $name: ${e.message}", - status = false, - ) - } - } } diff --git a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt index 873dfa5..129ef27 100644 --- a/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt +++ b/src/main/kotlin/dev/sweep/assistant/settings/SweepSettings.kt @@ -248,7 +248,9 @@ class SweepSettings : PersistentStateComponent { } else if (SweepSettingsParser.isCloudEnvironment()) { githubToken != DEFAULT_GITHUB_TOKEN } else { - githubToken != DEFAULT_GITHUB_TOKEN && baseUrl != DEFAULT_SWEEP_URL + // Non-Sweep backends (LM Studio, Ollama, etc.) only need a base URL — + // an API token is optional for local servers. + baseUrl != DEFAULT_SWEEP_URL || (githubToken != DEFAULT_GITHUB_TOKEN) } fun notifySettingsChanged() { diff --git a/src/main/kotlin/dev/sweep/assistant/startup/SweepStartupActivity.kt b/src/main/kotlin/dev/sweep/assistant/startup/SweepStartupActivity.kt index c729716..20ce56c 100644 --- a/src/main/kotlin/dev/sweep/assistant/startup/SweepStartupActivity.kt +++ b/src/main/kotlin/dev/sweep/assistant/startup/SweepStartupActivity.kt @@ -222,10 +222,16 @@ class SweepStartupActivity : OSNotificationService.getInstance(project) // Auto-start local autocomplete server if enabled and not already running - if (SweepSettings.getInstance().autocompleteLocalMode) { + val startupLogger = com.intellij.openapi.diagnostic.Logger.getInstance(SweepStartupActivity::class.java) + val autocompleteLocal = SweepSettings.getInstance().autocompleteLocalMode + startupLogger.info("[SweepStartup] autocompleteLocalMode=$autocompleteLocal — checking server status") + if (autocompleteLocal) { ApplicationManager.getApplication().executeOnPooledThread { val manager = LocalAutocompleteServerManager.getInstance() - if (!manager.isServerHealthy()) { + val healthy = manager.isServerHealthy() + startupLogger.info("[SweepStartup] Local server healthy=$healthy") + if (!healthy) { + startupLogger.info("[SweepStartup] Starting local autocomplete server in terminal") manager.startServerInTerminal(project) } }