Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 171 additions & 102 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"textual>=8.0.0",
"textual>=8.0.1",
"ollama>=0.6.1",
"pydantic>=2.12.5",
"rich>=14.3.2",
Expand Down
154 changes: 144 additions & 10 deletions src/ollama_chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
ConversationPickerScreen,
ImageAttachScreen,
InfoScreen,
QuestionScreen,
SimplePickerScreen,
TextPromptScreen,
ThemePickerScreen,
Expand All @@ -70,6 +69,7 @@
)
from .tools.base import ToolContext
from .widgets.activity_bar import ActivityBar
from .widgets.ask_question_widget import AskQuestionWidget
from .widgets.conversation import ConversationView
from .widgets.input_box import InputBox
from .widgets.message import MessageBubble
Expand Down Expand Up @@ -473,6 +473,10 @@ def __init__(self) -> None:
self._w_activity: ActivityBar | None = None
self._w_status: StatusBar | None = None
self._w_conversation: ConversationView | None = None
self._w_input_box: InputBox | None = None
self._w_ask_question: AskQuestionWidget | None = None

self._question_answer_queue: asyncio.Queue[str | None] = asyncio.Queue()

self._slash_commands: list[tuple[str, str]] = [
("/image <path>", "Attach image from filesystem"),
Expand Down Expand Up @@ -802,6 +806,7 @@ def compose(self) -> ComposeResult:
with Container(id="app-root"):
yield ConversationView(id="conversation")
yield InputBox()
yield AskQuestionWidget()
yield StatusBar(id="status_bar")
yield ActivityBar(
shortcut_hints="ctrl+p commands",
Expand Down Expand Up @@ -931,6 +936,8 @@ async def on_mount(self) -> None:
self._w_activity = self.query_one("#activity_bar", ActivityBar)
self._w_status = self.query_one("#status_bar", StatusBar)
self._w_conversation = self.query_one(ConversationView)
self._w_input_box = self.query_one(InputBox)
self._w_ask_question = self.query_one(AskQuestionWidget)

attach_button = self.query_one("#attach_button", Button)
self._w_input.disabled = True
Expand Down Expand Up @@ -1173,32 +1180,53 @@ def _on_support_file_event(self, event: str, payload: dict[str, Any]) -> None:
def _on_question_asked(self, event_name: str, payload: dict[str, Any]) -> None:
"""Handle question.asked event from question_service.

Uses run_worker so the coroutine runs on Textual's main event loop
with a proper worker context — required by push_screen_wait, and safe
to call from any thread (including the tool-executor thread that fires
this callback via _run_async_from_sync).
Schedules _run_question_sequence on Textual's main event loop via
call_from_thread + run_worker. The question_service.ask() call may run
in a background thread (tools are executed via asyncio.to_thread), so
this handler must not call run_worker directly from that thread.
"""
self.run_worker(self._run_question_sequence(payload))
def _start_question_worker() -> None:
self.run_worker(self._run_question_sequence(payload))

# call_from_thread is thread-safe and will invoke the closure on the
# main Textual thread, regardless of which thread published the event.
try:
self.call_from_thread(_start_question_worker)
except RuntimeError:
# Fallback for call sites already running on the app loop.
self.run_worker(self._run_question_sequence(payload))
except Exception as exc: # pragma: no cover - defensive logging
LOGGER.debug(
"app.question_worker.start_failed",
extra={"event": event_name, "error": str(exc)},
)

async def _run_question_sequence(self, payload: dict[str, Any]) -> None:
"""Show QuestionScreen modals sequentially and reply to question_service."""
"""Show inline AskQuestionWidget sequentially and reply to question_service."""
qid: str = payload.get("id", "")
questions: list[dict[str, Any]] = payload.get("questions", [])
all_answers: list[list[str]] = []

for q in questions:
try:
result: list[str] | None = await self.push_screen_wait(
QuestionScreen(q)
result = await self._ask_inline_question(q)
except asyncio.CancelledError:
LOGGER.debug(
"app.question_sequence.cancelled",
extra={"qid": qid},
)
break
except Exception as exc:
LOGGER.debug(
"app.question_screen.failed",
"app.question_inline.failed",
extra={"qid": qid, "error": str(exc)},
)
result = None
all_answers.append(result if result is not None else [])

if len(all_answers) < len(questions):
all_answers.extend([[] for _ in range(len(questions) - len(all_answers))])

try:
question_service.reply(qid, all_answers)
except Exception as exc:
Expand All @@ -1207,6 +1235,104 @@ async def _run_question_sequence(self, payload: dict[str, Any]) -> None:
extra={"qid": qid, "error": str(exc)},
)

def _show_question_widget(self) -> None:
input_box = self._w_input_box or self.query_one(InputBox)
ask_widget = self._w_ask_question or self.query_one(AskQuestionWidget)
input_box.display = False
ask_widget.display = True

def _focus_question_controls() -> None:
try:
options = ask_widget.query_one("#aq-options", OptionList)
if options.option_count:
options.focus()
return
except Exception:
pass

try:
custom_input = ask_widget.query_one("#aq-custom-input", Input)
if custom_input.display:
custom_input.focus()
return
except Exception:
pass

try:
ask_widget.focus()
except Exception:
return

self.call_after_refresh(_focus_question_controls)

def _restore_input(self) -> None:
ask_widget = self._w_ask_question or self.query_one(AskQuestionWidget)
input_box = self._w_input_box or self.query_one(InputBox)
ask_widget.display = False
input_box.display = True
input_widget = self._w_input or self.query_one("#message_input", Input)
input_widget.focus()

def _drain_answer_queue(self) -> None:
while True:
try:
self._question_answer_queue.get_nowait()
except asyncio.QueueEmpty:
return

async def _ask_inline_question(self, question_info: dict[str, Any]) -> list[str] | None:
question = str(question_info.get("question", "")).strip()
header = str(question_info.get("header", "Assistant Question")).strip()
options_raw = question_info.get("options")
custom = bool(question_info.get("custom", True))

options: list[str] = []
if isinstance(options_raw, list):
for item in options_raw:
if isinstance(item, dict):
label = str(item.get("label", "")).strip()
if label:
options.append(label)
elif isinstance(item, str):
label = item.strip()
if label:
options.append(label)

if not question:
return []

ask_widget = self._w_ask_question or self.query_one(AskQuestionWidget)
ask_widget.border_title = header or "Assistant Question"

self._drain_answer_queue()
ask_widget.load_question(question, options, custom=custom)
self._show_question_widget()

try:
answer = await self._question_answer_queue.get()
finally:
try:
self._restore_input()
except Exception as exc:
LOGGER.debug(
"app.question.restore_input_failed",
extra={"error": str(exc)},
)

if answer is None:
return []

value = str(answer).strip()
if not value:
return []
return [value]

async def on_ask_question_widget_answered(
self, message: AskQuestionWidget.Answered
) -> None:
message.stop()
self._question_answer_queue.put_nowait(message.value)

@property
def show_timestamps(self) -> bool:
return bool(self.config["ui"]["show_timestamps"])
Expand Down Expand Up @@ -1556,6 +1682,14 @@ def _scroll() -> None:

async def action_interrupt_stream(self) -> None:
"""Cancel an in-flight assistant response (delegates to StreamManager)."""
try:
ask_widget = self._w_ask_question or self.query_one(AskQuestionWidget)
if ask_widget.display:
self._question_answer_queue.put_nowait(None)
return
except Exception:
pass

interrupted = await self.stream_manager.interrupt_stream(self.chat.model)
if interrupted:
self._update_status_bar()
Expand Down
2 changes: 1 addition & 1 deletion src/ollama_chat/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def from_config(cls, config: dict[str, Any]) -> CapabilityContext:
show_thinking=bool(cap_cfg.get("show_thinking", True)),
web_search_enabled=bool(cap_cfg.get("web_search_enabled", False)),
web_search_api_key=str(cap_cfg.get("web_search_api_key", "")),
max_tool_iterations=int(cap_cfg.get("max_tool_iterations", 10)),
max_tool_iterations=int(cap_cfg.get("max_tool_iterations", 20)),
)


Expand Down
28 changes: 14 additions & 14 deletions src/ollama_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
_QUESTION_USE_POLICY = """

CLARIFYING QUESTIONS POLICY:
You have access to a 'question' tool. Use it proactively when:
You have access to an 'ask_user_question' tool. Use it proactively when:

ALWAYS ASK WHEN:
• User request is ambiguous ("fix the bug", "add auth", "optimize code")
Expand All @@ -77,13 +77,10 @@
• Questions should be answerable in <10 seconds

FORMAT REQUIREMENTS (CRITICAL):
• Each question MUST have a 'header' field (short label, max 30 chars)
• Each question MUST have a 'question' field (full question text)
• 'options' MUST be a list of objects, NOT strings
• Each option MUST have both 'label' and 'description' fields
• Example: {"label": "Redis", "description": "Fast, external, scalable"}
• WRONG: ["Redis", "Memcached"] ❌
• CORRECT: [{"label": "Redis", "description": "Fast, external"}, ...] ✓
• The tool call MUST include a 'question' string
• The tool call MUST include an 'options' list of 2-5 strings
• WRONG: {"options": [{"label": "Redis"}]} ❌
• CORRECT: {"options": ["Redis", "Memcached"]} ✓

WHEN NOT TO ASK (rare):
• Task is completely unambiguous ("write factorial function")
Expand All @@ -102,7 +99,7 @@
Example 1 - Ambiguous Code Target:
User: "Refactor the database connection code"
You (thinking): "Multiple files handle DB connections. Need to ask which."
You (action): Call question tool with:
You (action): Call ask_user_question tool with:
question: "Which database connection code should I refactor?"
options: [
"Main connection pool (db/pool.py)",
Expand All @@ -116,7 +113,7 @@
Example 2 - Technology Choice:
User: "Add caching to the API endpoints"
You (thinking): "Many caching strategies exist. Should ask."
You (action): Call question tool with:
You (action): Call ask_user_question tool with:
question: "Which caching backend should I use?"
options: [
"Redis (fast, external, scalable)",
Expand All @@ -143,8 +140,10 @@ def factorial(n: int) -> int:

# Tools that are I/O-bound and fast - don't need thread pool overhead
# These tools complete quickly (<10ms) and don't block the event loop.
# NOTE: Tools that may wait on user interaction (e.g. "question") MUST NOT be
# listed here, otherwise they will block the Textual event loop and freeze the UI.
# NOTE: Tools that may wait on user interaction (e.g. ask_user_question) MUST NOT be
# listed here. The tool execution path is synchronous at this layer; running an
# interactive tool on the main event loop would block Textual from rendering the
# question UI and handling input.
FAST_SYNC_TOOLS = {
"read",
"write",
Expand Down Expand Up @@ -763,9 +762,10 @@ async def _stream_once_with_capabilities(
# Build policy text - start with base tool use policy
policy_text = _TOOL_USE_POLICY

# Check if question tool is available in this request
# Check if ask_user_question tool is available in this request
has_question_tool = any(
t.get("function", {}).get("name") == "question" for t in formatted_tools
t.get("function", {}).get("name") == "ask_user_question"
for t in formatted_tools
)

# Add question-specific guidance if tool is present
Expand Down
2 changes: 1 addition & 1 deletion src/ollama_chat/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"batch_tool",
"lsp_tool",
"plan_tool",
"question_tool",
"ask_user_question_tool",
"todo_tool",
"skill_tool",
"apply_patch_tool",
Expand Down
73 changes: 73 additions & 0 deletions src/ollama_chat/tools/ask_user_question_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from pydantic import Field

from ..support import question_service
from .base import ParamsSchema, Tool, ToolContext, ToolResult


class AskUserQuestionParams(ParamsSchema):
question: str = Field(description="The question to present to the user")
options: list[str] = Field(
description="List of 2–5 answer options to display",
min_length=2,
max_length=5,
)


class AskUserQuestionTool(Tool):
id = "ask_user_question"
description = (
"Ask the user a multiple-choice question when you need clarification before proceeding. "
"Use this instead of guessing. Provide 2 to 5 clear, distinct options. "
"The user may also type a custom answer."
)
params_schema = AskUserQuestionParams

async def execute(
self, params: AskUserQuestionParams, ctx: ToolContext
) -> ToolResult:
question = (params.question or "").strip()
options = [str(opt).strip() for opt in (params.options or []) if str(opt).strip()]

if not question:
return ToolResult(
title="Question",
output="User did not answer.",
metadata={"answer": None},
)

if len(options) < 2:
return ToolResult(
title="Question",
output="User did not answer.",
metadata={"answer": None},
)

questions = [
{
"header": "Assistant Question",
"question": question,
"options": [{"label": opt, "description": ""} for opt in options],
"multiple": False,
"custom": True,
}
]

answers = await question_service.ask(session_id=ctx.session_id, questions=questions)
chosen: str | None = None
if answers and answers[0]:
chosen = str(answers[0][0]).strip() if str(answers[0][0]).strip() else None

if not chosen:
return ToolResult(
title="Question",
output="User did not answer.",
metadata={"answer": None},
)

return ToolResult(
title="Question",
output=chosen,
metadata={"answer": chosen},
)
Loading
Loading