From 838f44d5965e98dc3a9dcd3a510cd61ddc1c3fa1 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Sun, 5 Apr 2026 21:37:48 +0530 Subject: [PATCH 1/6] feat(litellm): add support for local proxy without API key - Add litellm to interactive provider selection menu - Support LITELLM_BASE_URL for local proxy deployments (no API key required) - Auto-add openai/ prefix when using api_base for proper LiteLLM routing - Add dummy API key for local proxies (OpenAI SDK requirement) - Add validation and tests for litellm provider configuration Co-Authored-By: Claude Opus 4.6 --- packages/cli/src/repowise/cli/helpers.py | 44 +++- packages/cli/src/repowise/cli/ui.py | 32 ++- .../repowise/core/providers/llm/litellm.py | 38 +++- tests/unit/cli/test_helpers.py | 27 +++ .../test_providers/test_litellm_provider.py | 199 ++++++++++++++++++ 5 files changed, 324 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_providers/test_litellm_provider.py diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index d5c3383..124481c 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -262,6 +262,12 @@ def resolve_provider( kwargs["api_key"] = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") elif provider_name == "ollama" and os.environ.get("OLLAMA_BASE_URL"): kwargs["base_url"] = os.environ["OLLAMA_BASE_URL"] + elif provider_name == "litellm": + # LiteLLM: API key for cloud, base URL for local proxy + if os.environ.get("LITELLM_API_KEY"): + kwargs["api_key"] = os.environ["LITELLM_API_KEY"] + if os.environ.get("LITELLM_BASE_URL"): + kwargs["api_base"] = os.environ["LITELLM_BASE_URL"] return get_provider(provider_name, **kwargs) @@ -293,10 +299,26 @@ def resolve_provider( api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") kwargs = {"model": model, "api_key": api_key} if model else {"api_key": api_key} return get_provider("gemini", **kwargs) + # LiteLLM: check for API key (cloud) or base URL (local proxy) + if os.environ.get("LITELLM_API_KEY") and os.environ["LITELLM_API_KEY"].strip(): + kwargs = ( + {"model": model, "api_key": os.environ["LITELLM_API_KEY"]} + if model + else {"api_key": os.environ["LITELLM_API_KEY"]} + ) + return get_provider("litellm", **kwargs) + if os.environ.get("LITELLM_BASE_URL") and os.environ["LITELLM_BASE_URL"].strip(): + kwargs = ( + {"model": model, "api_base": os.environ["LITELLM_BASE_URL"]} + if model + else {"api_base": os.environ["LITELLM_BASE_URL"]} + ) + return get_provider("litellm", **kwargs) raise click.ClickException( "No provider configured. Use --provider, set REPOWISE_PROVIDER, " - "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / GOOGLE_API_KEY." + "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / " + "LITELLM_API_KEY / LITELLM_BASE_URL." ) @@ -332,7 +354,10 @@ def _is_env_var_exists(var_name: str) -> bool: "openai": ["OPENAI_API_KEY"], "gemini": ["GEMINI_API_KEY", "GOOGLE_API_KEY"], # Either one "ollama": ["OLLAMA_BASE_URL"], - "litellm": ["LITELLM_API_KEY"], # May need others depending on backend + "litellm": [ + "LITELLM_API_KEY", + "LITELLM_BASE_URL", + ], # Either one (API key for cloud, base URL for local) } if provider_name: @@ -348,6 +373,10 @@ def _is_env_var_exists(var_name: str) -> bool: # Special case: either GEMINI_API_KEY or GOOGLE_API_KEY if not (_is_env_var_set("GEMINI_API_KEY") or _is_env_var_set("GOOGLE_API_KEY")): missing_vars = env_vars + elif provider_name == "litellm": + # Special case: LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy) + if not (_is_env_var_set("LITELLM_API_KEY") or _is_env_var_set("LITELLM_BASE_URL")): + missing_vars = env_vars else: for var in env_vars: if not _is_env_var_set(var): @@ -370,6 +399,17 @@ def _is_env_var_exists(var_name: str) -> bool: ) continue + if name == "litellm": + # Special case: LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy) + # Only warn if explicitly requested and neither is set + if os.environ.get("REPOWISE_PROVIDER") == "litellm" and not ( + _is_env_var_set("LITELLM_API_KEY") or _is_env_var_set("LITELLM_BASE_URL") + ): + warnings.append( + "Provider 'litellm' requires LITELLM_API_KEY or LITELLM_BASE_URL environment variable" + ) + continue + missing = [var for var in env_vars if not _is_env_var_set(var)] if missing: # Only warn if this provider is explicitly requested OR diff --git a/packages/cli/src/repowise/cli/ui.py b/packages/cli/src/repowise/cli/ui.py index ae27c38..b089307 100644 --- a/packages/cli/src/repowise/cli/ui.py +++ b/packages/cli/src/repowise/cli/ui.py @@ -268,11 +268,14 @@ def print_phase_header( "litellm": "groq/llama-3.1-70b-versatile", } +# For most providers, a single env var indicates configuration. +# litellm is special: can use LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy). _PROVIDER_ENV: dict[str, str] = { "gemini": "GEMINI_API_KEY", "openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY", "ollama": "OLLAMA_BASE_URL", + "litellm": "LITELLM_API_KEY", # Also checks LITELLM_BASE_URL in _detect_provider_status } _PROVIDER_SIGNUP: dict[str, str] = { @@ -280,6 +283,7 @@ def print_phase_header( "openai": "https://platform.openai.com/api-keys", "anthropic": "https://console.anthropic.com/settings/keys", "ollama": "https://ollama.com/download", + "litellm": "https://docs.litellm.ai/docs/proxy/proxy", } @@ -410,6 +414,10 @@ def _detect_provider_status() -> dict[str, str]: if prov == "gemini": if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"): status[prov] = env_var + elif prov == "litellm": + # litellm can be configured via API key (cloud) OR base URL (local proxy) + if os.environ.get("LITELLM_API_KEY") or os.environ.get("LITELLM_BASE_URL"): + status[prov] = env_var elif os.environ.get(env_var): status[prov] = env_var return status @@ -476,14 +484,22 @@ def interactive_provider_select( env_var = _PROVIDER_ENV[chosen] signup_url = _PROVIDER_SIGNUP.get(chosen, "") console.print() - console.print(f" [bold]{chosen}[/bold] requires [cyan]{env_var}[/cyan].") - if signup_url: - console.print(f" Get your API key here: [{BRAND}]{signup_url}[/]") - console.print() - key = _prompt_api_key(console, chosen, env_var, repo_path=repo_path) - if not key: - console.print(f" [{WARN}]Skipped. Please select another provider.[/]") - return interactive_provider_select(console, model_flag, repo_path=repo_path) + # Special case: litellm local proxy doesn't need an API key + if chosen == "litellm" and os.environ.get("LITELLM_BASE_URL"): + console.print( + f" [{OK}]✓ Using LiteLLM proxy at[/] [{BRAND}]{os.environ['LITELLM_BASE_URL']}[/]" + ) + console.print(" [dim]No API key required for local proxy.[/dim]") + console.print() + else: + console.print(f" [bold]{chosen}[/bold] requires [cyan]{env_var}[/cyan].") + if signup_url: + console.print(f" Get your API key here: [{BRAND}]{signup_url}[/]") + console.print() + key = _prompt_api_key(console, chosen, env_var, repo_path=repo_path) + if not key: + console.print(f" [{WARN}]Skipped. Please select another provider.[/]") + return interactive_provider_select(console, model_flag, repo_path=repo_path) # --- model --- default_model = _PROVIDER_DEFAULTS.get(chosen, "") diff --git a/packages/core/src/repowise/core/providers/llm/litellm.py b/packages/core/src/repowise/core/providers/llm/litellm.py index 1273e22..c52e463 100644 --- a/packages/core/src/repowise/core/providers/llm/litellm.py +++ b/packages/core/src/repowise/core/providers/llm/litellm.py @@ -19,13 +19,16 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any + import structlog from tenacity import ( + RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter, - RetryError, ) from repowise.core.providers.llm.base import ( @@ -37,7 +40,6 @@ RateLimitError, ) -from typing import TYPE_CHECKING, Any, AsyncIterator from repowise.core.rate_limiter import RateLimiter if TYPE_CHECKING: @@ -55,9 +57,13 @@ class LiteLLMProvider(BaseProvider): Args: model: LiteLLM model string (e.g., "groq/llama-3.1-70b-versatile"). + When using api_base (local proxy), just use the model name + (e.g., "zai.glm-5") - the provider will auto-add "openai/" prefix. api_key: API key for the target provider. Some providers read from environment variables (e.g., GROQ_API_KEY, TOGETHER_API_KEY). - api_base: Optional custom API base URL (e.g., for self-hosted deployments). + For local proxies without auth, a dummy key is used. + api_base: Optional custom API base URL for self-hosted LiteLLM proxy. + When set, the model is treated as OpenAI-compatible. rate_limiter: Optional RateLimiter instance. """ @@ -75,6 +81,13 @@ def __init__( self._rate_limiter = rate_limiter self._cost_tracker = cost_tracker + # When using a custom api_base (proxy), treat model as OpenAI-compatible. + # LiteLLM requires "openai/" prefix to route to custom endpoints. + if api_base and not model.startswith("openai/"): + self._litellm_model = f"openai/{model}" + else: + self._litellm_model = model + @property def provider_name(self) -> str: return "litellm" @@ -130,7 +143,7 @@ async def _generate_with_retry( litellm.suppress_debug_info = True call_kwargs: dict[str, object] = { - "model": self._model, + "model": self._litellm_model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, @@ -142,6 +155,10 @@ async def _generate_with_retry( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base + # Local proxy without auth: OpenAI SDK still requires a key. + # Use a dummy key if none provided. + if not self._api_key: + call_kwargs["api_key"] = "sk-dummy" try: response = await litellm.acompletion(**call_kwargs) @@ -199,6 +216,7 @@ async def stream_chat( tool_executor: Any | None = None, ) -> AsyncIterator[ChatStreamEvent]: import json as _json + import litellm # type: ignore[import-untyped] litellm.set_verbose = False @@ -206,7 +224,7 @@ async def stream_chat( full_messages = [{"role": "system", "content": system_prompt}, *messages] call_kwargs: dict[str, Any] = { - "model": self._model, + "model": self._litellm_model, "messages": full_messages, "temperature": temperature, "max_tokens": max_tokens, @@ -218,6 +236,10 @@ async def stream_chat( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base + # Local proxy without auth: OpenAI SDK still requires a key. + # Use a dummy key if none provided. + if not self._api_key: + call_kwargs["api_key"] = "sk-dummy" try: stream = await litellm.acompletion(**call_kwargs) @@ -244,7 +266,11 @@ async def stream_chat( for tc_delta in delta.tool_calls: idx = tc_delta.index if idx not in tool_calls_acc: - tool_calls_acc[idx] = {"id": getattr(tc_delta, "id", "") or "", "name": "", "arguments": ""} + tool_calls_acc[idx] = { + "id": getattr(tc_delta, "id", "") or "", + "name": "", + "arguments": "", + } acc = tool_calls_acc[idx] if getattr(tc_delta, "id", None): acc["id"] = tc_delta.id diff --git a/tests/unit/cli/test_helpers.py b/tests/unit/cli/test_helpers.py index 1444ac2..07845c3 100644 --- a/tests/unit/cli/test_helpers.py +++ b/tests/unit/cli/test_helpers.py @@ -231,3 +231,30 @@ def test_anthropic_empty_key_auto_detect(self, monkeypatch): assert len(warnings) == 1 assert "anthropic" in warnings[0] assert "ANTHROPIC_API_KEY" in warnings[0] + + # --- litellm tests --- + + def test_litellm_with_api_key(self, monkeypatch): + monkeypatch.setenv("LITELLM_API_KEY", "test-key") + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + assert validate_provider_config() == [] + + def test_litellm_with_base_url(self, monkeypatch): + """Local proxy without API key should be valid.""" + monkeypatch.delenv("LITELLM_API_KEY", raising=False) + monkeypatch.setenv("LITELLM_BASE_URL", "http://localhost:4000/v1") + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + assert validate_provider_config() == [] + + def test_litellm_missing_both(self, monkeypatch): + """Should warn when neither API key nor base URL is set.""" + monkeypatch.delenv("LITELLM_API_KEY", raising=False) + monkeypatch.delenv("LITELLM_BASE_URL", raising=False) + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + warnings = validate_provider_config() + assert len(warnings) == 1 + assert "litellm" in warnings[0] + assert "LITELLM_API_KEY" in warnings[0] or "LITELLM_BASE_URL" in warnings[0] diff --git a/tests/unit/test_providers/test_litellm_provider.py b/tests/unit/test_providers/test_litellm_provider.py new file mode 100644 index 0000000..bf0bab4 --- /dev/null +++ b/tests/unit/test_providers/test_litellm_provider.py @@ -0,0 +1,199 @@ +"""Unit tests for LiteLLMProvider. + +All tests mock the litellm.acompletion call — no real API calls are made. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytest.importorskip("litellm", reason="litellm SDK not installed") + +from repowise.core.providers.llm.base import GeneratedResponse, ProviderError +from repowise.core.providers.llm.litellm import LiteLLMProvider + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_provider_name(): + p = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + assert p.provider_name == "litellm" + + +def test_default_model(): + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile", api_key="sk-test") + assert p.model_name == "groq/llama-3.1-70b-versatile" + + +def test_model_without_api_base(): + """Without api_base, model should be passed through unchanged.""" + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile", api_key="sk-test") + assert p._litellm_model == "groq/llama-3.1-70b-versatile" + + +def test_model_with_api_base_adds_openai_prefix(): + """With api_base (local proxy), model should get openai/ prefix.""" + p = LiteLLMProvider( + model="zai.glm-5", + api_base="http://localhost:4000/v1", + ) + assert p._litellm_model == "openai/zai.glm-5" + assert p.model_name == "zai.glm-5" # Public property shows original name + + +def test_model_with_api_base_and_existing_prefix(): + """If model already has openai/ prefix, don't add another.""" + p = LiteLLMProvider( + model="openai/gpt-4o", + api_base="http://localhost:4000/v1", + ) + assert p._litellm_model == "openai/gpt-4o" + + +def test_no_api_key_or_base(): + """Provider can be created without API key or base (for some backends).""" + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile") + assert p._api_key is None + assert p._api_base is None + + +# --------------------------------------------------------------------------- +# Successful generation +# --------------------------------------------------------------------------- + + +def _make_mock_response(text: str = "# Doc\nContent.") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 120 + usage.completion_tokens = 60 + + choice = MagicMock() + choice.message.content = text + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +async def test_generate_returns_generated_response(): + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + mock_response = _make_mock_response("Hello from LiteLLM") + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response + result = await provider.generate("sys", "user") + + assert isinstance(result, GeneratedResponse) + assert result.content == "Hello from LiteLLM" + + +async def test_generate_token_counts(): + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + mock_response = _make_mock_response() + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response + result = await provider.generate("sys", "user") + + assert result.input_tokens == 120 + assert result.output_tokens == 60 + + +async def test_generate_sends_correct_kwargs(): + provider = LiteLLMProvider( + model="groq/llama-3.1-70b-versatile", + api_key="sk-test", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("system msg", "user msg", max_tokens=2048, temperature=0.5) + + kw = captured_kwargs[0] + assert kw["model"] == "groq/llama-3.1-70b-versatile" + assert kw["max_tokens"] == 2048 + assert kw["temperature"] == 0.5 + assert kw["api_key"] == "sk-test" + messages = kw["messages"] + assert messages[0] == {"role": "system", "content": "system msg"} + assert messages[1] == {"role": "user", "content": "user msg"} + + +async def test_generate_with_api_base(): + """With api_base (local proxy), should pass api_base and dummy key.""" + provider = LiteLLMProvider( + model="zai.glm-5", + api_base="http://localhost:4000/v1", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + # Model should have openai/ prefix for proxy routing + assert kw["model"] == "openai/zai.glm-5" + assert kw["api_base"] == "http://localhost:4000/v1" + # Dummy key should be added when using api_base without api_key + assert kw["api_key"] == "sk-dummy" + + +async def test_generate_with_api_base_and_api_key(): + """With both api_base and api_key, should use provided key.""" + provider = LiteLLMProvider( + model="zai.glm-5", + api_key="sk-real-key", + api_base="http://localhost:4000/v1", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + assert kw["api_key"] == "sk-real-key" + assert kw["api_base"] == "http://localhost:4000/v1" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +async def test_api_error(): + import litellm + + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = litellm.APIError( + message="server error", + llm_provider="openai", + model="gpt-4o", + status_code=500, + ) + with pytest.raises(ProviderError): + await provider.generate("sys", "user") \ No newline at end of file From 27f67706dbddfb56c0f690ae3eda5484f3672639 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Mon, 6 Apr 2026 15:42:20 +0530 Subject: [PATCH 2/6] fix(litellm): add inline comment for sk-dummy to avoid secret scanner false positives Co-Authored-By: Claude Opus 4.6 --- packages/core/src/repowise/core/providers/llm/litellm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/packages/core/src/repowise/core/providers/llm/litellm.py b/packages/core/src/repowise/core/providers/llm/litellm.py index c52e463..9092a75 100644 --- a/packages/core/src/repowise/core/providers/llm/litellm.py +++ b/packages/core/src/repowise/core/providers/llm/litellm.py @@ -155,10 +155,8 @@ async def _generate_with_retry( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base - # Local proxy without auth: OpenAI SDK still requires a key. - # Use a dummy key if none provided. if not self._api_key: - call_kwargs["api_key"] = "sk-dummy" + call_kwargs["api_key"] = "sk-dummy" # LiteLLM requires a non-empty key even for unauthenticated local proxies (OpenAI SDK requirement) try: response = await litellm.acompletion(**call_kwargs) @@ -236,10 +234,8 @@ async def stream_chat( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base - # Local proxy without auth: OpenAI SDK still requires a key. - # Use a dummy key if none provided. if not self._api_key: - call_kwargs["api_key"] = "sk-dummy" + call_kwargs["api_key"] = "sk-dummy" # LiteLLM requires a non-empty key even for unauthenticated local proxies (OpenAI SDK requirement) try: stream = await litellm.acompletion(**call_kwargs) From adde5e30c28fb593b19261d4d18bb0453868a4c5 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Sun, 12 Apr 2026 13:27:05 +0530 Subject: [PATCH 3/6] feat: add Z.AI (Zhipu AI) provider support Add first-class support for Z.AI with OpenAI-compatible API. - New ZAIProvider with thinking disabled by default for GLM-5 family - Plan selection: 'coding' (subscription) or 'general' (pay-as-you-go) - Environment variables: ZAI_API_KEY, ZAI_PLAN, ZAI_BASE_URL, ZAI_THINKING - Rate limit defaults and auto-detection in CLI helpers Closes #68 --- packages/cli/src/repowise/cli/helpers.py | 25 +- .../repowise/core/providers/llm/registry.py | 7 +- .../src/repowise/core/providers/llm/zai.py | 332 ++++++++++++++++ .../core/src/repowise/core/rate_limiter.py | 1 + .../unit/test_providers/test_zai_provider.py | 356 ++++++++++++++++++ 5 files changed, 719 insertions(+), 2 deletions(-) create mode 100644 packages/core/src/repowise/core/providers/llm/zai.py create mode 100644 tests/unit/test_providers/test_zai_provider.py diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index 124481c..6203a91 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -268,6 +268,16 @@ def resolve_provider( kwargs["api_key"] = os.environ["LITELLM_API_KEY"] if os.environ.get("LITELLM_BASE_URL"): kwargs["api_base"] = os.environ["LITELLM_BASE_URL"] + elif provider_name == "zai": + # Z.AI: API key, plan, base URL, and thinking mode + if os.environ.get("ZAI_API_KEY"): + kwargs["api_key"] = os.environ["ZAI_API_KEY"] + if os.environ.get("ZAI_PLAN"): + kwargs["plan"] = os.environ["ZAI_PLAN"] + if os.environ.get("ZAI_BASE_URL"): + kwargs["base_url"] = os.environ["ZAI_BASE_URL"] + if os.environ.get("ZAI_THINKING"): + kwargs["thinking"] = os.environ["ZAI_THINKING"] return get_provider(provider_name, **kwargs) @@ -314,11 +324,23 @@ def resolve_provider( else {"api_base": os.environ["LITELLM_BASE_URL"]} ) return get_provider("litellm", **kwargs) + # Z.AI: check for API key + if os.environ.get("ZAI_API_KEY") and os.environ["ZAI_API_KEY"].strip(): + kwargs = {"api_key": os.environ["ZAI_API_KEY"]} + if model: + kwargs["model"] = model + if os.environ.get("ZAI_PLAN"): + kwargs["plan"] = os.environ["ZAI_PLAN"] + if os.environ.get("ZAI_BASE_URL"): + kwargs["base_url"] = os.environ["ZAI_BASE_URL"] + if os.environ.get("ZAI_THINKING"): + kwargs["thinking"] = os.environ["ZAI_THINKING"] + return get_provider("zai", **kwargs) raise click.ClickException( "No provider configured. Use --provider, set REPOWISE_PROVIDER, " "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / " - "LITELLM_API_KEY / LITELLM_BASE_URL." + "LITELLM_API_KEY / LITELLM_BASE_URL / ZAI_API_KEY." ) @@ -358,6 +380,7 @@ def _is_env_var_exists(var_name: str) -> bool: "LITELLM_API_KEY", "LITELLM_BASE_URL", ], # Either one (API key for cloud, base URL for local) + "zai": ["ZAI_API_KEY"], } if provider_name: diff --git a/packages/core/src/repowise/core/providers/llm/registry.py b/packages/core/src/repowise/core/providers/llm/registry.py index 48e07a7..7f50fbf 100644 --- a/packages/core/src/repowise/core/providers/llm/registry.py +++ b/packages/core/src/repowise/core/providers/llm/registry.py @@ -7,8 +7,10 @@ Built-in providers: - anthropic → AnthropicProvider - openai → OpenAIProvider + - gemini → GeminiProvider - ollama → OllamaProvider - litellm → LiteLLMProvider + - zai → ZAIProvider - mock → MockProvider (testing only) Custom provider registration: @@ -24,7 +26,8 @@ from __future__ import annotations import importlib -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from repowise.core.providers.llm.base import BaseProvider from repowise.core.rate_limiter import PROVIDER_DEFAULTS, RateLimitConfig, RateLimiter @@ -39,6 +42,7 @@ "gemini": ("repowise.core.providers.llm.gemini", "GeminiProvider"), "ollama": ("repowise.core.providers.llm.ollama", "OllamaProvider"), "litellm": ("repowise.core.providers.llm.litellm", "LiteLLMProvider"), + "zai": ("repowise.core.providers.llm.zai", "ZAIProvider"), "mock": ("repowise.core.providers.llm.mock", "MockProvider"), } @@ -135,6 +139,7 @@ def get_provider( "gemini": "google-genai", "ollama": "openai", # ollama uses the openai package "litellm": "litellm", + "zai": "openai", # zai uses the openai package (OpenAI-compatible API) } package = _missing.get(name, name) raise ImportError( diff --git a/packages/core/src/repowise/core/providers/llm/zai.py b/packages/core/src/repowise/core/providers/llm/zai.py new file mode 100644 index 0000000..4b9bc8c --- /dev/null +++ b/packages/core/src/repowise/core/providers/llm/zai.py @@ -0,0 +1,332 @@ +"""Z.AI (Zhipu AI) provider for repowise. + +Z.AI provides competitive models (GLM-5 family) at accessible pricing through +two API plans: + + - coding: Subscription-based resource package (default) + - general: Pay-as-you-go + +Key features: + - OpenAI-compatible API + - GLM-5 family reasoning models with thinking disabled by default + - Plan selection via constructor or ZAI_PLAN environment variable + +Models: + - glm-5-turbo + - glm-5.1 + - glm-5 + - glm-4.7 + +Reasoning models (GLM-5 family) have thinking enabled by default, which consumes +85-95% of output tokens on chain-of-thought. This provider disables thinking by +default for efficient structured output generation. + +Reference: https://open.bigmodel.cn/dev/api +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Literal + +import structlog +from openai import APIStatusError as _OpenAIAPIStatusError +from openai import AsyncOpenAI +from tenacity import ( + RetryError, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from repowise.core.providers.llm.base import ( + BaseProvider, + ChatStreamEvent, + ChatToolCall, + GeneratedResponse, + ProviderError, + RateLimitError, +) +from repowise.core.rate_limiter import RateLimiter + +if TYPE_CHECKING: + from repowise.core.generation.cost_tracker import CostTracker + +log = structlog.get_logger(__name__) + +_MAX_RETRIES = 3 +_MIN_WAIT = 1.0 +_MAX_WAIT = 4.0 + +# Z.AI API endpoints by plan +_PLAN_BASE_URLS: dict[str, str] = { + "coding": "https://api.z.ai/api/coding/paas/v4", + "general": "https://api.z.ai/api/paas/v4", +} + +# Default model for Z.AI +_DEFAULT_MODEL = "glm-5.1" + +# Type for plan parameter +PlanType = Literal["coding", "general"] + + +class ZAIProvider(BaseProvider): + """Z.AI (Zhipu AI) chat provider. + + Uses the OpenAI-compatible API with thinking disabled by default + for efficient structured output generation. + + Args: + model: Z.AI model name (e.g., 'glm-5.1', 'glm-5-turbo', 'glm-4.7'). + Defaults to 'glm-5.1'. + api_key: API key for authentication. Reads from ZAI_API_KEY env var + if not provided. + plan: API plan to use. 'coding' for subscription-based resource + package, 'general' for pay-as-you-go. Defaults to 'coding'. + Can also be set via ZAI_PLAN environment variable. + base_url: Override API base URL. If provided, takes precedence + over plan selection. + thinking: Thinking mode for GLM-5 family. 'disabled' by default + to avoid reasoning token overhead. Set to 'enabled' for + complex reasoning tasks. + rate_limiter: Optional RateLimiter instance. + cost_tracker: Optional CostTracker for usage tracking. + """ + + def __init__( + self, + model: str = _DEFAULT_MODEL, + api_key: str | None = None, + plan: PlanType = "coding", + base_url: str | None = None, + thinking: str = "disabled", + rate_limiter: RateLimiter | None = None, + cost_tracker: "CostTracker | None" = None, # noqa: UP037 + ) -> None: + self._model = model + self._plan = plan + self._thinking = thinking + self._rate_limiter = rate_limiter + self._cost_tracker = cost_tracker + + # Resolve base URL: explicit base_url > plan lookup + effective_base_url = base_url or _PLAN_BASE_URLS.get(plan, _PLAN_BASE_URLS["coding"]) + + # Normalize base URL for OpenAI SDK + effective_base_url = effective_base_url.rstrip("/") + if not effective_base_url.endswith("/v1"): + effective_base_url += "/v1" + + # Store normalized base_url + self._base_url = effective_base_url + + # Initialize OpenAI client + self._client = AsyncOpenAI( + api_key=api_key, + base_url=effective_base_url, + ) + + @property + def provider_name(self) -> str: + return "zai" + + @property + def model_name(self) -> str: + return self._model + + async def generate( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int = 4096, + temperature: float = 0.3, + request_id: str | None = None, + ) -> GeneratedResponse: + if self._rate_limiter: + await self._rate_limiter.acquire(estimated_tokens=max_tokens) + + log.debug( + "zai.generate.start", + model=self._model, + max_tokens=max_tokens, + thinking=self._thinking, + request_id=request_id, + ) + + try: + return await self._generate_with_retry( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=temperature, + request_id=request_id, + ) + except RetryError as exc: + raise ProviderError( + "zai", + f"All {_MAX_RETRIES} retries exhausted: {exc}", + ) from exc + + @retry( + retry=retry_if_exception_type(ProviderError), + stop=stop_after_attempt(_MAX_RETRIES), + wait=wait_exponential_jitter(initial=_MIN_WAIT, max=_MAX_WAIT), + reraise=True, + ) + async def _generate_with_retry( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int, + temperature: float, + request_id: str | None, + ) -> GeneratedResponse: + # Build request kwargs + call_kwargs: dict[str, Any] = { + "model": self._model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": temperature, + "max_tokens": max_tokens, + } + + # Disable thinking for GLM-5 family by default + # This prevents reasoning tokens from consuming output budget + if self._thinking == "disabled": + call_kwargs["extra_body"] = {"thinking": {"type": "disabled"}} + + try: + response = await self._client.chat.completions.create(**call_kwargs) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc + except Exception as exc: + log.error("zai.generate.error", model=self._model, error=str(exc)) + raise ProviderError("zai", f"{type(exc).__name__}: {exc}") from exc + + usage = response.usage + result = GeneratedResponse( + content=response.choices[0].message.content or "", + input_tokens=usage.prompt_tokens if usage else 0, + output_tokens=usage.completion_tokens if usage else 0, + cached_tokens=0, + usage={ + "prompt_tokens": usage.prompt_tokens if usage else 0, + "completion_tokens": usage.completion_tokens if usage else 0, + }, + ) + + log.debug( + "zai.generate.done", + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + request_id=request_id, + ) + + if self._cost_tracker is not None: + import asyncio + + try: # noqa: SIM105 + asyncio.get_event_loop().create_task( + self._cost_tracker.record( + model=self._model, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + operation="doc_generation", + file_path=None, + ) + ) + except RuntimeError: + pass # No running event loop — skip async record + + return result + + # --- ChatProvider protocol implementation --- + + async def stream_chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + system_prompt: str, + max_tokens: int = 8192, + temperature: float = 0.7, + request_id: str | None = None, + tool_executor: Any | None = None, + ) -> AsyncIterator[ChatStreamEvent]: + """Stream chat via Z.AI's OpenAI-compatible endpoint.""" + import json as _json + + full_messages = [{"role": "system", "content": system_prompt}, *messages] + kwargs: dict[str, Any] = { + "model": self._model, + "messages": full_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + } + + if tools: + kwargs["tools"] = tools + + # Disable thinking for GLM-5 family by default + if self._thinking == "disabled": + kwargs["extra_body"] = {"thinking": {"type": "disabled"}} + + try: + stream = await self._client.chat.completions.create(**kwargs) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc + + tool_calls_acc: dict[int, dict[str, Any]] = {} + + try: + async for chunk in stream: + choice = chunk.choices[0] if chunk.choices else None + if not choice: + continue + + delta = choice.delta + finish = choice.finish_reason + + if delta and delta.content: + yield ChatStreamEvent(type="text_delta", text=delta.content) + + if delta and delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""} + acc = tool_calls_acc[idx] + if tc_delta.id: + acc["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + acc["name"] = tc_delta.function.name + if tc_delta.function.arguments: + acc["arguments"] += tc_delta.function.arguments + + if finish: + for idx in sorted(tool_calls_acc.keys()): + acc = tool_calls_acc[idx] + try: + args = _json.loads(acc["arguments"]) if acc["arguments"] else {} + except Exception: + args = {} + yield ChatStreamEvent( + type="tool_start", + tool_call=ChatToolCall(id=acc["id"], name=acc["name"], arguments=args), + ) + tool_calls_acc.clear() + stop_reason = "tool_use" if finish == "tool_calls" else "end_turn" + yield ChatStreamEvent(type="stop", stop_reason=stop_reason) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc \ No newline at end of file diff --git a/packages/core/src/repowise/core/rate_limiter.py b/packages/core/src/repowise/core/rate_limiter.py index 3612624..2c56d3b 100644 --- a/packages/core/src/repowise/core/rate_limiter.py +++ b/packages/core/src/repowise/core/rate_limiter.py @@ -49,6 +49,7 @@ class RateLimitConfig: # Ollama runs locally — effectively unlimited, but we cap to avoid OOM "ollama": RateLimitConfig(requests_per_minute=1_000, tokens_per_minute=10_000_000), "litellm": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), + "zai": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), } diff --git a/tests/unit/test_providers/test_zai_provider.py b/tests/unit/test_providers/test_zai_provider.py new file mode 100644 index 0000000..5507b5b --- /dev/null +++ b/tests/unit/test_providers/test_zai_provider.py @@ -0,0 +1,356 @@ +"""Unit tests for ZAIProvider. + +All tests mock the OpenAI client — no real API calls are made. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from repowise.core.providers.llm.base import GeneratedResponse, ProviderError, RateLimitError +from repowise.core.providers.llm.zai import _DEFAULT_MODEL, ZAIProvider + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_provider_name(): + p = ZAIProvider(api_key="test-key") + assert p.provider_name == "zai" + + +def test_default_model(): + p = ZAIProvider(api_key="test-key") + assert p.model_name == _DEFAULT_MODEL + assert p.model_name == "glm-5.1" + + +def test_custom_model(): + p = ZAIProvider(model="glm-5-turbo", api_key="test-key") + assert p.model_name == "glm-5-turbo" + + +def test_default_thinking_disabled(): + p = ZAIProvider(api_key="test-key") + assert p._thinking == "disabled" + + +def test_custom_thinking(): + p = ZAIProvider(api_key="test-key", thinking="enabled") + assert p._thinking == "enabled" + + +def test_default_plan_is_coding(): + """Default plan should be 'coding'.""" + p = ZAIProvider(api_key="test-key") + assert p._plan == "coding" + + +def test_coding_plan_base_url(): + """Coding plan should use coding endpoint.""" + p = ZAIProvider(api_key="test-key", plan="coding") + assert p._base_url == "https://api.z.ai/api/coding/paas/v4/v1" + + +def test_general_plan_base_url(): + """General plan should use general endpoint.""" + p = ZAIProvider(api_key="test-key", plan="general") + assert p._base_url == "https://api.z.ai/api/paas/v4/v1" + + +def test_base_url_overrides_plan(): + """Explicit base_url should take precedence over plan.""" + p = ZAIProvider(api_key="test-key", plan="coding", base_url="https://custom.api.com") + assert p._base_url == "https://custom.api.com/v1" + + +def test_custom_base_url(): + """Custom base URL should be used and normalized.""" + p = ZAIProvider(api_key="test-key", base_url="https://api.z.ai/api/paas/v4") + assert p._base_url == "https://api.z.ai/api/paas/v4/v1" + + +def test_base_url_normalization(): + """Base URL should be normalized to end with /v1.""" + p = ZAIProvider(api_key="test-key", base_url="https://custom.api.com") + assert p._base_url == "https://custom.api.com/v1" + + +def test_base_url_already_has_v1(): + """Base URL already ending with /v1 should not get another suffix.""" + p = ZAIProvider(api_key="test-key", base_url="https://custom.api.com/v1") + assert p._base_url == "https://custom.api.com/v1" + + +# --------------------------------------------------------------------------- +# Successful generation +# --------------------------------------------------------------------------- + + +def _make_mock_response(text: str = "# Doc\nContent.") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 120 + usage.completion_tokens = 60 + + choice = MagicMock() + choice.message.content = text + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +@pytest.mark.asyncio +async def test_generate_returns_generated_response(): + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response("Hello from Z.AI") + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await provider.generate("sys", "user") + + assert isinstance(result, GeneratedResponse) + assert result.content == "Hello from Z.AI" + + +@pytest.mark.asyncio +async def test_generate_token_counts(): + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response() + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await provider.generate("sys", "user") + + assert result.input_tokens == 120 + assert result.output_tokens == 60 + + +@pytest.mark.asyncio +async def test_generate_sends_correct_kwargs(): + provider = ZAIProvider(model="glm-5-turbo", api_key="test-key") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("system msg", "user msg", max_tokens=2048, temperature=0.5) + + kw = captured_kwargs[0] + assert kw["model"] == "glm-5-turbo" + assert kw["max_tokens"] == 2048 + assert kw["temperature"] == 0.5 + messages = kw["messages"] + assert messages[0] == {"role": "system", "content": "system msg"} + assert messages[1] == {"role": "user", "content": "user msg"} + + +@pytest.mark.asyncio +async def test_generate_disables_thinking_by_default(): + """By default, thinking should be disabled via extra_body.""" + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + assert "extra_body" in kw + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} + + +@pytest.mark.asyncio +async def test_generate_with_thinking_enabled(): + """When thinking is enabled, extra_body should not contain disabled.""" + provider = ZAIProvider(api_key="test-key", thinking="enabled") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + # When thinking is enabled, we don't send extra_body with thinking disabled + assert kw.get("extra_body") is None + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_api_error(): + from openai import APIStatusError + + provider = ZAIProvider(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 500 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=APIStatusError( + "server error", + response=mock_response, + body={}, + ), + ), pytest.raises(ProviderError) as exc_info: + await provider.generate("sys", "user") + + assert exc_info.value.status_code == 500 + + +@pytest.mark.asyncio +async def test_rate_limit_error(): + from openai import APIStatusError + + provider = ZAIProvider(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 429 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=APIStatusError( + "rate limit exceeded", + response=mock_response, + body={}, + ), + ), pytest.raises(RateLimitError) as exc_info: + await provider.generate("sys", "user") + + assert exc_info.value.status_code == 429 + + +# --------------------------------------------------------------------------- +# Stream chat +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_chat_yields_text_deltas(): + provider = ZAIProvider(api_key="test-key") + + # Create mock chunks + chunk1 = MagicMock() + chunk1.choices = [MagicMock()] + chunk1.choices[0].delta = MagicMock(content="Hello") + chunk1.choices[0].finish_reason = None + + chunk2 = MagicMock() + chunk2.choices = [MagicMock()] + chunk2.choices[0].delta = MagicMock(content=" world") + chunk2.choices[0].finish_reason = None + + chunk3 = MagicMock() + chunk3.choices = [MagicMock()] + chunk3.choices[0].delta = MagicMock(content=None) + chunk3.choices[0].finish_reason = "stop" + + async def fake_stream(): + yield chunk1 + yield chunk2 + yield chunk3 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=fake_stream(), + ): + events = [] + async for event in provider.stream_chat( + messages=[{"role": "user", "content": "test"}], + tools=[], + system_prompt="sys", + ): + events.append(event) + + # Should have text deltas and a stop event + assert len(events) == 3 + assert events[0].type == "text_delta" + assert events[0].text == "Hello" + assert events[1].type == "text_delta" + assert events[1].text == " world" + assert events[2].type == "stop" + + +@pytest.mark.asyncio +async def test_stream_chat_disables_thinking(): + """Stream chat should also disable thinking by default.""" + provider = ZAIProvider(api_key="test-key") + captured_kwargs: list[dict] = [] + + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta = MagicMock(content="test") + chunk.choices[0].finish_reason = "stop" + + async def fake_stream(): + yield chunk + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return fake_stream() + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + events = [] + async for event in provider.stream_chat( + messages=[{"role": "user", "content": "test"}], + tools=[], + system_prompt="sys", + ): + events.append(event) + + kw = captured_kwargs[0] + assert "extra_body" in kw + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} \ No newline at end of file From ed682fbad44cd2276f558204a0536b428d05dfab Mon Sep 17 00:00:00 2001 From: Societus <93468672+Societus@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:58:31 -0700 Subject: [PATCH 4/6] feat: add generic tier-aware rate limiting framework Add RATE_LIMIT_TIERS class attribute and resolve_rate_limiter() static method to BaseProvider. Any provider with subscription tiers can define RATE_LIMIT_TIERS and pass tier + tiers to resolve_rate_limiter() to get automatic tier-aware rate limiter creation. Precedence: tier > explicit rate_limiter > None. Tier matching is case-insensitive. Invalid tiers raise ValueError. This is a provider-agnostic foundation -- no provider-specific code. Providers adopt it by defining RATE_LIMIT_TIERS and calling resolve_rate_limiter() in their constructor. Ref: #68 --- .../src/repowise/core/providers/llm/base.py | 57 ++++++++++++++- .../test_generic_tier_framework.py | 71 +++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_providers/test_generic_tier_framework.py diff --git a/packages/core/src/repowise/core/providers/llm/base.py b/packages/core/src/repowise/core/providers/llm/base.py index 9e649ad..e8d7397 100644 --- a/packages/core/src/repowise/core/providers/llm/base.py +++ b/packages/core/src/repowise/core/providers/llm/base.py @@ -15,7 +15,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Protocol, runtime_checkable +from typing import Any, AsyncIterator, Protocol, TYPE_CHECKING, runtime_checkable + +if TYPE_CHECKING: + from repowise.core.rate_limiter import RateLimitConfig, RateLimiter @dataclass @@ -59,8 +62,60 @@ class BaseProvider(ABC): - Return GeneratedResponse with correct token counts - Raise ProviderError on non-recoverable API errors - Raise RateLimitError on 429 responses after retries are exhausted + + Class Attributes: + RATE_LIMIT_TIERS: Optional mapping of tier name to RateLimitConfig. + Providers with subscription tiers (e.g., Z.AI's lite/pro/max, + MiniMax's starter/plus/max) define this to support tier-aware + rate limiting. When set, users can pass ``tier="pro"`` to the + constructor and the appropriate rate limiter is created automatically. """ + RATE_LIMIT_TIERS: dict[str, Any] = {} # Override in subclasses + + @staticmethod + def resolve_rate_limiter( + tier: str | None = None, + tiers: dict[str, Any] | None = None, + rate_limiter: Any | None = None, + ) -> Any | None: + """Resolve rate limiter using tier precedence. + + Precedence: tier > explicit rate_limiter > None. + + When tier is set, it takes precedence -- it represents a specific + provider signal that overrides the generic registry default. + + Args: + tier: Tier name (e.g., 'lite', 'pro', 'max'). Case-insensitive. + tiers: Mapping of tier name to RateLimitConfig. + rate_limiter: Explicitly provided RateLimiter instance. + + Returns: + A RateLimiter instance, or None if neither tier nor + rate_limiter is provided. + + Raises: + ValueError: If tier is not found in the tiers mapping. + """ + # Late import to avoid circular dependency at module level + from repowise.core.rate_limiter import RateLimiter + + if tier is not None: + if not tiers: + msg = f"Tier {tier!r} specified but provider defines no tiers" + raise ValueError(msg) + tier_key = tier.lower() + tier_config = tiers.get(tier_key) + if tier_config is None: + valid = ", ".join(sorted(tiers)) + msg = f"Unknown tier {tier!r}. Valid tiers: {valid}" + raise ValueError(msg) + return RateLimiter(tier_config) + if rate_limiter is not None: + return rate_limiter + return None + @abstractmethod async def generate( self, diff --git a/tests/unit/test_providers/test_generic_tier_framework.py b/tests/unit/test_providers/test_generic_tier_framework.py new file mode 100644 index 0000000..1f617f5 --- /dev/null +++ b/tests/unit/test_providers/test_generic_tier_framework.py @@ -0,0 +1,71 @@ +"""Unit tests for BaseProvider.resolve_rate_limiter (generic tier framework). + +These tests verify the tier resolution logic independent of any specific provider. +Any provider that defines RATE_LIMIT_TIERS gets this behavior for free via +BaseProvider.resolve_rate_limiter(). +""" + +from __future__ import annotations + +import pytest + +from repowise.core.providers.llm.base import BaseProvider +from repowise.core.rate_limiter import RateLimitConfig, RateLimiter + + +def test_resolve_rate_limiter_with_tier(): + """resolve_rate_limiter should create a limiter from tier config.""" + tiers = { + "basic": RateLimitConfig(requests_per_minute=5, tokens_per_minute=10_000), + "premium": RateLimitConfig(requests_per_minute=50, tokens_per_minute=100_000), + } + limiter = BaseProvider.resolve_rate_limiter(tier="premium", tiers=tiers) + assert limiter is not None + assert limiter.config.requests_per_minute == 50 + + +def test_resolve_rate_limiter_tier_overrides_explicit(): + """Tier should take precedence over explicit rate_limiter.""" + tiers = {"pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=100_000)} + explicit = RateLimiter(RateLimitConfig(requests_per_minute=999, tokens_per_minute=999_999)) + limiter = BaseProvider.resolve_rate_limiter(tier="pro", tiers=tiers, rate_limiter=explicit) + assert limiter is not explicit + assert limiter.config.requests_per_minute == 30 + + +def test_resolve_rate_limiter_explicit_without_tier(): + """Without tier, explicit rate_limiter should be returned.""" + explicit = RateLimiter(RateLimitConfig(requests_per_minute=42, tokens_per_minute=420_000)) + limiter = BaseProvider.resolve_rate_limiter(rate_limiter=explicit) + assert limiter is explicit + + +def test_resolve_rate_limiter_none_when_nothing_provided(): + """Should return None when neither tier nor rate_limiter is provided.""" + limiter = BaseProvider.resolve_rate_limiter() + assert limiter is None + + +def test_resolve_rate_limiter_invalid_tier(): + """Invalid tier should raise ValueError.""" + tiers = {"basic": RateLimitConfig(requests_per_minute=5, tokens_per_minute=10_000)} + with pytest.raises(ValueError, match="Unknown tier"): + BaseProvider.resolve_rate_limiter(tier="enterprise", tiers=tiers) + + +def test_resolve_rate_limiter_tier_but_no_tiers_defined(): + """Tier with empty tiers dict should raise ValueError.""" + with pytest.raises(ValueError, match="defines no tiers"): + BaseProvider.resolve_rate_limiter(tier="pro", tiers={}) + + +def test_resolve_rate_limiter_case_insensitive(): + """Tier matching should be case-insensitive.""" + tiers = {"pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=100_000)} + limiter = BaseProvider.resolve_rate_limiter(tier="PRO", tiers=tiers) + assert limiter.config.requests_per_minute == 30 + + +def test_base_provider_default_empty_tiers(): + """BaseProvider should have empty RATE_LIMIT_TIERS by default.""" + assert BaseProvider.RATE_LIMIT_TIERS == {} From b4bdfd67e5e2b6da3762cd29390178e953fee018 Mon Sep 17 00:00:00 2001 From: Societus <93468672+Societus@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:08:33 -0700 Subject: [PATCH 5/6] feat(zai): adopt generic tier framework for plan-aware rate limiting Wire Z.AI provider into the BaseProvider tier framework (from PR #NN). Changes: - Define RATE_LIMIT_TIERS on ZAIProvider with Lite/Pro/Max configs derived from Z.AI support guidance (April 2026) - Use resolve_rate_limiter() in constructor (tier > explicit > none) - Add ZAI_TIER env var support in CLI helpers - Add ZAI_TIER_DEFAULTS to rate_limiter.py for reference - Update PROVIDER_DEFAULTS['zai'] to conservative Lite-tier default - Bump retry budget: 5 retries / 30s max wait (from 3/4s) for Z.AI load-shedding tolerance - Add tier parameter to constructor and docstring Rate limit context: - Z.AI concurrency limits are aggregate, dynamic, and load-dependent - Advanced models (GLM-5 family) consume 2-3x quota per prompt - Conservative defaults: Lite 10 RPM, Pro 30 RPM, Max 60 RPM - Ref: https://docs.z.ai/devpack/usage-policy Depends on: feat/generic-tier-framework Supersedes: #80 (deprecates monolithic PR in favor of layered approach) Ref: #68 --- packages/cli/src/repowise/cli/helpers.py | 4 + .../src/repowise/core/providers/llm/zai.py | 44 ++++++-- .../core/src/repowise/core/rate_limiter.py | 11 +- .../unit/test_providers/test_zai_provider.py | 102 ++++++++++++++++++ 4 files changed, 154 insertions(+), 7 deletions(-) diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index 6203a91..8ac377b 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -278,6 +278,8 @@ def resolve_provider( kwargs["base_url"] = os.environ["ZAI_BASE_URL"] if os.environ.get("ZAI_THINKING"): kwargs["thinking"] = os.environ["ZAI_THINKING"] + if os.environ.get("ZAI_TIER"): + kwargs["tier"] = os.environ["ZAI_TIER"] return get_provider(provider_name, **kwargs) @@ -335,6 +337,8 @@ def resolve_provider( kwargs["base_url"] = os.environ["ZAI_BASE_URL"] if os.environ.get("ZAI_THINKING"): kwargs["thinking"] = os.environ["ZAI_THINKING"] + if os.environ.get("ZAI_TIER"): + kwargs["tier"] = os.environ["ZAI_TIER"] return get_provider("zai", **kwargs) raise click.ClickException( diff --git a/packages/core/src/repowise/core/providers/llm/zai.py b/packages/core/src/repowise/core/providers/llm/zai.py index 4b9bc8c..9a492b6 100644 --- a/packages/core/src/repowise/core/providers/llm/zai.py +++ b/packages/core/src/repowise/core/providers/llm/zai.py @@ -48,16 +48,16 @@ ProviderError, RateLimitError, ) -from repowise.core.rate_limiter import RateLimiter +from repowise.core.rate_limiter import RateLimitConfig, RateLimiter if TYPE_CHECKING: from repowise.core.generation.cost_tracker import CostTracker log = structlog.get_logger(__name__) -_MAX_RETRIES = 3 -_MIN_WAIT = 1.0 -_MAX_WAIT = 4.0 +_MAX_RETRIES = 5 +_MIN_WAIT = 2.0 +_MAX_WAIT = 30.0 # Z.AI API endpoints by plan _PLAN_BASE_URLS: dict[str, str] = { @@ -91,10 +91,32 @@ class ZAIProvider(BaseProvider): thinking: Thinking mode for GLM-5 family. 'disabled' by default to avoid reasoning token overhead. Set to 'enabled' for complex reasoning tasks. - rate_limiter: Optional RateLimiter instance. + tier: Z.AI subscription tier for rate limiting. One of 'lite', + 'pro', 'max'. When set, overrides the default rate limiter + with tier-appropriate limits. Can also be set via ZAI_TIER + environment variable. + rate_limiter: Optional RateLimiter instance. If not provided and + tier is set, a tier-appropriate limiter is created. + If neither is provided, the registry attaches a + conservative default. cost_tracker: Optional CostTracker for usage tracking. """ + # Z.AI subscription tier rate limits. + # Derived from Z.AI support guidance (April 2026): + # - Lite: 2-3 concurrent, lower tolerance + # - Pro: 5-8 concurrent, moderate tolerance + # - Max: 10-15 concurrent, highest tolerance + # Limits are aggregate across all models. Advanced models (GLM-5 family) + # consume 2-3x quota per prompt, so effective concurrency is lower when + # using those models. + # Ref: https://docs.z.ai/devpack/usage-policy + RATE_LIMIT_TIERS: dict[str, RateLimitConfig] = { + "lite": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), + "pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=150_000), + "max": RateLimitConfig(requests_per_minute=60, tokens_per_minute=300_000), + } + def __init__( self, model: str = _DEFAULT_MODEL, @@ -102,15 +124,25 @@ def __init__( plan: PlanType = "coding", base_url: str | None = None, thinking: str = "disabled", + tier: str | None = None, rate_limiter: RateLimiter | None = None, cost_tracker: "CostTracker | None" = None, # noqa: UP037 ) -> None: self._model = model self._plan = plan self._thinking = thinking - self._rate_limiter = rate_limiter + self._tier = tier self._cost_tracker = cost_tracker + # Resolve rate limiter: tier > explicit instance > none (registry attaches default) + self._rate_limiter = self.resolve_rate_limiter( + tier=tier, + tiers=self.RATE_LIMIT_TIERS, + rate_limiter=rate_limiter, + ) + if tier is not None and self._rate_limiter is not None: + log.info("zai.tier_rate_limiter", tier=tier.lower(), rpm=self._rate_limiter.config.requests_per_minute) + # Resolve base URL: explicit base_url > plan lookup effective_base_url = base_url or _PLAN_BASE_URLS.get(plan, _PLAN_BASE_URLS["coding"]) diff --git a/packages/core/src/repowise/core/rate_limiter.py b/packages/core/src/repowise/core/rate_limiter.py index 2c56d3b..cadd722 100644 --- a/packages/core/src/repowise/core/rate_limiter.py +++ b/packages/core/src/repowise/core/rate_limiter.py @@ -49,7 +49,16 @@ class RateLimitConfig: # Ollama runs locally — effectively unlimited, but we cap to avoid OOM "ollama": RateLimitConfig(requests_per_minute=1_000, tokens_per_minute=10_000_000), "litellm": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), - "zai": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), + # Z.AI: conservative default (Lite tier). Set ZAI_TIER for plan-specific limits. + "zai": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), +} + +# Z.AI per-tier rate limits derived from support correspondence (April 2026). +# Limits are dynamic and load-dependent; these are conservative estimates. +ZAI_TIER_DEFAULTS: dict[str, RateLimitConfig] = { + "lite": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), + "pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=150_000), + "max": RateLimitConfig(requests_per_minute=60, tokens_per_minute=300_000), } diff --git a/tests/unit/test_providers/test_zai_provider.py b/tests/unit/test_providers/test_zai_provider.py index 5507b5b..62a11d0 100644 --- a/tests/unit/test_providers/test_zai_provider.py +++ b/tests/unit/test_providers/test_zai_provider.py @@ -189,6 +189,108 @@ async def fake_create(**kwargs): assert kw["extra_body"] == {"thinking": {"type": "disabled"}} +# --------------------------------------------------------------------------- +# Tier-based rate limiting (via generic framework) +# --------------------------------------------------------------------------- + + +def test_tier_creates_rate_limiter(): + """Setting tier should create a tier-appropriate rate limiter.""" + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == ZAIProvider.RATE_LIMIT_TIERS["pro"].requests_per_minute + assert p._rate_limiter.config.tokens_per_minute == ZAIProvider.RATE_LIMIT_TIERS["pro"].tokens_per_minute + + +def test_tier_lite(): + """Lite tier should have conservative RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="lite") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 10 + assert p._rate_limiter.config.tokens_per_minute == 50_000 + + +def test_tier_pro(): + """Pro tier should have moderate RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 30 + assert p._rate_limiter.config.tokens_per_minute == 150_000 + + +def test_tier_max(): + """Max tier should have the highest RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="max") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 60 + assert p._rate_limiter.config.tokens_per_minute == 300_000 + + +def test_tier_case_insensitive(): + """Tier should be case-insensitive.""" + p = ZAIProvider(api_key="test-key", tier="PRO") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 30 + + +def test_tier_overrides_explicit_rate_limiter(): + """Tier takes precedence over an explicitly passed rate_limiter.""" + from repowise.core.rate_limiter import RateLimitConfig, RateLimiter + + explicit_limiter = RateLimiter(RateLimitConfig(requests_per_minute=999, tokens_per_minute=999_999)) + p = ZAIProvider(api_key="test-key", tier="lite", rate_limiter=explicit_limiter) + # Tier wins -- the explicit limiter should be discarded + assert p._rate_limiter.config.requests_per_minute == 10 + assert p._rate_limiter is not explicit_limiter + + +def test_no_tier_no_rate_limiter(): + """Without tier or rate_limiter, _rate_limiter should be None.""" + p = ZAIProvider(api_key="test-key") + assert p._rate_limiter is None + + +def test_explicit_rate_limiter_without_tier(): + """Without tier, an explicit rate_limiter should be used.""" + from repowise.core.rate_limiter import RateLimitConfig, RateLimiter + + limiter = RateLimiter(RateLimitConfig(requests_per_minute=42, tokens_per_minute=420_000)) + p = ZAIProvider(api_key="test-key", rate_limiter=limiter) + assert p._rate_limiter is limiter + + +def test_invalid_tier_raises(): + """An unrecognized tier should raise ValueError with valid options.""" + with pytest.raises(ValueError, match="Unknown tier"): + ZAIProvider(api_key="test-key", tier="enterprise") + + +def test_tier_stored(): + """Tier value should be stored on the provider.""" + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._tier == "pro" + + +def test_no_tier_stored_as_none(): + """Without tier, _tier should be None.""" + p = ZAIProvider(api_key="test-key") + assert p._tier is None + + +def test_provider_has_rate_limit_tiers_attribute(): + """ZAIProvider should have RATE_LIMIT_TIERS class attribute.""" + assert hasattr(ZAIProvider, "RATE_LIMIT_TIERS") + assert "lite" in ZAIProvider.RATE_LIMIT_TIERS + assert "pro" in ZAIProvider.RATE_LIMIT_TIERS + assert "max" in ZAIProvider.RATE_LIMIT_TIERS + + +def test_anthropic_no_tiers(): + """Providers without tier support should have empty RATE_LIMIT_TIERS.""" + from repowise.core.providers.llm.anthropic import AnthropicProvider + + assert AnthropicProvider.RATE_LIMIT_TIERS == {} + @pytest.mark.asyncio async def test_generate_with_thinking_enabled(): """When thinking is enabled, extra_body should not contain disabled.""" From 110a8ffa0db012301c7b082d40e79cc80c3790ba Mon Sep 17 00:00:00 2001 From: Societus <93468672+Societus@users.noreply.github.com> Date: Sat, 18 Apr 2026 10:28:44 -0700 Subject: [PATCH 6/6] fix(zai): remove dead ZAI_TIER_DEFAULTS and /v1 base URL append MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ZAI_TIER_DEFAULTS in rate_limiter.py duplicated the same values as ZAIProvider.RATE_LIMIT_TIERS and nothing imported it. Single source of truth lives on the provider class. The /v1 suffix normalization produced /v4/v1/chat/completions which 404s against Z.AI's live API. Their endpoint is /paas/v4 as-is; the OpenAI SDK appends /chat/completions itself. Tested against live Z.AI API: /v4/chat/completions → 200 /v4/v1/chat/completions → 404 Addresses review feedback from @swati510 on #83. --- packages/core/src/repowise/core/providers/llm/zai.py | 2 -- packages/core/src/repowise/core/rate_limiter.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/packages/core/src/repowise/core/providers/llm/zai.py b/packages/core/src/repowise/core/providers/llm/zai.py index 9a492b6..061aa18 100644 --- a/packages/core/src/repowise/core/providers/llm/zai.py +++ b/packages/core/src/repowise/core/providers/llm/zai.py @@ -148,8 +148,6 @@ def __init__( # Normalize base URL for OpenAI SDK effective_base_url = effective_base_url.rstrip("/") - if not effective_base_url.endswith("/v1"): - effective_base_url += "/v1" # Store normalized base_url self._base_url = effective_base_url diff --git a/packages/core/src/repowise/core/rate_limiter.py b/packages/core/src/repowise/core/rate_limiter.py index cadd722..f150b03 100644 --- a/packages/core/src/repowise/core/rate_limiter.py +++ b/packages/core/src/repowise/core/rate_limiter.py @@ -53,13 +53,6 @@ class RateLimitConfig: "zai": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), } -# Z.AI per-tier rate limits derived from support correspondence (April 2026). -# Limits are dynamic and load-dependent; these are conservative estimates. -ZAI_TIER_DEFAULTS: dict[str, RateLimitConfig] = { - "lite": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), - "pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=150_000), - "max": RateLimitConfig(requests_per_minute=60, tokens_per_minute=300_000), -} class RateLimiter: