diff --git a/README.MD b/README.MD index 8418f76..8946c13 100644 --- a/README.MD +++ b/README.MD @@ -67,6 +67,11 @@ bash scripts/evaluate_claude.sh bash scripts/evaluate_gemini.sh ``` +* Evaluate MiniMax models (MiniMax-M2.5, MiniMax-M2.5-highspeed with 204K context) +``` +bash scripts/evaluate_minimax.sh +``` + * Evaluate models available on Huggingface ``` bash scripts/evaluate_hf_llm.sh diff --git a/global_methods.py b/global_methods.py index 7ada78d..a06ebcf 100644 --- a/global_methods.py +++ b/global_methods.py @@ -4,6 +4,8 @@ import time import sys import os +import re +import httpx import google.generativeai as genai from anthropic import Anthropic @@ -13,6 +15,9 @@ def get_openai_embedding(texts, model="text-embedding-ada-002"): texts = [text.replace("\n", " ") for text in texts] return np.array([openai.Embedding.create(input = texts, model=model)['data'][i]['embedding'] for i in range(len(texts))]) +def set_minimax_key(): + pass + def set_anthropic_key(): pass @@ -79,6 +84,40 @@ def run_claude(query, max_new_tokens, model_name): return message.content[0].text +def run_minimax(query, max_new_tokens, model_name, temperature=0): + """Run MiniMax model via OpenAI-compatible API.""" + + if model_name == 'minimax-m2.5': + api_model_name = "MiniMax-M2.5" + elif model_name == 'minimax-m2.5-highspeed': + api_model_name = "MiniMax-M2.5-highspeed" + elif model_name == 'minimax-m2.7': + api_model_name = "MiniMax-M2.7" + else: + api_model_name = model_name + + url = "https://api.minimax.io/v1/chat/completions" + headers = { + "Authorization": f"Bearer {os.environ.get('MINIMAX_API_KEY', '')}", + "Content-Type": "application/json", + } + # MiniMax temperature must be in (0.0, 1.0] + clamped_temp = max(0.01, min(temperature, 1.0)) if temperature > 0 else 0.01 + payload = { + "model": api_model_name, + "messages": [{"role": "user", "content": query}], + "max_tokens": max_new_tokens, + "temperature": clamped_temp, + } + response = httpx.post(url, headers=headers, json=payload, timeout=120) + response.raise_for_status() + data = response.json() + text = data["choices"][0]["message"]["content"] + # Strip thinking tags if present (M2.5 models may include them) + text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + return text + + def run_gemini(model, content: str, max_tokens: int = 0): try: diff --git a/scripts/env.sh b/scripts/env.sh index cdb5df3..7f3fcf0 100644 --- a/scripts/env.sh +++ b/scripts/env.sh @@ -24,5 +24,8 @@ export GOOGLE_API_KEY= # Anthropic API Key export ANTHROPIC_API_KEY= +# MiniMax API Key +export MINIMAX_API_KEY= + # HuggingFace Token export HF_TOKEN= diff --git a/scripts/evaluate_minimax.sh b/scripts/evaluate_minimax.sh new file mode 100644 index 0000000..f58e556 --- /dev/null +++ b/scripts/evaluate_minimax.sh @@ -0,0 +1,12 @@ +# sets necessary environment variables +source scripts/env.sh + +# Evaluate MiniMax-M2.5 +python3 task_eval/evaluate_qa.py \ + --data-file $DATA_FILE_PATH --out-file $OUT_DIR/$QA_OUTPUT_FILE \ + --model minimax-m2.5 --batch-size 10 + +# Evaluate MiniMax-M2.5-highspeed (204K context, faster inference) +python3 task_eval/evaluate_qa.py \ + --data-file $DATA_FILE_PATH --out-file $OUT_DIR/$QA_OUTPUT_FILE \ + --model minimax-m2.5-highspeed --batch-size 10 diff --git a/task_eval/evaluate_qa.py b/task_eval/evaluate_qa.py index c3e888c..1043fb2 100644 --- a/task_eval/evaluate_qa.py +++ b/task_eval/evaluate_qa.py @@ -5,12 +5,13 @@ import os, json from tqdm import tqdm import argparse -from global_methods import set_openai_key, set_anthropic_key, set_gemini_key +from global_methods import set_openai_key, set_anthropic_key, set_gemini_key, set_minimax_key from task_eval.evaluation import eval_question_answering from task_eval.evaluation_stats import analyze_aggr_acc from task_eval.gpt_utils import get_gpt_answers from task_eval.claude_utils import get_claude_answers from task_eval.gemini_utils import get_gemini_answers +from task_eval.minimax_utils import get_minimax_answers from task_eval.hf_llm_utils import init_hf_model, get_hf_answers import numpy as np @@ -56,7 +57,10 @@ def main(): model_name = "models/gemini-1.0-pro-latest" gemini_model = genai.GenerativeModel(model_name) - + + elif 'minimax' in args.model: + set_minimax_key() + elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]): hf_pipeline, hf_model_name = init_hf_model(args) @@ -90,6 +94,8 @@ def main(): answers = get_claude_answers(data, out_data, prediction_key, args) elif 'gemini' in args.model: answers = get_gemini_answers(gemini_model, data, out_data, prediction_key, args) + elif 'minimax' in args.model: + answers = get_minimax_answers(data, out_data, prediction_key, args) elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]): answers = get_hf_answers(data, out_data, args, hf_pipeline, hf_model_name) else: diff --git a/task_eval/minimax_utils.py b/task_eval/minimax_utils.py new file mode 100644 index 0000000..60bc8ce --- /dev/null +++ b/task_eval/minimax_utils.py @@ -0,0 +1,203 @@ +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import random +import os, json +from tqdm import tqdm +import time +from global_methods import run_minimax + + +MAX_LENGTH = { + 'minimax-m2.5': 204000, + 'minimax-m2.5-highspeed': 204000, + 'minimax-m2.7': 1000000, +} +PER_QA_TOKEN_BUDGET = 50 + +QA_PROMPT = """ +Based on the above context, write an answer in the form of a short phrase for the following question. Answer with exact words from the context whenever possible. + +Question: {} Short answer: +""" + +QA_PROMPT_CAT_5 = """ +Based on the above context, answer the following question. + +Question: {} Short answer: +""" + +QA_PROMPT_BATCH = """ +Based on the above conversations, write short answers for each of the following questions in a few words. Write the answers in the form of a json dictionary where each entry contains the string format of question number as 'key' and the short answer as value. Use single-quote characters for named entities. Answer with exact words from the conversations whenever possible. + +""" + +CONV_START_PROMPT = "Below is a conversation between two people: {} and {}. The conversation takes place over multiple days and the date of each conversation is wriiten at the beginning of the conversation.\n\n" + + +def process_ouput(text): + + text = text.strip() + if text[0] != '{': + start = text.index('{') + text = text[start:].strip() + + return json.loads(text) + + +def get_cat_5_answer(model_prediction, answer_key): + + model_prediction = model_prediction.strip().lower() + if len(model_prediction) == 1: + if 'a' in model_prediction: + return answer_key['a'] + else: + return answer_key['b'] + elif len(model_prediction) == 3: + if '(a)' in model_prediction: + return answer_key['a'] + else: + return answer_key['b'] + else: + return model_prediction + + +def get_input_context(data, num_question_tokens, model, args): + + query_conv = '' + stop = False + session_nums = [int(k.split('_')[-1]) for k in data.keys() if 'session' in k and 'date_time' not in k] + for i in range(min(session_nums), max(session_nums) + 1): + if 'session_%s' % i in data: + query_conv += "\n\n" + for dialog in data['session_%s' % i][::-1]: + turn = '' + turn = dialog['speaker'] + ' said, \"' + dialog['text'] + '\"' + '\n' + if "blip_caption" in dialog: + turn += ' and shared %s.' % dialog["blip_caption"] + turn += '\n' + + query_conv = turn + query_conv + + query_conv = '\nDATE: ' + data['session_%s_date_time' % i] + '\n' + 'CONVERSATION:\n' + query_conv + if stop: + break + + return query_conv + + +def get_minimax_answers(in_data, out_data, prediction_key, args): + + assert len(in_data['qa']) == len(out_data['qa']), (len(in_data['qa']), len(out_data['qa'])) + + # start instruction prompt + speakers_names = list(set([d['speaker'] for d in in_data['conversation']['session_1']])) + start_prompt = CONV_START_PROMPT.format(speakers_names[0], speakers_names[1]) + start_tokens = 100 + + if args.rag_mode: + raise NotImplementedError + else: + context_database, query_vectors = None, None + + for batch_start_idx in tqdm(range(0, len(in_data['qa']), args.batch_size), desc='Generating answers'): + + questions = [] + include_idxs = [] + cat_5_idxs = [] + cat_5_answers = [] + for i in range(batch_start_idx, batch_start_idx + args.batch_size): + + if i >= len(in_data['qa']): + break + + qa = in_data['qa'][i] + + if prediction_key not in out_data['qa'][i] or args.overwrite: + include_idxs.append(i) + else: + continue + + if qa['category'] == 2: + questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.') + elif qa['category'] == 5: + question = qa['question'] + " Select the correct answer: (a) {} (b) {}. " + if random.random() < 0.5: + question = question.format('Not mentioned in the conversation', qa['answer']) + answer = {'a': 'Not mentioned in the conversation', 'b': qa['answer']} + else: + question = question.format(qa['answer'], 'Not mentioned in the conversation') + answer = {'b': 'Not mentioned in the conversation', 'a': qa['answer']} + + cat_5_idxs.append(len(questions)) + questions.append(question) + cat_5_answers.append(answer) + else: + questions.append(qa['question']) + + if questions == []: + continue + + context_ids = None + if args.use_rag: + raise NotImplementedError + else: + question_prompt = QA_PROMPT_BATCH + "\n".join(["%s: %s" % (k, q) for k, q in enumerate(questions)]) + num_question_tokens = 100 + query_conv = get_input_context(in_data['conversation'], num_question_tokens + start_tokens, None, args) + query_conv = start_prompt + query_conv + + if args.batch_size == 1: + + query = query_conv + '\n\n' + QA_PROMPT.format(questions[0]) if len(cat_5_idxs) == 0 else query_conv + '\n\n' + QA_PROMPT_CAT_5.format(questions[0]) + answer = run_minimax(query, PER_QA_TOKEN_BUDGET, args.model) + + if len(cat_5_idxs) > 0: + answer = get_cat_5_answer(answer, cat_5_answers[0]) + + out_data['qa'][include_idxs[0]][prediction_key] = answer.strip() + if args.use_rag: + out_data['qa'][include_idxs[0]][prediction_key + '_context'] = context_ids + + else: + query = query_conv + '\n' + question_prompt + + trials = 0 + while trials < 5: + try: + trials += 1 + answer = run_minimax(query, PER_QA_TOKEN_BUDGET * args.batch_size, args.model) + answer = answer.replace('\\"', "'").replace('json', '').replace('`', '').strip() + answers = process_ouput(answer.strip()) + break + except json.decoder.JSONDecodeError: + pass + + for k, idx in enumerate(include_idxs): + try: + answers = process_ouput(answer.strip()) + if k in cat_5_idxs: + predicted_answer = get_cat_5_answer(answers[str(k)], cat_5_answers[cat_5_idxs.index(k)]) + out_data['qa'][idx][prediction_key] = predicted_answer + else: + try: + out_data['qa'][idx][prediction_key] = str(answers[str(k)]).replace('(a)', '').replace('(b)', '').strip() + except: + out_data['qa'][idx][prediction_key] = ', '.join([str(n) for n in list(answers[str(k)].values())]) + except: + try: + answers = json.loads(answer.strip()) + if k in cat_5_idxs: + predicted_answer = get_cat_5_answer(answers[k], cat_5_answers[cat_5_idxs.index(k)]) + out_data['qa'][idx][prediction_key] = predicted_answer + else: + out_data['qa'][idx][prediction_key] = answers[k].replace('(a)', '').replace('(b)', '').strip() + except: + if k in cat_5_idxs: + predicted_answer = get_cat_5_answer(answer.strip(), cat_5_answers[cat_5_idxs.index(k)]) + out_data['qa'][idx][prediction_key] = predicted_answer + else: + out_data['qa'][idx][prediction_key] = json.loads(answer.strip().replace('(a)', '').replace('(b)', '').split('\n')[k])[0] + + return out_data diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_minimax_integration.py b/tests/test_minimax_integration.py new file mode 100644 index 0000000..2482d59 --- /dev/null +++ b/tests/test_minimax_integration.py @@ -0,0 +1,90 @@ +"""Integration tests for MiniMax provider in LoCoMo. + +These tests verify the full pipeline works with the MiniMax API. +They require MINIMAX_API_KEY to be set in the environment. +""" + +import sys +import os +import json +import unittest +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +MINIMAX_API_KEY = os.environ.get('MINIMAX_API_KEY', '') +SKIP_REASON = "MINIMAX_API_KEY not set; skipping integration tests" + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMinimaxAPIIntegration(unittest.TestCase): + """Integration tests that call the real MiniMax API.""" + + def test_run_minimax_m2_5(self): + """Test a real API call to MiniMax-M2.5.""" + from global_methods import run_minimax + + # M2.5 uses thinking tokens that count against max_tokens, + # so we need a larger budget + result = run_minimax( + "What is 2 + 2? Answer with just the number.", + 1024, + "minimax-m2.5" + ) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) + self.assertIn('4', result) + + def test_run_minimax_m2_5_highspeed(self): + """Test a real API call to MiniMax-M2.5-highspeed.""" + from global_methods import run_minimax + + result = run_minimax( + "Name a color of the rainbow. Answer with just one word.", + 1024, + "minimax-m2.5-highspeed" + ) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) + + def test_minimax_qa_pipeline(self): + """Test the full MiniMax QA pipeline with a minimal conversation.""" + from task_eval.minimax_utils import get_minimax_answers + from unittest.mock import MagicMock + + in_data = { + 'conversation': { + 'session_1': [ + {'speaker': 'Alice', 'text': 'I just adopted a golden retriever named Max!', 'dia_id': 'd1'}, + {'speaker': 'Bob', 'text': 'That is wonderful! What breed is he?', 'dia_id': 'd2'}, + {'speaker': 'Alice', 'text': 'He is a golden retriever, about 2 years old.', 'dia_id': 'd3'}, + ], + 'session_1_date_time': '2024-03-15', + }, + 'qa': [ + {'question': "What is Alice's dog's name?", 'answer': 'Max', 'category': 1, 'evidence': ['d1']} + ] + } + out_data = { + 'qa': [ + {'question': "What is Alice's dog's name?", 'answer': 'Max', 'category': 1, 'evidence': ['d1']} + ] + } + + args = MagicMock() + args.model = 'minimax-m2.5' + args.batch_size = 1 + args.use_rag = False + args.rag_mode = '' + args.overwrite = True + + result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args) + + self.assertIn('minimax-m2.5_prediction', result['qa'][0]) + prediction = result['qa'][0]['minimax-m2.5_prediction'].lower() + self.assertIn('max', prediction) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_minimax_unit.py b/tests/test_minimax_unit.py new file mode 100644 index 0000000..12265a9 --- /dev/null +++ b/tests/test_minimax_unit.py @@ -0,0 +1,426 @@ +"""Unit tests for MiniMax provider integration in LoCoMo.""" + +import sys +import os +import json +import unittest +from unittest.mock import patch, MagicMock +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +class TestRunMinimax(unittest.TestCase): + """Test the run_minimax function in global_methods.""" + + @patch('global_methods.httpx') + def test_run_minimax_basic(self, mock_httpx): + """Test basic MiniMax API call.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Paris"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'test-key' + result = run_minimax("What is the capital of France?", 32, "minimax-m2.5") + + self.assertEqual(result, "Paris") + mock_httpx.post.assert_called_once() + call_args = mock_httpx.post.call_args + payload = call_args[1]['json'] + self.assertEqual(payload['model'], 'MiniMax-M2.5') + self.assertEqual(payload['max_tokens'], 32) + + @patch('global_methods.httpx') + def test_run_minimax_model_mapping(self, mock_httpx): + """Test model name mapping for different MiniMax models.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'test-key' + + test_cases = [ + ('minimax-m2.5', 'MiniMax-M2.5'), + ('minimax-m2.5-highspeed', 'MiniMax-M2.5-highspeed'), + ('minimax-m2.7', 'MiniMax-M2.7'), + ] + + for input_name, expected_api_name in test_cases: + mock_httpx.post.reset_mock() + run_minimax("test", 32, input_name) + payload = mock_httpx.post.call_args[1]['json'] + self.assertEqual(payload['model'], expected_api_name, + f"Model {input_name} should map to {expected_api_name}") + + @patch('global_methods.httpx') + def test_run_minimax_temperature_clamping(self, mock_httpx): + """Test that temperature is clamped to valid range.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'test-key' + + # temperature=0 should be clamped to 0.01 + run_minimax("test", 32, "minimax-m2.5", temperature=0) + payload = mock_httpx.post.call_args[1]['json'] + self.assertEqual(payload['temperature'], 0.01) + + # temperature=0.5 should pass through + run_minimax("test", 32, "minimax-m2.5", temperature=0.5) + payload = mock_httpx.post.call_args[1]['json'] + self.assertEqual(payload['temperature'], 0.5) + + # temperature > 1.0 should be clamped to 1.0 + run_minimax("test", 32, "minimax-m2.5", temperature=1.5) + payload = mock_httpx.post.call_args[1]['json'] + self.assertEqual(payload['temperature'], 1.0) + + @patch('global_methods.httpx') + def test_run_minimax_strips_think_tags(self, mock_httpx): + """Test that thinking tags from M2.5 models are stripped.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Let me think about this...Paris"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'test-key' + result = run_minimax("test", 32, "minimax-m2.5") + self.assertEqual(result, "Paris") + + @patch('global_methods.httpx') + def test_run_minimax_api_url(self, mock_httpx): + """Test that the correct API URL is used.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'test-key' + run_minimax("test", 32, "minimax-m2.5") + + call_args = mock_httpx.post.call_args + self.assertEqual(call_args[0][0], "https://api.minimax.io/v1/chat/completions") + + @patch('global_methods.httpx') + def test_run_minimax_auth_header(self, mock_httpx): + """Test that the authorization header includes the API key.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from global_methods import run_minimax + os.environ['MINIMAX_API_KEY'] = 'my-secret-key' + run_minimax("test", 32, "minimax-m2.5") + + call_args = mock_httpx.post.call_args + headers = call_args[1]['headers'] + self.assertEqual(headers['Authorization'], 'Bearer my-secret-key') + + +class TestSetMinimaxKey(unittest.TestCase): + """Test the set_minimax_key function.""" + + def test_set_minimax_key_no_error(self): + """Test that set_minimax_key runs without error.""" + from global_methods import set_minimax_key + set_minimax_key() # Should not raise + + +class TestMinimaxUtils(unittest.TestCase): + """Test the minimax_utils module.""" + + def test_process_output_valid_json(self): + """Test process_ouput with valid JSON.""" + from task_eval.minimax_utils import process_ouput + + result = process_ouput('{"0": "Paris", "1": "London"}') + self.assertEqual(result, {"0": "Paris", "1": "London"}) + + def test_process_output_with_prefix(self): + """Test process_ouput strips text before JSON.""" + from task_eval.minimax_utils import process_ouput + + result = process_ouput('Here is the answer: {"0": "Paris"}') + self.assertEqual(result, {"0": "Paris"}) + + def test_get_cat_5_answer_single_char_a(self): + """Test adversarial answer parsing: single character 'a'.""" + from task_eval.minimax_utils import get_cat_5_answer + + answer_key = {'a': 'Not mentioned', 'b': 'Paris'} + result = get_cat_5_answer('a', answer_key) + self.assertEqual(result, 'Not mentioned') + + def test_get_cat_5_answer_single_char_b(self): + """Test adversarial answer parsing: single character 'b'.""" + from task_eval.minimax_utils import get_cat_5_answer + + answer_key = {'a': 'Not mentioned', 'b': 'Paris'} + result = get_cat_5_answer('b', answer_key) + self.assertEqual(result, 'Paris') + + def test_get_cat_5_answer_parenthesized(self): + """Test adversarial answer parsing: parenthesized.""" + from task_eval.minimax_utils import get_cat_5_answer + + answer_key = {'a': 'Not mentioned', 'b': 'Paris'} + result = get_cat_5_answer('(a)', answer_key) + self.assertEqual(result, 'Not mentioned') + + def test_get_cat_5_answer_freeform(self): + """Test adversarial answer parsing: free-form text returned as-is (lowercased).""" + from task_eval.minimax_utils import get_cat_5_answer + + answer_key = {'a': 'Not mentioned', 'b': 'Paris'} + # The function lowercases the input, then returns it when length > 3 + result = get_cat_5_answer('The answer is Paris', answer_key) + self.assertEqual(result, 'the answer is paris') + + def test_get_input_context(self): + """Test conversation context extraction.""" + from task_eval.minimax_utils import get_input_context + + data = { + 'session_1': [ + {'speaker': 'Alice', 'text': 'Hello Bob!', 'dia_id': 'd1'}, + {'speaker': 'Bob', 'text': 'Hi Alice!', 'dia_id': 'd2'}, + ], + 'session_1_date_time': '2024-01-01', + } + args = MagicMock() + result = get_input_context(data, 100, None, args) + self.assertIn('Alice', result) + self.assertIn('Bob', result) + self.assertIn('2024-01-01', result) + + def test_get_input_context_multimodal(self): + """Test that blip captions are included in context.""" + from task_eval.minimax_utils import get_input_context + + data = { + 'session_1': [ + {'speaker': 'Alice', 'text': 'Look at this!', 'dia_id': 'd1', + 'blip_caption': 'a photo of a sunset'}, + ], + 'session_1_date_time': '2024-01-01', + } + args = MagicMock() + result = get_input_context(data, 100, None, args) + self.assertIn('a photo of a sunset', result) + + def test_max_length_definitions(self): + """Test that MAX_LENGTH contains correct MiniMax model entries.""" + from task_eval.minimax_utils import MAX_LENGTH + + self.assertIn('minimax-m2.5', MAX_LENGTH) + self.assertIn('minimax-m2.5-highspeed', MAX_LENGTH) + self.assertIn('minimax-m2.7', MAX_LENGTH) + self.assertEqual(MAX_LENGTH['minimax-m2.5'], 204000) + self.assertEqual(MAX_LENGTH['minimax-m2.7'], 1000000) + + def test_prompts_defined(self): + """Test that required prompts are defined in minimax_utils.""" + from task_eval import minimax_utils + + self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT')) + self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT_CAT_5')) + self.assertTrue(hasattr(minimax_utils, 'QA_PROMPT_BATCH')) + self.assertTrue(hasattr(minimax_utils, 'CONV_START_PROMPT')) + + +class TestEvaluateQADispatch(unittest.TestCase): + """Test that evaluate_qa.py correctly dispatches to MiniMax.""" + + def test_minimax_import(self): + """Test that minimax_utils can be imported from evaluate_qa context.""" + from task_eval.minimax_utils import get_minimax_answers + self.assertTrue(callable(get_minimax_answers)) + + def test_set_minimax_key_import(self): + """Test that set_minimax_key can be imported from global_methods.""" + from global_methods import set_minimax_key + self.assertTrue(callable(set_minimax_key)) + + @patch('global_methods.httpx') + def test_get_minimax_answers_single_batch(self, mock_httpx): + """Test get_minimax_answers with batch_size=1.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Paris"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from task_eval.minimax_utils import get_minimax_answers + + os.environ['MINIMAX_API_KEY'] = 'test-key' + in_data = { + 'conversation': { + 'session_1': [ + {'speaker': 'Alice', 'text': 'I live in Paris.', 'dia_id': 'd1'}, + {'speaker': 'Bob', 'text': 'Nice!', 'dia_id': 'd2'}, + ], + 'session_1_date_time': '2024-01-01', + }, + 'qa': [ + {'question': 'Where does Alice live?', 'answer': 'Paris', 'category': 1, 'evidence': []} + ] + } + out_data = { + 'qa': [ + {'question': 'Where does Alice live?', 'answer': 'Paris', 'category': 1, 'evidence': []} + ] + } + + args = MagicMock() + args.model = 'minimax-m2.5' + args.batch_size = 1 + args.use_rag = False + args.rag_mode = '' + args.overwrite = True + + result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args) + + self.assertIn('minimax-m2.5_prediction', result['qa'][0]) + self.assertEqual(result['qa'][0]['minimax-m2.5_prediction'], 'Paris') + + @patch('global_methods.httpx') + def test_get_minimax_answers_cat_5(self, mock_httpx): + """Test get_minimax_answers with adversarial category 5 question.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "(a)"}}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx.post.return_value = mock_response + + from task_eval.minimax_utils import get_minimax_answers + + os.environ['MINIMAX_API_KEY'] = 'test-key' + in_data = { + 'conversation': { + 'session_1': [ + {'speaker': 'Alice', 'text': 'Hello Bob', 'dia_id': 'd1'}, + {'speaker': 'Bob', 'text': 'Hi Alice', 'dia_id': 'd2'}, + ], + 'session_1_date_time': '2024-01-01', + }, + 'qa': [ + {'question': 'Does Alice own a car?', 'answer': 'Not mentioned in the conversation', 'category': 5, 'evidence': []} + ] + } + out_data = { + 'qa': [ + {'question': 'Does Alice own a car?', 'answer': 'Not mentioned in the conversation', 'category': 5, 'evidence': []} + ] + } + + args = MagicMock() + args.model = 'minimax-m2.5' + args.batch_size = 1 + args.use_rag = False + args.rag_mode = '' + args.overwrite = True + + result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args) + self.assertIn('minimax-m2.5_prediction', result['qa'][0]) + + @patch('global_methods.httpx') + def test_get_minimax_answers_skip_existing(self, mock_httpx): + """Test that existing predictions are skipped without overwrite.""" + from task_eval.minimax_utils import get_minimax_answers + + in_data = { + 'conversation': { + 'session_1': [ + {'speaker': 'Alice', 'text': 'Hi', 'dia_id': 'd1'}, + {'speaker': 'Bob', 'text': 'Hello', 'dia_id': 'd2'}, + ], + 'session_1_date_time': '2024-01-01', + }, + 'qa': [ + {'question': 'Q?', 'answer': 'A', 'category': 1, 'evidence': []} + ] + } + out_data = { + 'qa': [ + {'question': 'Q?', 'answer': 'A', 'category': 1, 'evidence': [], + 'minimax-m2.5_prediction': 'existing_answer'} + ] + } + + args = MagicMock() + args.model = 'minimax-m2.5' + args.batch_size = 1 + args.use_rag = False + args.rag_mode = '' + args.overwrite = False + + result = get_minimax_answers(in_data, out_data, 'minimax-m2.5_prediction', args) + # Should not have called the API + mock_httpx.post.assert_not_called() + self.assertEqual(result['qa'][0]['minimax-m2.5_prediction'], 'existing_answer') + + +class TestEvaluateMinimaxShellScript(unittest.TestCase): + """Test that the evaluation shell script exists and has correct content.""" + + def test_script_exists(self): + """Test that evaluate_minimax.sh exists.""" + script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh' + self.assertTrue(script_path.exists()) + + def test_script_sources_env(self): + """Test that the script sources env.sh.""" + script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh' + content = script_path.read_text() + self.assertIn('source scripts/env.sh', content) + + def test_script_runs_m2_5(self): + """Test that the script evaluates minimax-m2.5.""" + script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh' + content = script_path.read_text() + self.assertIn('minimax-m2.5', content) + + def test_script_runs_highspeed(self): + """Test that the script evaluates minimax-m2.5-highspeed.""" + script_path = Path(__file__).parent.parent / 'scripts' / 'evaluate_minimax.sh' + content = script_path.read_text() + self.assertIn('minimax-m2.5-highspeed', content) + + +class TestEnvScript(unittest.TestCase): + """Test that env.sh includes MINIMAX_API_KEY.""" + + def test_env_has_minimax_key(self): + """Test that env.sh contains MINIMAX_API_KEY export.""" + env_path = Path(__file__).parent.parent / 'scripts' / 'env.sh' + content = env_path.read_text() + self.assertIn('MINIMAX_API_KEY', content) + + +if __name__ == '__main__': + unittest.main()