diff --git a/README.MD b/README.MD
index 8418f76..8946c13 100644
--- a/README.MD
+++ b/README.MD
@@ -67,6 +67,11 @@ bash scripts/evaluate_claude.sh
bash scripts/evaluate_gemini.sh
```
+* Evaluate MiniMax models (MiniMax-M2.5, MiniMax-M2.5-highspeed with 204K context)
+```
+bash scripts/evaluate_minimax.sh
+```
+
* Evaluate models available on Huggingface
```
bash scripts/evaluate_hf_llm.sh
diff --git a/global_methods.py b/global_methods.py
index 7ada78d..a06ebcf 100644
--- a/global_methods.py
+++ b/global_methods.py
@@ -4,6 +4,8 @@
import time
import sys
import os
+import re
+import httpx
import google.generativeai as genai
from anthropic import Anthropic
@@ -13,6 +15,9 @@ def get_openai_embedding(texts, model="text-embedding-ada-002"):
texts = [text.replace("\n", " ") for text in texts]
return np.array([openai.Embedding.create(input = texts, model=model)['data'][i]['embedding'] for i in range(len(texts))])
+def set_minimax_key():
+ pass
+
def set_anthropic_key():
pass
@@ -79,6 +84,40 @@ def run_claude(query, max_new_tokens, model_name):
return message.content[0].text
+def run_minimax(query, max_new_tokens, model_name, temperature=0):
+ """Run MiniMax model via OpenAI-compatible API."""
+
+ if model_name == 'minimax-m2.5':
+ api_model_name = "MiniMax-M2.5"
+ elif model_name == 'minimax-m2.5-highspeed':
+ api_model_name = "MiniMax-M2.5-highspeed"
+ elif model_name == 'minimax-m2.7':
+ api_model_name = "MiniMax-M2.7"
+ else:
+ api_model_name = model_name
+
+ url = "https://api.minimax.io/v1/chat/completions"
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('MINIMAX_API_KEY', '')}",
+ "Content-Type": "application/json",
+ }
+ # MiniMax temperature must be in (0.0, 1.0]
+ clamped_temp = max(0.01, min(temperature, 1.0)) if temperature > 0 else 0.01
+ payload = {
+ "model": api_model_name,
+ "messages": [{"role": "user", "content": query}],
+ "max_tokens": max_new_tokens,
+ "temperature": clamped_temp,
+ }
+ response = httpx.post(url, headers=headers, json=payload, timeout=120)
+ response.raise_for_status()
+ data = response.json()
+ text = data["choices"][0]["message"]["content"]
+ # Strip thinking tags if present (M2.5 models may include them)
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip()
+ return text
+
+
def run_gemini(model, content: str, max_tokens: int = 0):
try:
diff --git a/scripts/env.sh b/scripts/env.sh
index cdb5df3..7f3fcf0 100644
--- a/scripts/env.sh
+++ b/scripts/env.sh
@@ -24,5 +24,8 @@ export GOOGLE_API_KEY=
# Anthropic API Key
export ANTHROPIC_API_KEY=
+# MiniMax API Key
+export MINIMAX_API_KEY=
+
# HuggingFace Token
export HF_TOKEN=
diff --git a/scripts/evaluate_minimax.sh b/scripts/evaluate_minimax.sh
new file mode 100644
index 0000000..f58e556
--- /dev/null
+++ b/scripts/evaluate_minimax.sh
@@ -0,0 +1,12 @@
+# sets necessary environment variables
+source scripts/env.sh
+
+# Evaluate MiniMax-M2.5
+python3 task_eval/evaluate_qa.py \
+ --data-file $DATA_FILE_PATH --out-file $OUT_DIR/$QA_OUTPUT_FILE \
+ --model minimax-m2.5 --batch-size 10
+
+# Evaluate MiniMax-M2.5-highspeed (204K context, faster inference)
+python3 task_eval/evaluate_qa.py \
+ --data-file $DATA_FILE_PATH --out-file $OUT_DIR/$QA_OUTPUT_FILE \
+ --model minimax-m2.5-highspeed --batch-size 10
diff --git a/task_eval/evaluate_qa.py b/task_eval/evaluate_qa.py
index c3e888c..1043fb2 100644
--- a/task_eval/evaluate_qa.py
+++ b/task_eval/evaluate_qa.py
@@ -5,12 +5,13 @@
import os, json
from tqdm import tqdm
import argparse
-from global_methods import set_openai_key, set_anthropic_key, set_gemini_key
+from global_methods import set_openai_key, set_anthropic_key, set_gemini_key, set_minimax_key
from task_eval.evaluation import eval_question_answering
from task_eval.evaluation_stats import analyze_aggr_acc
from task_eval.gpt_utils import get_gpt_answers
from task_eval.claude_utils import get_claude_answers
from task_eval.gemini_utils import get_gemini_answers
+from task_eval.minimax_utils import get_minimax_answers
from task_eval.hf_llm_utils import init_hf_model, get_hf_answers
import numpy as np
@@ -56,7 +57,10 @@ def main():
model_name = "models/gemini-1.0-pro-latest"
gemini_model = genai.GenerativeModel(model_name)
-
+
+ elif 'minimax' in args.model:
+ set_minimax_key()
+
elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]):
hf_pipeline, hf_model_name = init_hf_model(args)
@@ -90,6 +94,8 @@ def main():
answers = get_claude_answers(data, out_data, prediction_key, args)
elif 'gemini' in args.model:
answers = get_gemini_answers(gemini_model, data, out_data, prediction_key, args)
+ elif 'minimax' in args.model:
+ answers = get_minimax_answers(data, out_data, prediction_key, args)
elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]):
answers = get_hf_answers(data, out_data, args, hf_pipeline, hf_model_name)
else:
diff --git a/task_eval/minimax_utils.py b/task_eval/minimax_utils.py
new file mode 100644
index 0000000..60bc8ce
--- /dev/null
+++ b/task_eval/minimax_utils.py
@@ -0,0 +1,203 @@
+import sys
+from pathlib import Path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import random
+import os, json
+from tqdm import tqdm
+import time
+from global_methods import run_minimax
+
+
+MAX_LENGTH = {
+ 'minimax-m2.5': 204000,
+ 'minimax-m2.5-highspeed': 204000,
+ 'minimax-m2.7': 1000000,
+}
+PER_QA_TOKEN_BUDGET = 50
+
+QA_PROMPT = """
+Based on the above context, write an answer in the form of a short phrase for the following question. Answer with exact words from the context whenever possible.
+
+Question: {} Short answer:
+"""
+
+QA_PROMPT_CAT_5 = """
+Based on the above context, answer the following question.
+
+Question: {} Short answer:
+"""
+
+QA_PROMPT_BATCH = """
+Based on the above conversations, write short answers for each of the following questions in a few words. Write the answers in the form of a json dictionary where each entry contains the string format of question number as 'key' and the short answer as value. Use single-quote characters for named entities. Answer with exact words from the conversations whenever possible.
+
+"""
+
+CONV_START_PROMPT = "Below is a conversation between two people: {} and {}. The conversation takes place over multiple days and the date of each conversation is wriiten at the beginning of the conversation.\n\n"
+
+
+def process_ouput(text):
+
+ text = text.strip()
+ if text[0] != '{':
+ start = text.index('{')
+ text = text[start:].strip()
+
+ return json.loads(text)
+
+
+def get_cat_5_answer(model_prediction, answer_key):
+
+ model_prediction = model_prediction.strip().lower()
+ if len(model_prediction) == 1:
+ if 'a' in model_prediction:
+ return answer_key['a']
+ else:
+ return answer_key['b']
+ elif len(model_prediction) == 3:
+ if '(a)' in model_prediction:
+ return answer_key['a']
+ else:
+ return answer_key['b']
+ else:
+ return model_prediction
+
+
+def get_input_context(data, num_question_tokens, model, args):
+
+ query_conv = ''
+ stop = False
+ session_nums = [int(k.split('_')[-1]) for k in data.keys() if 'session' in k and 'date_time' not in k]
+ for i in range(min(session_nums), max(session_nums) + 1):
+ if 'session_%s' % i in data:
+ query_conv += "\n\n"
+ for dialog in data['session_%s' % i][::-1]:
+ turn = ''
+ turn = dialog['speaker'] + ' said, \"' + dialog['text'] + '\"' + '\n'
+ if "blip_caption" in dialog:
+ turn += ' and shared %s.' % dialog["blip_caption"]
+ turn += '\n'
+
+ query_conv = turn + query_conv
+
+ query_conv = '\nDATE: ' + data['session_%s_date_time' % i] + '\n' + 'CONVERSATION:\n' + query_conv
+ if stop:
+ break
+
+ return query_conv
+
+
+def get_minimax_answers(in_data, out_data, prediction_key, args):
+
+ assert len(in_data['qa']) == len(out_data['qa']), (len(in_data['qa']), len(out_data['qa']))
+
+ # start instruction prompt
+ speakers_names = list(set([d['speaker'] for d in in_data['conversation']['session_1']]))
+ start_prompt = CONV_START_PROMPT.format(speakers_names[0], speakers_names[1])
+ start_tokens = 100
+
+ if args.rag_mode:
+ raise NotImplementedError
+ else:
+ context_database, query_vectors = None, None
+
+ for batch_start_idx in tqdm(range(0, len(in_data['qa']), args.batch_size), desc='Generating answers'):
+
+ questions = []
+ include_idxs = []
+ cat_5_idxs = []
+ cat_5_answers = []
+ for i in range(batch_start_idx, batch_start_idx + args.batch_size):
+
+ if i >= len(in_data['qa']):
+ break
+
+ qa = in_data['qa'][i]
+
+ if prediction_key not in out_data['qa'][i] or args.overwrite:
+ include_idxs.append(i)
+ else:
+ continue
+
+ if qa['category'] == 2:
+ questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.')
+ elif qa['category'] == 5:
+ question = qa['question'] + " Select the correct answer: (a) {} (b) {}. "
+ if random.random() < 0.5:
+ question = question.format('Not mentioned in the conversation', qa['answer'])
+ answer = {'a': 'Not mentioned in the conversation', 'b': qa['answer']}
+ else:
+ question = question.format(qa['answer'], 'Not mentioned in the conversation')
+ answer = {'b': 'Not mentioned in the conversation', 'a': qa['answer']}
+
+ cat_5_idxs.append(len(questions))
+ questions.append(question)
+ cat_5_answers.append(answer)
+ else:
+ questions.append(qa['question'])
+
+ if questions == []:
+ continue
+
+ context_ids = None
+ if args.use_rag:
+ raise NotImplementedError
+ else:
+ question_prompt = QA_PROMPT_BATCH + "\n".join(["%s: %s" % (k, q) for k, q in enumerate(questions)])
+ num_question_tokens = 100
+ query_conv = get_input_context(in_data['conversation'], num_question_tokens + start_tokens, None, args)
+ query_conv = start_prompt + query_conv
+
+ if args.batch_size == 1:
+
+ query = query_conv + '\n\n' + QA_PROMPT.format(questions[0]) if len(cat_5_idxs) == 0 else query_conv + '\n\n' + QA_PROMPT_CAT_5.format(questions[0])
+ answer = run_minimax(query, PER_QA_TOKEN_BUDGET, args.model)
+
+ if len(cat_5_idxs) > 0:
+ answer = get_cat_5_answer(answer, cat_5_answers[0])
+
+ out_data['qa'][include_idxs[0]][prediction_key] = answer.strip()
+ if args.use_rag:
+ out_data['qa'][include_idxs[0]][prediction_key + '_context'] = context_ids
+
+ else:
+ query = query_conv + '\n' + question_prompt
+
+ trials = 0
+ while trials < 5:
+ try:
+ trials += 1
+ answer = run_minimax(query, PER_QA_TOKEN_BUDGET * args.batch_size, args.model)
+ answer = answer.replace('\\"', "'").replace('json', '').replace('`', '').strip()
+ answers = process_ouput(answer.strip())
+ break
+ except json.decoder.JSONDecodeError:
+ pass
+
+ for k, idx in enumerate(include_idxs):
+ try:
+ answers = process_ouput(answer.strip())
+ if k in cat_5_idxs:
+ predicted_answer = get_cat_5_answer(answers[str(k)], cat_5_answers[cat_5_idxs.index(k)])
+ out_data['qa'][idx][prediction_key] = predicted_answer
+ else:
+ try:
+ out_data['qa'][idx][prediction_key] = str(answers[str(k)]).replace('(a)', '').replace('(b)', '').strip()
+ except:
+ out_data['qa'][idx][prediction_key] = ', '.join([str(n) for n in list(answers[str(k)].values())])
+ except:
+ try:
+ answers = json.loads(answer.strip())
+ if k in cat_5_idxs:
+ predicted_answer = get_cat_5_answer(answers[k], cat_5_answers[cat_5_idxs.index(k)])
+ out_data['qa'][idx][prediction_key] = predicted_answer
+ else:
+ out_data['qa'][idx][prediction_key] = answers[k].replace('(a)', '').replace('(b)', '').strip()
+ except:
+ if k in cat_5_idxs:
+ predicted_answer = get_cat_5_answer(answer.strip(), cat_5_answers[cat_5_idxs.index(k)])
+ out_data['qa'][idx][prediction_key] = predicted_answer
+ else:
+ out_data['qa'][idx][prediction_key] = json.loads(answer.strip().replace('(a)', '').replace('(b)', '').split('\n')[k])[0]
+
+ return out_data
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_minimax_integration.py b/tests/test_minimax_integration.py
new file mode 100644
index 0000000..2482d59
--- /dev/null
+++ b/tests/test_minimax_integration.py
@@ -0,0 +1,90 @@
+"""Integration tests for MiniMax provider in LoCoMo.
+
+These tests verify the full pipeline works with the MiniMax API.
+They require MINIMAX_API_KEY to be set in the environment.
+"""
+
+import sys
+import os
+import json
+import unittest
+from pathlib import Path
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+MINIMAX_API_KEY = os.environ.get('MINIMAX_API_KEY', '')
+SKIP_REASON = "MINIMAX_API_KEY not set; skipping integration tests"
+
+
+@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON)
+class TestMinimaxAPIIntegration(unittest.TestCase):
+ """Integration tests that call the real MiniMax API."""
+
+ def test_run_minimax_m2_5(self):
+ """Test a real API call to MiniMax-M2.5."""
+ from global_methods import run_minimax
+
+ # M2.5 uses thinking tokens that count against max_tokens,
+ # so we need a larger budget
+ result = run_minimax(
+ "What is 2 + 2? Answer with just the number.",
+ 1024,
+ "minimax-m2.5"
+ )
+ self.assertIsInstance(result, str)
+ self.assertTrue(len(result) > 0)
+ self.assertIn('4', result)
+
+ def test_run_minimax_m2_5_highspeed(self):
+ """Test a real API call to MiniMax-M2.5-highspeed."""
+ from global_methods import run_minimax
+
+ result = run_minimax(
+ "Name a color of the rainbow. Answer with just one word.",
+ 1024,
+ "minimax-m2.5-highspeed"
+ )
+ self.assertIsInstance(result, str)
+ self.assertTrue(len(result) > 0)
+
+ def test_minimax_qa_pipeline(self):
+ """Test the full MiniMax QA pipeline with a minimal conversation."""
+ from task_eval.minimax_utils import get_minimax_answers
+ from unittest.mock import MagicMock
+
+ in_data = {
+ 'conversation': {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'I just adopted a golden retriever named Max!', 'dia_id': 'd1'},
+ {'speaker': 'Bob', 'text': 'That is wonderful! What breed is he?', 'dia_id': 'd2'},
+ {'speaker': 'Alice', 'text': 'He is a golden retriever, about 2 years old.', 'dia_id': 'd3'},
+ ],
+ 'session_1_date_time': '2024-03-15',
+ },
+ 'qa': [
+ {'question': "What is Alice's dog's name?", 'answer': 'Max', 'category': 1, 'evidence': ['d1']}
+ ]
+ }
+ out_data = {
+ 'qa': [
+ {'question': "What is Alice's dog's name?", 'answer': 'Max', 'category': 1, 'evidence': ['d1']}
+ ]
+ }
+
+ args = MagicMock()
+ args.model = 'minimax-m2.5'
+ args.batch_size = 1
+ args.use_rag = False
+ args.rag_mode = ''
+ args.overwrite = True
+
+ result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args)
+
+ self.assertIn('minimax-m2.5_prediction', result['qa'][0])
+ prediction = result['qa'][0]['minimax-m2.5_prediction'].lower()
+ self.assertIn('max', prediction)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_minimax_unit.py b/tests/test_minimax_unit.py
new file mode 100644
index 0000000..12265a9
--- /dev/null
+++ b/tests/test_minimax_unit.py
@@ -0,0 +1,426 @@
+"""Unit tests for MiniMax provider integration in LoCoMo."""
+
+import sys
+import os
+import json
+import unittest
+from unittest.mock import patch, MagicMock
+from pathlib import Path
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+class TestRunMinimax(unittest.TestCase):
+ """Test the run_minimax function in global_methods."""
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_basic(self, mock_httpx):
+ """Test basic MiniMax API call."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "Paris"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+ result = run_minimax("What is the capital of France?", 32, "minimax-m2.5")
+
+ self.assertEqual(result, "Paris")
+ mock_httpx.post.assert_called_once()
+ call_args = mock_httpx.post.call_args
+ payload = call_args[1]['json']
+ self.assertEqual(payload['model'], 'MiniMax-M2.5')
+ self.assertEqual(payload['max_tokens'], 32)
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_model_mapping(self, mock_httpx):
+ """Test model name mapping for different MiniMax models."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "ok"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+
+ test_cases = [
+ ('minimax-m2.5', 'MiniMax-M2.5'),
+ ('minimax-m2.5-highspeed', 'MiniMax-M2.5-highspeed'),
+ ('minimax-m2.7', 'MiniMax-M2.7'),
+ ]
+
+ for input_name, expected_api_name in test_cases:
+ mock_httpx.post.reset_mock()
+ run_minimax("test", 32, input_name)
+ payload = mock_httpx.post.call_args[1]['json']
+ self.assertEqual(payload['model'], expected_api_name,
+ f"Model {input_name} should map to {expected_api_name}")
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_temperature_clamping(self, mock_httpx):
+ """Test that temperature is clamped to valid range."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "ok"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+
+ # temperature=0 should be clamped to 0.01
+ run_minimax("test", 32, "minimax-m2.5", temperature=0)
+ payload = mock_httpx.post.call_args[1]['json']
+ self.assertEqual(payload['temperature'], 0.01)
+
+ # temperature=0.5 should pass through
+ run_minimax("test", 32, "minimax-m2.5", temperature=0.5)
+ payload = mock_httpx.post.call_args[1]['json']
+ self.assertEqual(payload['temperature'], 0.5)
+
+ # temperature > 1.0 should be clamped to 1.0
+ run_minimax("test", 32, "minimax-m2.5", temperature=1.5)
+ payload = mock_httpx.post.call_args[1]['json']
+ self.assertEqual(payload['temperature'], 1.0)
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_strips_think_tags(self, mock_httpx):
+ """Test that thinking tags from M2.5 models are stripped."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "Let me think about this...Paris"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+ result = run_minimax("test", 32, "minimax-m2.5")
+ self.assertEqual(result, "Paris")
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_api_url(self, mock_httpx):
+ """Test that the correct API URL is used."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "ok"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+ run_minimax("test", 32, "minimax-m2.5")
+
+ call_args = mock_httpx.post.call_args
+ self.assertEqual(call_args[0][0], "https://api.minimax.io/v1/chat/completions")
+
+ @patch('global_methods.httpx')
+ def test_run_minimax_auth_header(self, mock_httpx):
+ """Test that the authorization header includes the API key."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "ok"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from global_methods import run_minimax
+ os.environ['MINIMAX_API_KEY'] = 'my-secret-key'
+ run_minimax("test", 32, "minimax-m2.5")
+
+ call_args = mock_httpx.post.call_args
+ headers = call_args[1]['headers']
+ self.assertEqual(headers['Authorization'], 'Bearer my-secret-key')
+
+
+class TestSetMinimaxKey(unittest.TestCase):
+ """Test the set_minimax_key function."""
+
+ def test_set_minimax_key_no_error(self):
+ """Test that set_minimax_key runs without error."""
+ from global_methods import set_minimax_key
+ set_minimax_key() # Should not raise
+
+
+class TestMinimaxUtils(unittest.TestCase):
+ """Test the minimax_utils module."""
+
+ def test_process_output_valid_json(self):
+ """Test process_ouput with valid JSON."""
+ from task_eval.minimax_utils import process_ouput
+
+ result = process_ouput('{"0": "Paris", "1": "London"}')
+ self.assertEqual(result, {"0": "Paris", "1": "London"})
+
+ def test_process_output_with_prefix(self):
+ """Test process_ouput strips text before JSON."""
+ from task_eval.minimax_utils import process_ouput
+
+ result = process_ouput('Here is the answer: {"0": "Paris"}')
+ self.assertEqual(result, {"0": "Paris"})
+
+ def test_get_cat_5_answer_single_char_a(self):
+ """Test adversarial answer parsing: single character 'a'."""
+ from task_eval.minimax_utils import get_cat_5_answer
+
+ answer_key = {'a': 'Not mentioned', 'b': 'Paris'}
+ result = get_cat_5_answer('a', answer_key)
+ self.assertEqual(result, 'Not mentioned')
+
+ def test_get_cat_5_answer_single_char_b(self):
+ """Test adversarial answer parsing: single character 'b'."""
+ from task_eval.minimax_utils import get_cat_5_answer
+
+ answer_key = {'a': 'Not mentioned', 'b': 'Paris'}
+ result = get_cat_5_answer('b', answer_key)
+ self.assertEqual(result, 'Paris')
+
+ def test_get_cat_5_answer_parenthesized(self):
+ """Test adversarial answer parsing: parenthesized."""
+ from task_eval.minimax_utils import get_cat_5_answer
+
+ answer_key = {'a': 'Not mentioned', 'b': 'Paris'}
+ result = get_cat_5_answer('(a)', answer_key)
+ self.assertEqual(result, 'Not mentioned')
+
+ def test_get_cat_5_answer_freeform(self):
+ """Test adversarial answer parsing: free-form text returned as-is (lowercased)."""
+ from task_eval.minimax_utils import get_cat_5_answer
+
+ answer_key = {'a': 'Not mentioned', 'b': 'Paris'}
+ # The function lowercases the input, then returns it when length > 3
+ result = get_cat_5_answer('The answer is Paris', answer_key)
+ self.assertEqual(result, 'the answer is paris')
+
+ def test_get_input_context(self):
+ """Test conversation context extraction."""
+ from task_eval.minimax_utils import get_input_context
+
+ data = {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'Hello Bob!', 'dia_id': 'd1'},
+ {'speaker': 'Bob', 'text': 'Hi Alice!', 'dia_id': 'd2'},
+ ],
+ 'session_1_date_time': '2024-01-01',
+ }
+ args = MagicMock()
+ result = get_input_context(data, 100, None, args)
+ self.assertIn('Alice', result)
+ self.assertIn('Bob', result)
+ self.assertIn('2024-01-01', result)
+
+ def test_get_input_context_multimodal(self):
+ """Test that blip captions are included in context."""
+ from task_eval.minimax_utils import get_input_context
+
+ data = {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'Look at this!', 'dia_id': 'd1',
+ 'blip_caption': 'a photo of a sunset'},
+ ],
+ 'session_1_date_time': '2024-01-01',
+ }
+ args = MagicMock()
+ result = get_input_context(data, 100, None, args)
+ self.assertIn('a photo of a sunset', result)
+
+ def test_max_length_definitions(self):
+ """Test that MAX_LENGTH contains correct MiniMax model entries."""
+ from task_eval.minimax_utils import MAX_LENGTH
+
+ self.assertIn('minimax-m2.5', MAX_LENGTH)
+ self.assertIn('minimax-m2.5-highspeed', MAX_LENGTH)
+ self.assertIn('minimax-m2.7', MAX_LENGTH)
+ self.assertEqual(MAX_LENGTH['minimax-m2.5'], 204000)
+ self.assertEqual(MAX_LENGTH['minimax-m2.7'], 1000000)
+
+ def test_prompts_defined(self):
+ """Test that required prompts are defined in minimax_utils."""
+ from task_eval import minimax_utils
+
+ self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT'))
+ self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT_CAT_5'))
+ self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT_BATCH'))
+ self.assertTrue(hasattr(minimax_utils, 'CONV_START_PROMPT'))
+
+
+class TestEvaluateQADispatch(unittest.TestCase):
+ """Test that evaluate_qa.py correctly dispatches to MiniMax."""
+
+ def test_minimax_import(self):
+ """Test that minimax_utils can be imported from evaluate_qa context."""
+ from task_eval.minimax_utils import get_minimax_answers
+ self.assertTrue(callable(get_minimax_answers))
+
+ def test_set_minimax_key_import(self):
+ """Test that set_minimax_key can be imported from global_methods."""
+ from global_methods import set_minimax_key
+ self.assertTrue(callable(set_minimax_key))
+
+ @patch('global_methods.httpx')
+ def test_get_minimax_answers_single_batch(self, mock_httpx):
+ """Test get_minimax_answers with batch_size=1."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "Paris"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from task_eval.minimax_utils import get_minimax_answers
+
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+ in_data = {
+ 'conversation': {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'I live in Paris.', 'dia_id': 'd1'},
+ {'speaker': 'Bob', 'text': 'Nice!', 'dia_id': 'd2'},
+ ],
+ 'session_1_date_time': '2024-01-01',
+ },
+ 'qa': [
+ {'question': 'Where does Alice live?', 'answer': 'Paris', 'category': 1, 'evidence': []}
+ ]
+ }
+ out_data = {
+ 'qa': [
+ {'question': 'Where does Alice live?', 'answer': 'Paris', 'category': 1, 'evidence': []}
+ ]
+ }
+
+ args = MagicMock()
+ args.model = 'minimax-m2.5'
+ args.batch_size = 1
+ args.use_rag = False
+ args.rag_mode = ''
+ args.overwrite = True
+
+ result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args)
+
+ self.assertIn('minimax-m2.5_prediction', result['qa'][0])
+ self.assertEqual(result['qa'][0]['minimax-m2.5_prediction'], 'Paris')
+
+ @patch('global_methods.httpx')
+ def test_get_minimax_answers_cat_5(self, mock_httpx):
+ """Test get_minimax_answers with adversarial category 5 question."""
+ mock_response = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [{"message": {"content": "(a)"}}]
+ }
+ mock_response.raise_for_status = MagicMock()
+ mock_httpx.post.return_value = mock_response
+
+ from task_eval.minimax_utils import get_minimax_answers
+
+ os.environ['MINIMAX_API_KEY'] = 'test-key'
+ in_data = {
+ 'conversation': {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'Hello Bob', 'dia_id': 'd1'},
+ {'speaker': 'Bob', 'text': 'Hi Alice', 'dia_id': 'd2'},
+ ],
+ 'session_1_date_time': '2024-01-01',
+ },
+ 'qa': [
+ {'question': 'Does Alice own a car?', 'answer': 'Not mentioned in the conversation', 'category': 5, 'evidence': []}
+ ]
+ }
+ out_data = {
+ 'qa': [
+ {'question': 'Does Alice own a car?', 'answer': 'Not mentioned in the conversation', 'category': 5, 'evidence': []}
+ ]
+ }
+
+ args = MagicMock()
+ args.model = 'minimax-m2.5'
+ args.batch_size = 1
+ args.use_rag = False
+ args.rag_mode = ''
+ args.overwrite = True
+
+ result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args)
+ self.assertIn('minimax-m2.5_prediction', result['qa'][0])
+
+ @patch('global_methods.httpx')
+ def test_get_minimax_answers_skip_existing(self, mock_httpx):
+ """Test that existing predictions are skipped without overwrite."""
+ from task_eval.minimax_utils import get_minimax_answers
+
+ in_data = {
+ 'conversation': {
+ 'session_1': [
+ {'speaker': 'Alice', 'text': 'Hi', 'dia_id': 'd1'},
+ {'speaker': 'Bob', 'text': 'Hello', 'dia_id': 'd2'},
+ ],
+ 'session_1_date_time': '2024-01-01',
+ },
+ 'qa': [
+ {'question': 'Q?', 'answer': 'A', 'category': 1, 'evidence': []}
+ ]
+ }
+ out_data = {
+ 'qa': [
+ {'question': 'Q?', 'answer': 'A', 'category': 1, 'evidence': [],
+ 'minimax-m2.5_prediction': 'existing_answer'}
+ ]
+ }
+
+ args = MagicMock()
+ args.model = 'minimax-m2.5'
+ args.batch_size = 1
+ args.use_rag = False
+ args.rag_mode = ''
+ args.overwrite = False
+
+ result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args)
+ # Should not have called the API
+ mock_httpx.post.assert_not_called()
+ self.assertEqual(result['qa'][0]['minimax-m2.5_prediction'], 'existing_answer')
+
+
+class TestEvaluateMinimaxShellScript(unittest.TestCase):
+ """Test that the evaluation shell script exists and has correct content."""
+
+ def test_script_exists(self):
+ """Test that evaluate_minimax.sh exists."""
+ script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh'
+ self.assertTrue(script_path.exists())
+
+ def test_script_sources_env(self):
+ """Test that the script sources env.sh."""
+ script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh'
+ content = script_path.read_text()
+ self.assertIn('source scripts/env.sh', content)
+
+ def test_script_runs_m2_5(self):
+ """Test that the script evaluates minimax-m2.5."""
+ script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh'
+ content = script_path.read_text()
+ self.assertIn('minimax-m2.5', content)
+
+ def test_script_runs_highspeed(self):
+ """Test that the script evaluates minimax-m2.5-highspeed."""
+ script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh'
+ content = script_path.read_text()
+ self.assertIn('minimax-m2.5-highspeed', content)
+
+
+class TestEnvScript(unittest.TestCase):
+ """Test that env.sh includes MINIMAX_API_KEY."""
+
+ def test_env_has_minimax_key(self):
+ """Test that env.sh contains MINIMAX_API_KEY export."""
+ env_path = Path(__file__).parent.parent / 'scripts' / 'env.sh'
+ content = env_path.read_text()
+ self.assertIn('MINIMAX_API_KEY', content)
+
+
+if __name__ == '__main__':
+ unittest.main()