From 838f44d5965e98dc3a9dcd3a510cd61ddc1c3fa1 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Sun, 5 Apr 2026 21:37:48 +0530 Subject: [PATCH 1/4] feat(litellm): add support for local proxy without API key - Add litellm to interactive provider selection menu - Support LITELLM_BASE_URL for local proxy deployments (no API key required) - Auto-add openai/ prefix when using api_base for proper LiteLLM routing - Add dummy API key for local proxies (OpenAI SDK requirement) - Add validation and tests for litellm provider configuration Co-Authored-By: Claude Opus 4.6 --- packages/cli/src/repowise/cli/helpers.py | 44 +++- packages/cli/src/repowise/cli/ui.py | 32 ++- .../repowise/core/providers/llm/litellm.py | 38 +++- tests/unit/cli/test_helpers.py | 27 +++ .../test_providers/test_litellm_provider.py | 199 ++++++++++++++++++ 5 files changed, 324 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_providers/test_litellm_provider.py diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index d5c3383..124481c 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -262,6 +262,12 @@ def resolve_provider( kwargs["api_key"] = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") elif provider_name == "ollama" and os.environ.get("OLLAMA_BASE_URL"): kwargs["base_url"] = os.environ["OLLAMA_BASE_URL"] + elif provider_name == "litellm": + # LiteLLM: API key for cloud, base URL for local proxy + if os.environ.get("LITELLM_API_KEY"): + kwargs["api_key"] = os.environ["LITELLM_API_KEY"] + if os.environ.get("LITELLM_BASE_URL"): + kwargs["api_base"] = os.environ["LITELLM_BASE_URL"] return get_provider(provider_name, **kwargs) @@ -293,10 +299,26 @@ def resolve_provider( api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") kwargs = {"model": model, "api_key": api_key} if model else {"api_key": api_key} return get_provider("gemini", **kwargs) + # LiteLLM: check for API key (cloud) or base URL (local proxy) + if os.environ.get("LITELLM_API_KEY") and os.environ["LITELLM_API_KEY"].strip(): + kwargs = ( + {"model": model, "api_key": os.environ["LITELLM_API_KEY"]} + if model + else {"api_key": os.environ["LITELLM_API_KEY"]} + ) + return get_provider("litellm", **kwargs) + if os.environ.get("LITELLM_BASE_URL") and os.environ["LITELLM_BASE_URL"].strip(): + kwargs = ( + {"model": model, "api_base": os.environ["LITELLM_BASE_URL"]} + if model + else {"api_base": os.environ["LITELLM_BASE_URL"]} + ) + return get_provider("litellm", **kwargs) raise click.ClickException( "No provider configured. Use --provider, set REPOWISE_PROVIDER, " - "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / GOOGLE_API_KEY." + "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / " + "LITELLM_API_KEY / LITELLM_BASE_URL." ) @@ -332,7 +354,10 @@ def _is_env_var_exists(var_name: str) -> bool: "openai": ["OPENAI_API_KEY"], "gemini": ["GEMINI_API_KEY", "GOOGLE_API_KEY"], # Either one "ollama": ["OLLAMA_BASE_URL"], - "litellm": ["LITELLM_API_KEY"], # May need others depending on backend + "litellm": [ + "LITELLM_API_KEY", + "LITELLM_BASE_URL", + ], # Either one (API key for cloud, base URL for local) } if provider_name: @@ -348,6 +373,10 @@ def _is_env_var_exists(var_name: str) -> bool: # Special case: either GEMINI_API_KEY or GOOGLE_API_KEY if not (_is_env_var_set("GEMINI_API_KEY") or _is_env_var_set("GOOGLE_API_KEY")): missing_vars = env_vars + elif provider_name == "litellm": + # Special case: LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy) + if not (_is_env_var_set("LITELLM_API_KEY") or _is_env_var_set("LITELLM_BASE_URL")): + missing_vars = env_vars else: for var in env_vars: if not _is_env_var_set(var): @@ -370,6 +399,17 @@ def _is_env_var_exists(var_name: str) -> bool: ) continue + if name == "litellm": + # Special case: LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy) + # Only warn if explicitly requested and neither is set + if os.environ.get("REPOWISE_PROVIDER") == "litellm" and not ( + _is_env_var_set("LITELLM_API_KEY") or _is_env_var_set("LITELLM_BASE_URL") + ): + warnings.append( + "Provider 'litellm' requires LITELLM_API_KEY or LITELLM_BASE_URL environment variable" + ) + continue + missing = [var for var in env_vars if not _is_env_var_set(var)] if missing: # Only warn if this provider is explicitly requested OR diff --git a/packages/cli/src/repowise/cli/ui.py b/packages/cli/src/repowise/cli/ui.py index ae27c38..b089307 100644 --- a/packages/cli/src/repowise/cli/ui.py +++ b/packages/cli/src/repowise/cli/ui.py @@ -268,11 +268,14 @@ def print_phase_header( "litellm": "groq/llama-3.1-70b-versatile", } +# For most providers, a single env var indicates configuration. +# litellm is special: can use LITELLM_API_KEY (cloud) OR LITELLM_BASE_URL (local proxy). _PROVIDER_ENV: dict[str, str] = { "gemini": "GEMINI_API_KEY", "openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY", "ollama": "OLLAMA_BASE_URL", + "litellm": "LITELLM_API_KEY", # Also checks LITELLM_BASE_URL in _detect_provider_status } _PROVIDER_SIGNUP: dict[str, str] = { @@ -280,6 +283,7 @@ def print_phase_header( "openai": "https://platform.openai.com/api-keys", "anthropic": "https://console.anthropic.com/settings/keys", "ollama": "https://ollama.com/download", + "litellm": "https://docs.litellm.ai/docs/proxy/proxy", } @@ -410,6 +414,10 @@ def _detect_provider_status() -> dict[str, str]: if prov == "gemini": if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"): status[prov] = env_var + elif prov == "litellm": + # litellm can be configured via API key (cloud) OR base URL (local proxy) + if os.environ.get("LITELLM_API_KEY") or os.environ.get("LITELLM_BASE_URL"): + status[prov] = env_var elif os.environ.get(env_var): status[prov] = env_var return status @@ -476,14 +484,22 @@ def interactive_provider_select( env_var = _PROVIDER_ENV[chosen] signup_url = _PROVIDER_SIGNUP.get(chosen, "") console.print() - console.print(f" [bold]{chosen}[/bold] requires [cyan]{env_var}[/cyan].") - if signup_url: - console.print(f" Get your API key here: [{BRAND}]{signup_url}[/]") - console.print() - key = _prompt_api_key(console, chosen, env_var, repo_path=repo_path) - if not key: - console.print(f" [{WARN}]Skipped. Please select another provider.[/]") - return interactive_provider_select(console, model_flag, repo_path=repo_path) + # Special case: litellm local proxy doesn't need an API key + if chosen == "litellm" and os.environ.get("LITELLM_BASE_URL"): + console.print( + f" [{OK}]✓ Using LiteLLM proxy at[/] [{BRAND}]{os.environ['LITELLM_BASE_URL']}[/]" + ) + console.print(" [dim]No API key required for local proxy.[/dim]") + console.print() + else: + console.print(f" [bold]{chosen}[/bold] requires [cyan]{env_var}[/cyan].") + if signup_url: + console.print(f" Get your API key here: [{BRAND}]{signup_url}[/]") + console.print() + key = _prompt_api_key(console, chosen, env_var, repo_path=repo_path) + if not key: + console.print(f" [{WARN}]Skipped. Please select another provider.[/]") + return interactive_provider_select(console, model_flag, repo_path=repo_path) # --- model --- default_model = _PROVIDER_DEFAULTS.get(chosen, "") diff --git a/packages/core/src/repowise/core/providers/llm/litellm.py b/packages/core/src/repowise/core/providers/llm/litellm.py index 1273e22..c52e463 100644 --- a/packages/core/src/repowise/core/providers/llm/litellm.py +++ b/packages/core/src/repowise/core/providers/llm/litellm.py @@ -19,13 +19,16 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any + import structlog from tenacity import ( + RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter, - RetryError, ) from repowise.core.providers.llm.base import ( @@ -37,7 +40,6 @@ RateLimitError, ) -from typing import TYPE_CHECKING, Any, AsyncIterator from repowise.core.rate_limiter import RateLimiter if TYPE_CHECKING: @@ -55,9 +57,13 @@ class LiteLLMProvider(BaseProvider): Args: model: LiteLLM model string (e.g., "groq/llama-3.1-70b-versatile"). + When using api_base (local proxy), just use the model name + (e.g., "zai.glm-5") - the provider will auto-add "openai/" prefix. api_key: API key for the target provider. Some providers read from environment variables (e.g., GROQ_API_KEY, TOGETHER_API_KEY). - api_base: Optional custom API base URL (e.g., for self-hosted deployments). + For local proxies without auth, a dummy key is used. + api_base: Optional custom API base URL for self-hosted LiteLLM proxy. + When set, the model is treated as OpenAI-compatible. rate_limiter: Optional RateLimiter instance. """ @@ -75,6 +81,13 @@ def __init__( self._rate_limiter = rate_limiter self._cost_tracker = cost_tracker + # When using a custom api_base (proxy), treat model as OpenAI-compatible. + # LiteLLM requires "openai/" prefix to route to custom endpoints. + if api_base and not model.startswith("openai/"): + self._litellm_model = f"openai/{model}" + else: + self._litellm_model = model + @property def provider_name(self) -> str: return "litellm" @@ -130,7 +143,7 @@ async def _generate_with_retry( litellm.suppress_debug_info = True call_kwargs: dict[str, object] = { - "model": self._model, + "model": self._litellm_model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, @@ -142,6 +155,10 @@ async def _generate_with_retry( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base + # Local proxy without auth: OpenAI SDK still requires a key. + # Use a dummy key if none provided. + if not self._api_key: + call_kwargs["api_key"] = "sk-dummy" try: response = await litellm.acompletion(**call_kwargs) @@ -199,6 +216,7 @@ async def stream_chat( tool_executor: Any | None = None, ) -> AsyncIterator[ChatStreamEvent]: import json as _json + import litellm # type: ignore[import-untyped] litellm.set_verbose = False @@ -206,7 +224,7 @@ async def stream_chat( full_messages = [{"role": "system", "content": system_prompt}, *messages] call_kwargs: dict[str, Any] = { - "model": self._model, + "model": self._litellm_model, "messages": full_messages, "temperature": temperature, "max_tokens": max_tokens, @@ -218,6 +236,10 @@ async def stream_chat( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base + # Local proxy without auth: OpenAI SDK still requires a key. + # Use a dummy key if none provided. + if not self._api_key: + call_kwargs["api_key"] = "sk-dummy" try: stream = await litellm.acompletion(**call_kwargs) @@ -244,7 +266,11 @@ async def stream_chat( for tc_delta in delta.tool_calls: idx = tc_delta.index if idx not in tool_calls_acc: - tool_calls_acc[idx] = {"id": getattr(tc_delta, "id", "") or "", "name": "", "arguments": ""} + tool_calls_acc[idx] = { + "id": getattr(tc_delta, "id", "") or "", + "name": "", + "arguments": "", + } acc = tool_calls_acc[idx] if getattr(tc_delta, "id", None): acc["id"] = tc_delta.id diff --git a/tests/unit/cli/test_helpers.py b/tests/unit/cli/test_helpers.py index 1444ac2..07845c3 100644 --- a/tests/unit/cli/test_helpers.py +++ b/tests/unit/cli/test_helpers.py @@ -231,3 +231,30 @@ def test_anthropic_empty_key_auto_detect(self, monkeypatch): assert len(warnings) == 1 assert "anthropic" in warnings[0] assert "ANTHROPIC_API_KEY" in warnings[0] + + # --- litellm tests --- + + def test_litellm_with_api_key(self, monkeypatch): + monkeypatch.setenv("LITELLM_API_KEY", "test-key") + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + assert validate_provider_config() == [] + + def test_litellm_with_base_url(self, monkeypatch): + """Local proxy without API key should be valid.""" + monkeypatch.delenv("LITELLM_API_KEY", raising=False) + monkeypatch.setenv("LITELLM_BASE_URL", "http://localhost:4000/v1") + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + assert validate_provider_config() == [] + + def test_litellm_missing_both(self, monkeypatch): + """Should warn when neither API key nor base URL is set.""" + monkeypatch.delenv("LITELLM_API_KEY", raising=False) + monkeypatch.delenv("LITELLM_BASE_URL", raising=False) + monkeypatch.setenv("REPOWISE_PROVIDER", "litellm") + + warnings = validate_provider_config() + assert len(warnings) == 1 + assert "litellm" in warnings[0] + assert "LITELLM_API_KEY" in warnings[0] or "LITELLM_BASE_URL" in warnings[0] diff --git a/tests/unit/test_providers/test_litellm_provider.py b/tests/unit/test_providers/test_litellm_provider.py new file mode 100644 index 0000000..bf0bab4 --- /dev/null +++ b/tests/unit/test_providers/test_litellm_provider.py @@ -0,0 +1,199 @@ +"""Unit tests for LiteLLMProvider. + +All tests mock the litellm.acompletion call — no real API calls are made. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytest.importorskip("litellm", reason="litellm SDK not installed") + +from repowise.core.providers.llm.base import GeneratedResponse, ProviderError +from repowise.core.providers.llm.litellm import LiteLLMProvider + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_provider_name(): + p = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + assert p.provider_name == "litellm" + + +def test_default_model(): + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile", api_key="sk-test") + assert p.model_name == "groq/llama-3.1-70b-versatile" + + +def test_model_without_api_base(): + """Without api_base, model should be passed through unchanged.""" + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile", api_key="sk-test") + assert p._litellm_model == "groq/llama-3.1-70b-versatile" + + +def test_model_with_api_base_adds_openai_prefix(): + """With api_base (local proxy), model should get openai/ prefix.""" + p = LiteLLMProvider( + model="zai.glm-5", + api_base="http://localhost:4000/v1", + ) + assert p._litellm_model == "openai/zai.glm-5" + assert p.model_name == "zai.glm-5" # Public property shows original name + + +def test_model_with_api_base_and_existing_prefix(): + """If model already has openai/ prefix, don't add another.""" + p = LiteLLMProvider( + model="openai/gpt-4o", + api_base="http://localhost:4000/v1", + ) + assert p._litellm_model == "openai/gpt-4o" + + +def test_no_api_key_or_base(): + """Provider can be created without API key or base (for some backends).""" + p = LiteLLMProvider(model="groq/llama-3.1-70b-versatile") + assert p._api_key is None + assert p._api_base is None + + +# --------------------------------------------------------------------------- +# Successful generation +# --------------------------------------------------------------------------- + + +def _make_mock_response(text: str = "# Doc\nContent.") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 120 + usage.completion_tokens = 60 + + choice = MagicMock() + choice.message.content = text + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +async def test_generate_returns_generated_response(): + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + mock_response = _make_mock_response("Hello from LiteLLM") + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response + result = await provider.generate("sys", "user") + + assert isinstance(result, GeneratedResponse) + assert result.content == "Hello from LiteLLM" + + +async def test_generate_token_counts(): + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + mock_response = _make_mock_response() + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response + result = await provider.generate("sys", "user") + + assert result.input_tokens == 120 + assert result.output_tokens == 60 + + +async def test_generate_sends_correct_kwargs(): + provider = LiteLLMProvider( + model="groq/llama-3.1-70b-versatile", + api_key="sk-test", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("system msg", "user msg", max_tokens=2048, temperature=0.5) + + kw = captured_kwargs[0] + assert kw["model"] == "groq/llama-3.1-70b-versatile" + assert kw["max_tokens"] == 2048 + assert kw["temperature"] == 0.5 + assert kw["api_key"] == "sk-test" + messages = kw["messages"] + assert messages[0] == {"role": "system", "content": "system msg"} + assert messages[1] == {"role": "user", "content": "user msg"} + + +async def test_generate_with_api_base(): + """With api_base (local proxy), should pass api_base and dummy key.""" + provider = LiteLLMProvider( + model="zai.glm-5", + api_base="http://localhost:4000/v1", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + # Model should have openai/ prefix for proxy routing + assert kw["model"] == "openai/zai.glm-5" + assert kw["api_base"] == "http://localhost:4000/v1" + # Dummy key should be added when using api_base without api_key + assert kw["api_key"] == "sk-dummy" + + +async def test_generate_with_api_base_and_api_key(): + """With both api_base and api_key, should use provided key.""" + provider = LiteLLMProvider( + model="zai.glm-5", + api_key="sk-real-key", + api_base="http://localhost:4000/v1", + ) + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_acompletion(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = fake_acompletion + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + assert kw["api_key"] == "sk-real-key" + assert kw["api_base"] == "http://localhost:4000/v1" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +async def test_api_error(): + import litellm + + provider = LiteLLMProvider(model="gpt-4o", api_key="sk-test") + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.side_effect = litellm.APIError( + message="server error", + llm_provider="openai", + model="gpt-4o", + status_code=500, + ) + with pytest.raises(ProviderError): + await provider.generate("sys", "user") \ No newline at end of file From 27f67706dbddfb56c0f690ae3eda5484f3672639 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Mon, 6 Apr 2026 15:42:20 +0530 Subject: [PATCH 2/4] fix(litellm): add inline comment for sk-dummy to avoid secret scanner false positives Co-Authored-By: Claude Opus 4.6 --- packages/core/src/repowise/core/providers/llm/litellm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/packages/core/src/repowise/core/providers/llm/litellm.py b/packages/core/src/repowise/core/providers/llm/litellm.py index c52e463..9092a75 100644 --- a/packages/core/src/repowise/core/providers/llm/litellm.py +++ b/packages/core/src/repowise/core/providers/llm/litellm.py @@ -155,10 +155,8 @@ async def _generate_with_retry( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base - # Local proxy without auth: OpenAI SDK still requires a key. - # Use a dummy key if none provided. if not self._api_key: - call_kwargs["api_key"] = "sk-dummy" + call_kwargs["api_key"] = "sk-dummy" # LiteLLM requires a non-empty key even for unauthenticated local proxies (OpenAI SDK requirement) try: response = await litellm.acompletion(**call_kwargs) @@ -236,10 +234,8 @@ async def stream_chat( call_kwargs["api_key"] = self._api_key if self._api_base: call_kwargs["api_base"] = self._api_base - # Local proxy without auth: OpenAI SDK still requires a key. - # Use a dummy key if none provided. if not self._api_key: - call_kwargs["api_key"] = "sk-dummy" + call_kwargs["api_key"] = "sk-dummy" # LiteLLM requires a non-empty key even for unauthenticated local proxies (OpenAI SDK requirement) try: stream = await litellm.acompletion(**call_kwargs) From adde5e30c28fb593b19261d4d18bb0453868a4c5 Mon Sep 17 00:00:00 2001 From: Vinit Sutar Date: Sun, 12 Apr 2026 13:27:05 +0530 Subject: [PATCH 3/4] feat: add Z.AI (Zhipu AI) provider support Add first-class support for Z.AI with OpenAI-compatible API. - New ZAIProvider with thinking disabled by default for GLM-5 family - Plan selection: 'coding' (subscription) or 'general' (pay-as-you-go) - Environment variables: ZAI_API_KEY, ZAI_PLAN, ZAI_BASE_URL, ZAI_THINKING - Rate limit defaults and auto-detection in CLI helpers Closes #68 --- packages/cli/src/repowise/cli/helpers.py | 25 +- .../repowise/core/providers/llm/registry.py | 7 +- .../src/repowise/core/providers/llm/zai.py | 332 ++++++++++++++++ .../core/src/repowise/core/rate_limiter.py | 1 + .../unit/test_providers/test_zai_provider.py | 356 ++++++++++++++++++ 5 files changed, 719 insertions(+), 2 deletions(-) create mode 100644 packages/core/src/repowise/core/providers/llm/zai.py create mode 100644 tests/unit/test_providers/test_zai_provider.py diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index 124481c..6203a91 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -268,6 +268,16 @@ def resolve_provider( kwargs["api_key"] = os.environ["LITELLM_API_KEY"] if os.environ.get("LITELLM_BASE_URL"): kwargs["api_base"] = os.environ["LITELLM_BASE_URL"] + elif provider_name == "zai": + # Z.AI: API key, plan, base URL, and thinking mode + if os.environ.get("ZAI_API_KEY"): + kwargs["api_key"] = os.environ["ZAI_API_KEY"] + if os.environ.get("ZAI_PLAN"): + kwargs["plan"] = os.environ["ZAI_PLAN"] + if os.environ.get("ZAI_BASE_URL"): + kwargs["base_url"] = os.environ["ZAI_BASE_URL"] + if os.environ.get("ZAI_THINKING"): + kwargs["thinking"] = os.environ["ZAI_THINKING"] return get_provider(provider_name, **kwargs) @@ -314,11 +324,23 @@ def resolve_provider( else {"api_base": os.environ["LITELLM_BASE_URL"]} ) return get_provider("litellm", **kwargs) + # Z.AI: check for API key + if os.environ.get("ZAI_API_KEY") and os.environ["ZAI_API_KEY"].strip(): + kwargs = {"api_key": os.environ["ZAI_API_KEY"]} + if model: + kwargs["model"] = model + if os.environ.get("ZAI_PLAN"): + kwargs["plan"] = os.environ["ZAI_PLAN"] + if os.environ.get("ZAI_BASE_URL"): + kwargs["base_url"] = os.environ["ZAI_BASE_URL"] + if os.environ.get("ZAI_THINKING"): + kwargs["thinking"] = os.environ["ZAI_THINKING"] + return get_provider("zai", **kwargs) raise click.ClickException( "No provider configured. Use --provider, set REPOWISE_PROVIDER, " "or set ANTHROPIC_API_KEY / OPENAI_API_KEY / OLLAMA_BASE_URL / GEMINI_API_KEY / " - "LITELLM_API_KEY / LITELLM_BASE_URL." + "LITELLM_API_KEY / LITELLM_BASE_URL / ZAI_API_KEY." ) @@ -358,6 +380,7 @@ def _is_env_var_exists(var_name: str) -> bool: "LITELLM_API_KEY", "LITELLM_BASE_URL", ], # Either one (API key for cloud, base URL for local) + "zai": ["ZAI_API_KEY"], } if provider_name: diff --git a/packages/core/src/repowise/core/providers/llm/registry.py b/packages/core/src/repowise/core/providers/llm/registry.py index 48e07a7..7f50fbf 100644 --- a/packages/core/src/repowise/core/providers/llm/registry.py +++ b/packages/core/src/repowise/core/providers/llm/registry.py @@ -7,8 +7,10 @@ Built-in providers: - anthropic → AnthropicProvider - openai → OpenAIProvider + - gemini → GeminiProvider - ollama → OllamaProvider - litellm → LiteLLMProvider + - zai → ZAIProvider - mock → MockProvider (testing only) Custom provider registration: @@ -24,7 +26,8 @@ from __future__ import annotations import importlib -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from repowise.core.providers.llm.base import BaseProvider from repowise.core.rate_limiter import PROVIDER_DEFAULTS, RateLimitConfig, RateLimiter @@ -39,6 +42,7 @@ "gemini": ("repowise.core.providers.llm.gemini", "GeminiProvider"), "ollama": ("repowise.core.providers.llm.ollama", "OllamaProvider"), "litellm": ("repowise.core.providers.llm.litellm", "LiteLLMProvider"), + "zai": ("repowise.core.providers.llm.zai", "ZAIProvider"), "mock": ("repowise.core.providers.llm.mock", "MockProvider"), } @@ -135,6 +139,7 @@ def get_provider( "gemini": "google-genai", "ollama": "openai", # ollama uses the openai package "litellm": "litellm", + "zai": "openai", # zai uses the openai package (OpenAI-compatible API) } package = _missing.get(name, name) raise ImportError( diff --git a/packages/core/src/repowise/core/providers/llm/zai.py b/packages/core/src/repowise/core/providers/llm/zai.py new file mode 100644 index 0000000..4b9bc8c --- /dev/null +++ b/packages/core/src/repowise/core/providers/llm/zai.py @@ -0,0 +1,332 @@ +"""Z.AI (Zhipu AI) provider for repowise. + +Z.AI provides competitive models (GLM-5 family) at accessible pricing through +two API plans: + + - coding: Subscription-based resource package (default) + - general: Pay-as-you-go + +Key features: + - OpenAI-compatible API + - GLM-5 family reasoning models with thinking disabled by default + - Plan selection via constructor or ZAI_PLAN environment variable + +Models: + - glm-5-turbo + - glm-5.1 + - glm-5 + - glm-4.7 + +Reasoning models (GLM-5 family) have thinking enabled by default, which consumes +85-95% of output tokens on chain-of-thought. This provider disables thinking by +default for efficient structured output generation. + +Reference: https://open.bigmodel.cn/dev/api +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Literal + +import structlog +from openai import APIStatusError as _OpenAIAPIStatusError +from openai import AsyncOpenAI +from tenacity import ( + RetryError, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from repowise.core.providers.llm.base import ( + BaseProvider, + ChatStreamEvent, + ChatToolCall, + GeneratedResponse, + ProviderError, + RateLimitError, +) +from repowise.core.rate_limiter import RateLimiter + +if TYPE_CHECKING: + from repowise.core.generation.cost_tracker import CostTracker + +log = structlog.get_logger(__name__) + +_MAX_RETRIES = 3 +_MIN_WAIT = 1.0 +_MAX_WAIT = 4.0 + +# Z.AI API endpoints by plan +_PLAN_BASE_URLS: dict[str, str] = { + "coding": "https://api.z.ai/api/coding/paas/v4", + "general": "https://api.z.ai/api/paas/v4", +} + +# Default model for Z.AI +_DEFAULT_MODEL = "glm-5.1" + +# Type for plan parameter +PlanType = Literal["coding", "general"] + + +class ZAIProvider(BaseProvider): + """Z.AI (Zhipu AI) chat provider. + + Uses the OpenAI-compatible API with thinking disabled by default + for efficient structured output generation. + + Args: + model: Z.AI model name (e.g., 'glm-5.1', 'glm-5-turbo', 'glm-4.7'). + Defaults to 'glm-5.1'. + api_key: API key for authentication. Reads from ZAI_API_KEY env var + if not provided. + plan: API plan to use. 'coding' for subscription-based resource + package, 'general' for pay-as-you-go. Defaults to 'coding'. + Can also be set via ZAI_PLAN environment variable. + base_url: Override API base URL. If provided, takes precedence + over plan selection. + thinking: Thinking mode for GLM-5 family. 'disabled' by default + to avoid reasoning token overhead. Set to 'enabled' for + complex reasoning tasks. + rate_limiter: Optional RateLimiter instance. + cost_tracker: Optional CostTracker for usage tracking. + """ + + def __init__( + self, + model: str = _DEFAULT_MODEL, + api_key: str | None = None, + plan: PlanType = "coding", + base_url: str | None = None, + thinking: str = "disabled", + rate_limiter: RateLimiter | None = None, + cost_tracker: "CostTracker | None" = None, # noqa: UP037 + ) -> None: + self._model = model + self._plan = plan + self._thinking = thinking + self._rate_limiter = rate_limiter + self._cost_tracker = cost_tracker + + # Resolve base URL: explicit base_url > plan lookup + effective_base_url = base_url or _PLAN_BASE_URLS.get(plan, _PLAN_BASE_URLS["coding"]) + + # Normalize base URL for OpenAI SDK + effective_base_url = effective_base_url.rstrip("/") + if not effective_base_url.endswith("/v1"): + effective_base_url += "/v1" + + # Store normalized base_url + self._base_url = effective_base_url + + # Initialize OpenAI client + self._client = AsyncOpenAI( + api_key=api_key, + base_url=effective_base_url, + ) + + @property + def provider_name(self) -> str: + return "zai" + + @property + def model_name(self) -> str: + return self._model + + async def generate( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int = 4096, + temperature: float = 0.3, + request_id: str | None = None, + ) -> GeneratedResponse: + if self._rate_limiter: + await self._rate_limiter.acquire(estimated_tokens=max_tokens) + + log.debug( + "zai.generate.start", + model=self._model, + max_tokens=max_tokens, + thinking=self._thinking, + request_id=request_id, + ) + + try: + return await self._generate_with_retry( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=temperature, + request_id=request_id, + ) + except RetryError as exc: + raise ProviderError( + "zai", + f"All {_MAX_RETRIES} retries exhausted: {exc}", + ) from exc + + @retry( + retry=retry_if_exception_type(ProviderError), + stop=stop_after_attempt(_MAX_RETRIES), + wait=wait_exponential_jitter(initial=_MIN_WAIT, max=_MAX_WAIT), + reraise=True, + ) + async def _generate_with_retry( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int, + temperature: float, + request_id: str | None, + ) -> GeneratedResponse: + # Build request kwargs + call_kwargs: dict[str, Any] = { + "model": self._model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": temperature, + "max_tokens": max_tokens, + } + + # Disable thinking for GLM-5 family by default + # This prevents reasoning tokens from consuming output budget + if self._thinking == "disabled": + call_kwargs["extra_body"] = {"thinking": {"type": "disabled"}} + + try: + response = await self._client.chat.completions.create(**call_kwargs) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc + except Exception as exc: + log.error("zai.generate.error", model=self._model, error=str(exc)) + raise ProviderError("zai", f"{type(exc).__name__}: {exc}") from exc + + usage = response.usage + result = GeneratedResponse( + content=response.choices[0].message.content or "", + input_tokens=usage.prompt_tokens if usage else 0, + output_tokens=usage.completion_tokens if usage else 0, + cached_tokens=0, + usage={ + "prompt_tokens": usage.prompt_tokens if usage else 0, + "completion_tokens": usage.completion_tokens if usage else 0, + }, + ) + + log.debug( + "zai.generate.done", + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + request_id=request_id, + ) + + if self._cost_tracker is not None: + import asyncio + + try: # noqa: SIM105 + asyncio.get_event_loop().create_task( + self._cost_tracker.record( + model=self._model, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + operation="doc_generation", + file_path=None, + ) + ) + except RuntimeError: + pass # No running event loop — skip async record + + return result + + # --- ChatProvider protocol implementation --- + + async def stream_chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + system_prompt: str, + max_tokens: int = 8192, + temperature: float = 0.7, + request_id: str | None = None, + tool_executor: Any | None = None, + ) -> AsyncIterator[ChatStreamEvent]: + """Stream chat via Z.AI's OpenAI-compatible endpoint.""" + import json as _json + + full_messages = [{"role": "system", "content": system_prompt}, *messages] + kwargs: dict[str, Any] = { + "model": self._model, + "messages": full_messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + } + + if tools: + kwargs["tools"] = tools + + # Disable thinking for GLM-5 family by default + if self._thinking == "disabled": + kwargs["extra_body"] = {"thinking": {"type": "disabled"}} + + try: + stream = await self._client.chat.completions.create(**kwargs) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc + + tool_calls_acc: dict[int, dict[str, Any]] = {} + + try: + async for chunk in stream: + choice = chunk.choices[0] if chunk.choices else None + if not choice: + continue + + delta = choice.delta + finish = choice.finish_reason + + if delta and delta.content: + yield ChatStreamEvent(type="text_delta", text=delta.content) + + if delta and delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""} + acc = tool_calls_acc[idx] + if tc_delta.id: + acc["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + acc["name"] = tc_delta.function.name + if tc_delta.function.arguments: + acc["arguments"] += tc_delta.function.arguments + + if finish: + for idx in sorted(tool_calls_acc.keys()): + acc = tool_calls_acc[idx] + try: + args = _json.loads(acc["arguments"]) if acc["arguments"] else {} + except Exception: + args = {} + yield ChatStreamEvent( + type="tool_start", + tool_call=ChatToolCall(id=acc["id"], name=acc["name"], arguments=args), + ) + tool_calls_acc.clear() + stop_reason = "tool_use" if finish == "tool_calls" else "end_turn" + yield ChatStreamEvent(type="stop", stop_reason=stop_reason) + except _OpenAIAPIStatusError as exc: + if exc.status_code == 429: + raise RateLimitError("zai", str(exc), status_code=429) from exc + raise ProviderError("zai", str(exc), status_code=exc.status_code) from exc \ No newline at end of file diff --git a/packages/core/src/repowise/core/rate_limiter.py b/packages/core/src/repowise/core/rate_limiter.py index 3612624..2c56d3b 100644 --- a/packages/core/src/repowise/core/rate_limiter.py +++ b/packages/core/src/repowise/core/rate_limiter.py @@ -49,6 +49,7 @@ class RateLimitConfig: # Ollama runs locally — effectively unlimited, but we cap to avoid OOM "ollama": RateLimitConfig(requests_per_minute=1_000, tokens_per_minute=10_000_000), "litellm": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), + "zai": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), } diff --git a/tests/unit/test_providers/test_zai_provider.py b/tests/unit/test_providers/test_zai_provider.py new file mode 100644 index 0000000..5507b5b --- /dev/null +++ b/tests/unit/test_providers/test_zai_provider.py @@ -0,0 +1,356 @@ +"""Unit tests for ZAIProvider. + +All tests mock the OpenAI client — no real API calls are made. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from repowise.core.providers.llm.base import GeneratedResponse, ProviderError, RateLimitError +from repowise.core.providers.llm.zai import _DEFAULT_MODEL, ZAIProvider + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_provider_name(): + p = ZAIProvider(api_key="test-key") + assert p.provider_name == "zai" + + +def test_default_model(): + p = ZAIProvider(api_key="test-key") + assert p.model_name == _DEFAULT_MODEL + assert p.model_name == "glm-5.1" + + +def test_custom_model(): + p = ZAIProvider(model="glm-5-turbo", api_key="test-key") + assert p.model_name == "glm-5-turbo" + + +def test_default_thinking_disabled(): + p = ZAIProvider(api_key="test-key") + assert p._thinking == "disabled" + + +def test_custom_thinking(): + p = ZAIProvider(api_key="test-key", thinking="enabled") + assert p._thinking == "enabled" + + +def test_default_plan_is_coding(): + """Default plan should be 'coding'.""" + p = ZAIProvider(api_key="test-key") + assert p._plan == "coding" + + +def test_coding_plan_base_url(): + """Coding plan should use coding endpoint.""" + p = ZAIProvider(api_key="test-key", plan="coding") + assert p._base_url == "https://api.z.ai/api/coding/paas/v4/v1" + + +def test_general_plan_base_url(): + """General plan should use general endpoint.""" + p = ZAIProvider(api_key="test-key", plan="general") + assert p._base_url == "https://api.z.ai/api/paas/v4/v1" + + +def test_base_url_overrides_plan(): + """Explicit base_url should take precedence over plan.""" + p = ZAIProvider(api_key="test-key", plan="coding", base_url="https://custom.api.com") + assert p._base_url == "https://custom.api.com/v1" + + +def test_custom_base_url(): + """Custom base URL should be used and normalized.""" + p = ZAIProvider(api_key="test-key", base_url="https://api.z.ai/api/paas/v4") + assert p._base_url == "https://api.z.ai/api/paas/v4/v1" + + +def test_base_url_normalization(): + """Base URL should be normalized to end with /v1.""" + p = ZAIProvider(api_key="test-key", base_url="https://custom.api.com") + assert p._base_url == "https://custom.api.com/v1" + + +def test_base_url_already_has_v1(): + """Base URL already ending with /v1 should not get another suffix.""" + p = ZAIProvider(api_key="test-key", base_url="https://custom.api.com/v1") + assert p._base_url == "https://custom.api.com/v1" + + +# --------------------------------------------------------------------------- +# Successful generation +# --------------------------------------------------------------------------- + + +def _make_mock_response(text: str = "# Doc\nContent.") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 120 + usage.completion_tokens = 60 + + choice = MagicMock() + choice.message.content = text + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +@pytest.mark.asyncio +async def test_generate_returns_generated_response(): + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response("Hello from Z.AI") + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await provider.generate("sys", "user") + + assert isinstance(result, GeneratedResponse) + assert result.content == "Hello from Z.AI" + + +@pytest.mark.asyncio +async def test_generate_token_counts(): + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response() + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await provider.generate("sys", "user") + + assert result.input_tokens == 120 + assert result.output_tokens == 60 + + +@pytest.mark.asyncio +async def test_generate_sends_correct_kwargs(): + provider = ZAIProvider(model="glm-5-turbo", api_key="test-key") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("system msg", "user msg", max_tokens=2048, temperature=0.5) + + kw = captured_kwargs[0] + assert kw["model"] == "glm-5-turbo" + assert kw["max_tokens"] == 2048 + assert kw["temperature"] == 0.5 + messages = kw["messages"] + assert messages[0] == {"role": "system", "content": "system msg"} + assert messages[1] == {"role": "user", "content": "user msg"} + + +@pytest.mark.asyncio +async def test_generate_disables_thinking_by_default(): + """By default, thinking should be disabled via extra_body.""" + provider = ZAIProvider(api_key="test-key") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + assert "extra_body" in kw + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} + + +@pytest.mark.asyncio +async def test_generate_with_thinking_enabled(): + """When thinking is enabled, extra_body should not contain disabled.""" + provider = ZAIProvider(api_key="test-key", thinking="enabled") + mock_response = _make_mock_response() + captured_kwargs: list[dict] = [] + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return mock_response + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + await provider.generate("sys", "user") + + kw = captured_kwargs[0] + # When thinking is enabled, we don't send extra_body with thinking disabled + assert kw.get("extra_body") is None + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_api_error(): + from openai import APIStatusError + + provider = ZAIProvider(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 500 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=APIStatusError( + "server error", + response=mock_response, + body={}, + ), + ), pytest.raises(ProviderError) as exc_info: + await provider.generate("sys", "user") + + assert exc_info.value.status_code == 500 + + +@pytest.mark.asyncio +async def test_rate_limit_error(): + from openai import APIStatusError + + provider = ZAIProvider(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 429 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=APIStatusError( + "rate limit exceeded", + response=mock_response, + body={}, + ), + ), pytest.raises(RateLimitError) as exc_info: + await provider.generate("sys", "user") + + assert exc_info.value.status_code == 429 + + +# --------------------------------------------------------------------------- +# Stream chat +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_chat_yields_text_deltas(): + provider = ZAIProvider(api_key="test-key") + + # Create mock chunks + chunk1 = MagicMock() + chunk1.choices = [MagicMock()] + chunk1.choices[0].delta = MagicMock(content="Hello") + chunk1.choices[0].finish_reason = None + + chunk2 = MagicMock() + chunk2.choices = [MagicMock()] + chunk2.choices[0].delta = MagicMock(content=" world") + chunk2.choices[0].finish_reason = None + + chunk3 = MagicMock() + chunk3.choices = [MagicMock()] + chunk3.choices[0].delta = MagicMock(content=None) + chunk3.choices[0].finish_reason = "stop" + + async def fake_stream(): + yield chunk1 + yield chunk2 + yield chunk3 + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=fake_stream(), + ): + events = [] + async for event in provider.stream_chat( + messages=[{"role": "user", "content": "test"}], + tools=[], + system_prompt="sys", + ): + events.append(event) + + # Should have text deltas and a stop event + assert len(events) == 3 + assert events[0].type == "text_delta" + assert events[0].text == "Hello" + assert events[1].type == "text_delta" + assert events[1].text == " world" + assert events[2].type == "stop" + + +@pytest.mark.asyncio +async def test_stream_chat_disables_thinking(): + """Stream chat should also disable thinking by default.""" + provider = ZAIProvider(api_key="test-key") + captured_kwargs: list[dict] = [] + + chunk = MagicMock() + chunk.choices = [MagicMock()] + chunk.choices[0].delta = MagicMock(content="test") + chunk.choices[0].finish_reason = "stop" + + async def fake_stream(): + yield chunk + + async def fake_create(**kwargs): + captured_kwargs.append(kwargs) + return fake_stream() + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=fake_create, + ): + events = [] + async for event in provider.stream_chat( + messages=[{"role": "user", "content": "test"}], + tools=[], + system_prompt="sys", + ): + events.append(event) + + kw = captured_kwargs[0] + assert "extra_body" in kw + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} \ No newline at end of file From 5b1dd69bc6bb3e95e20642e309c08b372432bd4a Mon Sep 17 00:00:00 2001 From: Societus <93468672+Societus@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:38:00 -0700 Subject: [PATCH 4/4] feat(zai): add tier-aware rate limiting for Coding plan Add ZAI_TIER environment variable support for plan-aware concurrency. Users can set ZAI_TIER=lite|pro|max to get appropriate rate limits derived from Z.AI support guidance. Changes: - rate_limiter.py: Add ZAI_TIER_DEFAULTS (lite/pro/max configs), update provider default to conservative 10 RPM/50k TPM - zai.py: Add tier parameter, tier takes precedence over registry default limiter. Bump retry budget to 5 retries / 30s max wait. - helpers.py: Read ZAI_TIER env var in both explicit and auto-detect provider resolution paths - tests: 10 new tests covering tier creation, precedence, case insensitivity, invalid tier handling, and edge cases Ref: https://docs.z.ai/devpack/usage-policy Related: #68 --- packages/cli/src/repowise/cli/helpers.py | 6 +- .../src/repowise/core/providers/llm/zai.py | 35 ++++++- .../core/src/repowise/core/rate_limiter.py | 18 +++- .../unit/test_providers/test_zai_provider.py | 92 ++++++++++++++++++- 4 files changed, 143 insertions(+), 8 deletions(-) diff --git a/packages/cli/src/repowise/cli/helpers.py b/packages/cli/src/repowise/cli/helpers.py index 6203a91..6401913 100644 --- a/packages/cli/src/repowise/cli/helpers.py +++ b/packages/cli/src/repowise/cli/helpers.py @@ -269,7 +269,7 @@ def resolve_provider( if os.environ.get("LITELLM_BASE_URL"): kwargs["api_base"] = os.environ["LITELLM_BASE_URL"] elif provider_name == "zai": - # Z.AI: API key, plan, base URL, and thinking mode + # Z.AI: API key, plan, base URL, thinking mode, and tier if os.environ.get("ZAI_API_KEY"): kwargs["api_key"] = os.environ["ZAI_API_KEY"] if os.environ.get("ZAI_PLAN"): @@ -278,6 +278,8 @@ def resolve_provider( kwargs["base_url"] = os.environ["ZAI_BASE_URL"] if os.environ.get("ZAI_THINKING"): kwargs["thinking"] = os.environ["ZAI_THINKING"] + if os.environ.get("ZAI_TIER"): + kwargs["tier"] = os.environ["ZAI_TIER"] return get_provider(provider_name, **kwargs) @@ -335,6 +337,8 @@ def resolve_provider( kwargs["base_url"] = os.environ["ZAI_BASE_URL"] if os.environ.get("ZAI_THINKING"): kwargs["thinking"] = os.environ["ZAI_THINKING"] + if os.environ.get("ZAI_TIER"): + kwargs["tier"] = os.environ["ZAI_TIER"] return get_provider("zai", **kwargs) raise click.ClickException( diff --git a/packages/core/src/repowise/core/providers/llm/zai.py b/packages/core/src/repowise/core/providers/llm/zai.py index 4b9bc8c..b390886 100644 --- a/packages/core/src/repowise/core/providers/llm/zai.py +++ b/packages/core/src/repowise/core/providers/llm/zai.py @@ -48,16 +48,16 @@ ProviderError, RateLimitError, ) -from repowise.core.rate_limiter import RateLimiter +from repowise.core.rate_limiter import RateLimiter, ZAI_TIER_DEFAULTS if TYPE_CHECKING: from repowise.core.generation.cost_tracker import CostTracker log = structlog.get_logger(__name__) -_MAX_RETRIES = 3 +_MAX_RETRIES = 5 _MIN_WAIT = 1.0 -_MAX_WAIT = 4.0 +_MAX_WAIT = 30.0 # Z.AI API endpoints by plan _PLAN_BASE_URLS: dict[str, str] = { @@ -91,7 +91,14 @@ class ZAIProvider(BaseProvider): thinking: Thinking mode for GLM-5 family. 'disabled' by default to avoid reasoning token overhead. Set to 'enabled' for complex reasoning tasks. - rate_limiter: Optional RateLimiter instance. + tier: Z.AI subscription tier for rate limiting. One of 'lite', + 'pro', 'max'. When set, overrides the default rate limiter + with tier-appropriate limits. Can also be set via ZAI_TIER + environment variable. + rate_limiter: Optional RateLimiter instance. If not provided and + tier is set, a tier-appropriate limiter is created. + If neither is provided, the registry attaches a + conservative default. cost_tracker: Optional CostTracker for usage tracking. """ @@ -102,15 +109,33 @@ def __init__( plan: PlanType = "coding", base_url: str | None = None, thinking: str = "disabled", + tier: str | None = None, rate_limiter: RateLimiter | None = None, cost_tracker: "CostTracker | None" = None, # noqa: UP037 ) -> None: self._model = model self._plan = plan self._thinking = thinking - self._rate_limiter = rate_limiter + self._tier = tier self._cost_tracker = cost_tracker + # Resolve rate limiter: tier > explicit instance > none (registry attaches default) + # When tier is set, it takes precedence -- it's a specific Z.AI signal that + # overrides the generic registry default. + if tier is not None: + tier_key = tier.lower() + tier_config = ZAI_TIER_DEFAULTS.get(tier_key) + if tier_config is None: + valid = ", ".join(sorted(ZAI_TIER_DEFAULTS)) + msg = f"Unknown Z.AI tier {tier!r}. Valid tiers: {valid}" + raise ValueError(msg) + self._rate_limiter = RateLimiter(tier_config) + log.info("zai.tier_rate_limiter", tier=tier_key, rpm=tier_config.requests_per_minute) + elif rate_limiter is not None: + self._rate_limiter = rate_limiter + else: + self._rate_limiter = None # None — registry will attach default + # Resolve base URL: explicit base_url > plan lookup effective_base_url = base_url or _PLAN_BASE_URLS.get(plan, _PLAN_BASE_URLS["coding"]) diff --git a/packages/core/src/repowise/core/rate_limiter.py b/packages/core/src/repowise/core/rate_limiter.py index 2c56d3b..f83b550 100644 --- a/packages/core/src/repowise/core/rate_limiter.py +++ b/packages/core/src/repowise/core/rate_limiter.py @@ -49,7 +49,23 @@ class RateLimitConfig: # Ollama runs locally — effectively unlimited, but we cap to avoid OOM "ollama": RateLimitConfig(requests_per_minute=1_000, tokens_per_minute=10_000_000), "litellm": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), - "zai": RateLimitConfig(requests_per_minute=60, tokens_per_minute=150_000), + # Z.AI: conservative default (matches Lite tier). Override via ZAI_TIER env var. + "zai": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), +} + +# Z.AI subscription tier rate limits. +# Derived from Z.AI support guidance (April 2026): +# - Lite: 2-3 concurrent, lower tolerance +# - Pro: 5-8 concurrent, moderate tolerance +# - Max: 10-15 concurrent, highest tolerance +# Limits are aggregate across all models. Advanced models (GLM-5 family) +# consume 2-3x quota per prompt, so effective concurrency is lower when +# using those models. +# Ref: https://docs.z.ai/devpack/usage-policy +ZAI_TIER_DEFAULTS: dict[str, RateLimitConfig] = { + "lite": RateLimitConfig(requests_per_minute=10, tokens_per_minute=50_000), + "pro": RateLimitConfig(requests_per_minute=30, tokens_per_minute=150_000), + "max": RateLimitConfig(requests_per_minute=60, tokens_per_minute=300_000), } diff --git a/tests/unit/test_providers/test_zai_provider.py b/tests/unit/test_providers/test_zai_provider.py index 5507b5b..d6d204c 100644 --- a/tests/unit/test_providers/test_zai_provider.py +++ b/tests/unit/test_providers/test_zai_provider.py @@ -353,4 +353,94 @@ async def fake_create(**kwargs): kw = captured_kwargs[0] assert "extra_body" in kw - assert kw["extra_body"] == {"thinking": {"type": "disabled"}} \ No newline at end of file + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} + + +# --------------------------------------------------------------------------- +# Tier-based rate limiting +# --------------------------------------------------------------------------- + + +def test_tier_creates_rate_limiter(): + """Setting tier should create a tier-appropriate rate limiter.""" + from repowise.core.rate_limiter import ZAI_TIER_DEFAULTS + + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == ZAI_TIER_DEFAULTS["pro"].requests_per_minute + assert p._rate_limiter.config.tokens_per_minute == ZAI_TIER_DEFAULTS["pro"].tokens_per_minute + + +def test_tier_lite(): + """Lite tier should have conservative RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="lite") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 10 + assert p._rate_limiter.config.tokens_per_minute == 50_000 + + +def test_tier_pro(): + """Pro tier should have moderate RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 30 + assert p._rate_limiter.config.tokens_per_minute == 150_000 + + +def test_tier_max(): + """Max tier should have the highest RPM/TPM limits.""" + p = ZAIProvider(api_key="test-key", tier="max") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 60 + assert p._rate_limiter.config.tokens_per_minute == 300_000 + + +def test_tier_case_insensitive(): + """Tier should be case-insensitive.""" + p = ZAIProvider(api_key="test-key", tier="PRO") + assert p._rate_limiter is not None + assert p._rate_limiter.config.requests_per_minute == 30 + + +def test_tier_overrides_explicit_rate_limiter(): + """Tier takes precedence over an explicitly passed rate_limiter.""" + from repowise.core.rate_limiter import RateLimitConfig, RateLimiter + + explicit_limiter = RateLimiter(RateLimitConfig(requests_per_minute=999, tokens_per_minute=999_999)) + p = ZAIProvider(api_key="test-key", tier="lite", rate_limiter=explicit_limiter) + # Tier wins -- the explicit limiter should be discarded + assert p._rate_limiter.config.requests_per_minute == 10 + assert p._rate_limiter is not explicit_limiter + + +def test_no_tier_no_rate_limiter(): + """Without tier or rate_limiter, _rate_limiter should be None.""" + p = ZAIProvider(api_key="test-key") + assert p._rate_limiter is None + + +def test_explicit_rate_limiter_without_tier(): + """Without tier, an explicit rate_limiter should be used.""" + from repowise.core.rate_limiter import RateLimitConfig, RateLimiter + + limiter = RateLimiter(RateLimitConfig(requests_per_minute=42, tokens_per_minute=420_000)) + p = ZAIProvider(api_key="test-key", rate_limiter=limiter) + assert p._rate_limiter is limiter + + +def test_invalid_tier_raises(): + """An unrecognized tier should raise ValueError with valid options.""" + with pytest.raises(ValueError, match="Unknown Z.AI tier"): + ZAIProvider(api_key="test-key", tier="enterprise") + + +def test_tier_stored(): + """Tier value should be stored on the provider.""" + p = ZAIProvider(api_key="test-key", tier="pro") + assert p._tier == "pro" + + +def test_no_tier_stored_as_none(): + """Without tier, _tier should be None.""" + p = ZAIProvider(api_key="test-key") + assert p._tier is None \ No newline at end of file