From 3bd617ad5717f76bc2cd2ce2f1aa87705569ac46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A5=BF=E7=B1=B3?= Date: Mon, 9 Mar 2026 19:37:18 +0800 Subject: [PATCH] feat: add MiniMax provider support Add MiniMax as a new LLM provider with OpenAI-compatible API. Supported models: - MiniMax-M2.5 (default) - Peak Performance, Ultimate Value - MiniMax-M2.5-highspeed - Same performance, faster and more agile Changes: - Add MiniMaxLLMBackend (inherits OpenAILLMBackend) - Register MiniMax in LLM and embedding backend registries - Add MINIMAX_API_KEY environment variable support with provider defaults - Add unit tests and integration test for MiniMax provider API Documentation: - https://platform.minimax.io/docs/api-reference/text-openai-api - https://platform.minimax.io/docs/api-reference/text-anthropic-api --- src/memu/app/settings.py | 8 ++ src/memu/llm/backends/__init__.py | 10 +- src/memu/llm/backends/minimax.py | 15 +++ src/memu/llm/http_client.py | 3 + tests/llm/test_minimax_provider.py | 111 ++++++++++++++++++++ tests/test_minimax.py | 161 +++++++++++++++++++++++++++++ 6 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 src/memu/llm/backends/minimax.py create mode 100644 tests/llm/test_minimax_provider.py create mode 100644 tests/test_minimax.py diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..b721ba1d 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -135,6 +135,14 @@ def set_provider_defaults(self) -> "LLMConfig": self.api_key = "XAI_API_KEY" if self.chat_model == "gpt-4o-mini": self.chat_model = "grok-2-latest" + elif self.provider == "minimax": + # If values match the OpenAI defaults, switch them to MiniMax defaults + if self.base_url == "https://api.openai.com/v1": + self.base_url = "https://api.minimax.io/v1" + if self.api_key == "OPENAI_API_KEY": + self.api_key = "MINIMAX_API_KEY" + if self.chat_model == "gpt-4o-mini": + self.chat_model = "MiniMax-M2.5" return self diff --git a/src/memu/llm/backends/__init__.py b/src/memu/llm/backends/__init__.py index 5350e7b2..c17078d2 100644 --- a/src/memu/llm/backends/__init__.py +++ b/src/memu/llm/backends/__init__.py @@ -1,7 +1,15 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend from memu.llm.backends.grok import GrokBackend +from memu.llm.backends.minimax import MiniMaxLLMBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend -__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"] +__all__ = [ + "DoubaoLLMBackend", + "GrokBackend", + "LLMBackend", + "MiniMaxLLMBackend", + "OpenAILLMBackend", + "OpenRouterLLMBackend", +] diff --git a/src/memu/llm/backends/minimax.py b/src/memu/llm/backends/minimax.py new file mode 100644 index 00000000..76a05d9f --- /dev/null +++ b/src/memu/llm/backends/minimax.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from memu.llm.backends.openai import OpenAILLMBackend + + +class MiniMaxLLMBackend(OpenAILLMBackend): + """Backend for MiniMax LLM API (OpenAI-compatible). + + MiniMax provides OpenAI-compatible API endpoints. + Supported models: MiniMax-M2.5, MiniMax-M2.5-highspeed. + """ + + name = "minimax" + # MiniMax uses the same /chat/completions endpoint and payload structure as OpenAI. + # We inherit build_summary_payload, parse_summary_response, build_vision_payload, etc. diff --git a/src/memu/llm/http_client.py b/src/memu/llm/http_client.py index ba84b05b..2612d087 100644 --- a/src/memu/llm/http_client.py +++ b/src/memu/llm/http_client.py @@ -12,6 +12,7 @@ from memu.llm.backends.base import LLMBackend from memu.llm.backends.doubao import DoubaoLLMBackend from memu.llm.backends.grok import GrokBackend +from memu.llm.backends.minimax import MiniMaxLLMBackend from memu.llm.backends.openai import OpenAILLMBackend from memu.llm.backends.openrouter import OpenRouterLLMBackend @@ -73,6 +74,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]: OpenAILLMBackend.name: OpenAILLMBackend, DoubaoLLMBackend.name: DoubaoLLMBackend, GrokBackend.name: GrokBackend, + MiniMaxLLMBackend.name: MiniMaxLLMBackend, OpenRouterLLMBackend.name: OpenRouterLLMBackend, } @@ -291,6 +293,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend: _OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend, _DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend, "grok": _OpenAIEmbeddingBackend, + "minimax": _OpenAIEmbeddingBackend, _OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend, } factory = backends.get(provider) diff --git a/tests/llm/test_minimax_provider.py b/tests/llm/test_minimax_provider.py new file mode 100644 index 00000000..712a8619 --- /dev/null +++ b/tests/llm/test_minimax_provider.py @@ -0,0 +1,111 @@ +import unittest +from unittest.mock import patch + +from memu.app.settings import LLMConfig +from memu.llm.backends.minimax import MiniMaxLLMBackend +from memu.llm.openai_sdk import OpenAISDKClient + + +class TestMiniMaxProvider(unittest.IsolatedAsyncioTestCase): + def test_settings_defaults(self): + """Test that setting provider='minimax' sets the correct defaults.""" + config = LLMConfig(provider="minimax") + self.assertEqual(config.base_url, "https://api.minimax.io/v1") + self.assertEqual(config.api_key, "MINIMAX_API_KEY") + self.assertEqual(config.chat_model, "MiniMax-M2.5") + + def test_settings_custom_model(self): + """Test that custom model can be set for MiniMax provider.""" + config = LLMConfig(provider="minimax", chat_model="MiniMax-M2.5-highspeed") + self.assertEqual(config.base_url, "https://api.minimax.io/v1") + self.assertEqual(config.chat_model, "MiniMax-M2.5-highspeed") + + def test_settings_custom_base_url(self): + """Test that custom base_url is preserved for MiniMax provider.""" + config = LLMConfig(provider="minimax", base_url="https://api.minimaxi.com/v1") + self.assertEqual(config.base_url, "https://api.minimaxi.com/v1") + + @patch("memu.llm.openai_sdk.AsyncOpenAI") + async def test_client_initialization_with_minimax_config(self, mock_async_openai): + """Test that OpenAISDKClient initializes with MiniMax base URL when configured.""" + config = LLMConfig(provider="minimax") + + client = OpenAISDKClient( + base_url=config.base_url, + api_key="fake-key", + chat_model=config.chat_model, + embed_model=config.embed_model, + ) + + mock_async_openai.assert_called_with(api_key="fake-key", base_url="https://api.minimax.io/v1") + self.assertEqual(client.chat_model, "MiniMax-M2.5") + + def test_minimax_backend_payload_parsing(self): + """Test that MiniMaxLLMBackend parses responses correctly (inherited from OpenAI).""" + backend = MiniMaxLLMBackend() + + dummy_response = {"choices": [{"message": {"content": "MiniMax response content", "role": "assistant"}}]} + + result = backend.parse_summary_response(dummy_response) + self.assertEqual(result, "MiniMax response content") + + def test_minimax_backend_name(self): + """Test that MiniMaxLLMBackend has the correct name.""" + backend = MiniMaxLLMBackend() + self.assertEqual(backend.name, "minimax") + + def test_minimax_backend_summary_payload(self): + """Test that MiniMaxLLMBackend builds the correct summary payload.""" + backend = MiniMaxLLMBackend() + payload = backend.build_summary_payload( + text="Hello world", + system_prompt="Summarize this.", + chat_model="MiniMax-M2.5", + max_tokens=100, + ) + self.assertEqual(payload["model"], "MiniMax-M2.5") + self.assertEqual(len(payload["messages"]), 2) + self.assertEqual(payload["messages"][0]["role"], "system") + self.assertEqual(payload["messages"][1]["role"], "user") + self.assertEqual(payload["messages"][1]["content"], "Hello world") + self.assertEqual(payload["max_tokens"], 100) + + def test_minimax_backend_vision_payload(self): + """Test that MiniMaxLLMBackend builds the correct vision payload.""" + backend = MiniMaxLLMBackend() + payload = backend.build_vision_payload( + prompt="Describe this image", + base64_image="base64data", + mime_type="image/png", + system_prompt=None, + chat_model="MiniMax-M2.5", + max_tokens=200, + ) + self.assertEqual(payload["model"], "MiniMax-M2.5") + self.assertIsInstance(payload["messages"], list) + # Should have user message with text and image content + user_msg = payload["messages"][0] + self.assertEqual(user_msg["role"], "user") + self.assertIsInstance(user_msg["content"], list) + self.assertEqual(len(user_msg["content"]), 2) + + def test_minimax_http_backend_registration(self): + """Test that MiniMax is registered in the HTTP LLM backends.""" + from memu.llm.http_client import LLM_BACKENDS + + self.assertIn("minimax", LLM_BACKENDS) + backend = LLM_BACKENDS["minimax"]() + self.assertEqual(backend.name, "minimax") + + def test_minimax_http_embedding_backend_registration(self): + """Test that MiniMax is registered in the HTTP embedding backends.""" + from memu.llm.http_client import HTTPLLMClient + + client = HTTPLLMClient( + base_url="https://api.minimax.io/v1", + api_key="fake-key", + chat_model="MiniMax-M2.5", + provider="minimax", + ) + self.assertEqual(client.provider, "minimax") + self.assertEqual(client.backend.name, "minimax") diff --git a/tests/test_minimax.py b/tests/test_minimax.py new file mode 100644 index 00000000..90b3e541 --- /dev/null +++ b/tests/test_minimax.py @@ -0,0 +1,161 @@ +""" +Test MiniMax integration with MemU's full workflow. + +Tests: +1. Conversation memorization using MiniMax +2. RAG-based retrieval using MiniMax embeddings +3. LLM-based retrieval using MiniMax + +Usage: + export MINIMAX_API_KEY=your_api_key + python tests/test_minimax.py +""" + +import asyncio +import json +import os +import sys +from typing import Any + +import pytest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +from memu.app import MemoryService + + +def _print_categories(categories, max_items=3): + """Print category summaries.""" + if categories: + print(" Categories:") + for cat in categories[:max_items]: + summary = cat.get("summary") or cat.get("description", "") + print(f" - {cat.get('name')}: {summary[:60]}...") + + +def _print_items(items, max_items=3): + """Print memory item summaries.""" + if items: + print(" Items:") + for item in items[:max_items]: + memory_type = item.get("memory_type", "unknown") + summary = item.get("summary", "")[:80] + print(f" - [{memory_type}] {summary}...") + + +async def _test_memorize(service, file_path, output_data): + """Test conversation memorization.""" + print("\n[MINIMAX] Test 1: Memorizing conversation...") + memory = await service.memorize( + resource_url=file_path, modality="conversation", user={"user_id": "minimax_test_user"} + ) + items_count = len(memory.get("items", [])) + categories_count = len(memory.get("categories", [])) + + print(f" Memorized {items_count} items") + print(f" Created {categories_count} categories") + + output_data["memorize"] = memory + + assert items_count > 0, "Expected at least 1 memory item" + assert categories_count > 0, "Expected at least 1 category" + + _print_categories(memory.get("categories", [])) + return memory + + +async def _test_retrieve(service, queries, method, test_num, output_data): + """Test retrieval with specified method.""" + print(f"\n[MINIMAX] Test {test_num}: {method.upper()}-based retrieval...") + service.retrieve_config.method = method + result = await service.retrieve(queries=queries, where={"user_id": "minimax_test_user"}) + + categories_retrieved = len(result.get("categories", [])) + items_retrieved = len(result.get("items", [])) + + print(f" Retrieved {categories_retrieved} categories") + print(f" Retrieved {items_retrieved} items") + + output_data[f"retrieve_{method}"] = result + + _print_categories(result.get("categories", [])) + _print_items(result.get("items", [])) + return result + + +async def test_minimax_full_workflow(): + """Test MiniMax integration with full MemU workflow.""" + api_key = os.environ.get("MINIMAX_API_KEY") + if not api_key: + pytest.skip("MINIMAX_API_KEY environment variable not set") + + file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "example", "example_conversation.json")) + if not os.path.exists(file_path): + pytest.skip(f"Test file not found: {file_path}") + + output_data: dict[str, Any] = {} + + print("\n" + "=" * 60) + print("[MINIMAX] Starting full workflow test...") + print("=" * 60) + + service = MemoryService( + llm_profiles={ + "default": { + "provider": "minimax", + "client_backend": "httpx", + "base_url": "https://api.minimax.io/v1", + "api_key": api_key, + "chat_model": "MiniMax-M2.5", + "embed_model": "text-embedding-3-small", + }, + }, + database_config={ + "metadata_store": {"provider": "inmemory"}, + }, + retrieve_config={ + "method": "rag", + "route_intention": False, + }, + ) + + queries = [ + {"role": "user", "content": {"text": "What foods does the user like to eat?"}}, + ] + + await _test_memorize(service, file_path, output_data) + await _test_retrieve(service, queries, "rag", 2, output_data) + await _test_retrieve(service, queries, "llm", 3, output_data) + + # Test 4: List memory items + print("\n[MINIMAX] Test 4: List memory items...") + items_result = await service.list_memory_items(where={"user_id": "minimax_test_user"}) + items_list = items_result.get("items", []) + print(f" Listed {len(items_list)} memory items") + output_data["list_items"] = items_result + assert len(items_list) > 0, "Expected at least 1 item in list" + + # Test 5: List memory categories + print("\n[MINIMAX] Test 5: List memory categories...") + cats_result = await service.list_memory_categories(where={"user_id": "minimax_test_user"}) + cats_list = cats_result.get("categories", []) + print(f" Listed {len(cats_list)} categories") + output_data["list_categories"] = cats_result + assert len(cats_list) > 0, "Expected at least 1 category in list" + + # Save output to file + output_file = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "examples", "output", "minimax_test_output.json") + ) + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "w", encoding="utf-8") as f: + json.dump(output_data, f, indent=2, default=str) + print(f"\n[MINIMAX] Output saved to: {output_file}") + + print("\n" + "=" * 60) + print("[MINIMAX] All tests completed!") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(test_minimax_full_workflow())