Skip to content

Commit 7b252d3

Browse files
authored
Switching Single Turn to LiteLLM (#31)
* switch to litellm * updating tests * updating single turn processor
1 parent fffd75c commit 7b252d3

11 files changed

+40
-24
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import asyncio
22
from typing import List
33

4-
from openai import AsyncOpenAI
4+
from litellm import acompletion
5+
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
56

6-
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
77
from eval_protocol.models import EvaluationRow, Message
88
from eval_protocol.pytest.types import RolloutProcessorConfig
99

1010

1111
async def default_single_turn_rollout_processor(
1212
rows: List[EvaluationRow], config: RolloutProcessorConfig
1313
) -> List[EvaluationRow]:
14-
"""Generate a single response from a Fireworks model concurrently."""
15-
16-
api_key = get_fireworks_api_key()
17-
api_base = get_fireworks_api_base()
18-
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
14+
"""Generate a single response from any supported model provider using LiteLLM."""
1915

2016
async def process_row(row: EvaluationRow) -> EvaluationRow:
2117
"""Process a single row asynchronously."""
@@ -24,17 +20,35 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2420

2521
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
2622

27-
create_kwargs = dict(model=config.model, messages=messages_payload, **config.input_params)
23+
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
24+
2825
if row.tools is not None:
29-
create_kwargs["tools"] = row.tools
30-
response = await client.chat.completions.create(**create_kwargs)
26+
request_params["tools"] = row.tools
27+
28+
response = await acompletion(**request_params)
29+
3130
assistant_content = response.choices[0].message.content or ""
3231
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
32+
33+
converted_tool_calls = None
34+
if tool_calls:
35+
converted_tool_calls = [
36+
ChatCompletionMessageToolCall(
37+
id=tool_call.id,
38+
type=tool_call.type,
39+
function={
40+
"name": tool_call.function.name,
41+
"arguments": tool_call.function.arguments,
42+
},
43+
)
44+
for tool_call in tool_calls
45+
]
46+
3347
messages = list(row.messages) + [
3448
Message(
3549
role="assistant",
3650
content=assistant_content,
37-
tool_calls=tool_calls,
51+
tool_calls=converted_tool_calls,
3852
)
3953
]
4054

tests/pytest/test_apps_coding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio
2929
@evaluation_test(
3030
input_dataset=["tests/pytest/data/apps_sample_dataset.jsonl"],
3131
dataset_adapter=apps_dataset_to_evaluation_row,
32-
model=["accounts/fireworks/models/kimi-k2-instruct"],
32+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
3333
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
3434
threshold_of_success=0.33,
3535
rollout_processor=default_single_turn_rollout_processor,

tests/pytest/test_basic_coding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat
2828
@evaluation_test(
2929
input_dataset=["tests/pytest/data/basic_coding_dataset.jsonl"],
3030
dataset_adapter=coding_dataset_to_evaluation_row,
31-
model=["accounts/fireworks/models/kimi-k2-instruct"],
31+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
3232
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
3333
threshold_of_success=0.8,
3434
rollout_processor=default_single_turn_rollout_processor,

tests/pytest/test_hallucination.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import json
1010
from typing import Any, Dict, List
1111

12-
from fireworks import LLM
12+
import litellm
1313

1414
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
1515
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
1616

17-
judge_llm = LLM(model="accounts/fireworks/models/kimi-k2-instruct", deployment_type="serverless")
17+
# Configure the judge model for LiteLLM
18+
JUDGE_MODEL = "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"
1819

1920

2021
def hallucination_dataset_adapter(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
@@ -31,7 +32,7 @@ def hallucination_dataset_adapter(data: List[Dict[str, Any]]) -> List[Evaluation
3132
@evaluation_test(
3233
input_dataset=["tests/pytest/data/halueval_sample_dataset.jsonl"],
3334
dataset_adapter=hallucination_dataset_adapter,
34-
model=["accounts/fireworks/models/kimi-k2-instruct"],
35+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
3536
rollout_input_params=[{"temperature": 0.0, "max_tokens": 512}],
3637
rollout_processor=default_single_turn_rollout_processor,
3738
threshold_of_success=0.33,
@@ -77,7 +78,8 @@ def test_hallucination_detection(row: EvaluationRow) -> EvaluationRow:
7778
"""
7879

7980
try:
80-
response = judge_llm.chat.completions.create(
81+
response = litellm.completion(
82+
model=JUDGE_MODEL,
8183
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
8284
temperature=0.1,
8385
max_tokens=500,

tests/pytest/test_markdown_highlighting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
2424
@evaluation_test(
2525
input_dataset=["tests/pytest/data/markdown_dataset.jsonl"],
2626
dataset_adapter=markdown_dataset_to_evaluation_row,
27-
model=["accounts/fireworks/models/kimi-k2-instruct"],
27+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
2828
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
2929
threshold_of_success=0.5,
3030
rollout_processor=default_single_turn_rollout_processor,

tests/pytest/test_pytest_function_calling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def function_calling_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evalu
1919

2020
@evaluation_test(
2121
input_dataset=["tests/pytest/data/function_calling.jsonl"],
22-
model=["accounts/fireworks/models/kimi-k2-instruct"],
22+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
2323
mode="pointwise",
2424
dataset_adapter=function_calling_to_evaluation_row,
2525
rollout_processor=default_single_turn_rollout_processor,

tests/pytest/test_pytest_input_messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
Message(role="user", content="What is the capital of France?"),
1111
]
1212
],
13-
model=["accounts/fireworks/models/kimi-k2-instruct"],
13+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
1414
rollout_processor=default_single_turn_rollout_processor,
1515
)
1616
def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]:

tests/pytest/test_pytest_json_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def json_schema_to_evaluation_row(rows: List[Dict[str, Any]]) -> List[Evaluation
2323

2424
@evaluation_test(
2525
input_dataset=["tests/pytest/data/json_schema.jsonl"],
26-
model=["accounts/fireworks/models/kimi-k2-instruct"],
26+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
2727
mode="pointwise",
2828
rollout_processor=default_single_turn_rollout_processor,
2929
dataset_adapter=json_schema_to_evaluation_row,

tests/pytest/test_pytest_math_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@evaluation_test(
99
input_dataset=["development/gsm8k_sample.jsonl"],
1010
dataset_adapter=gsm8k_to_evaluation_row,
11-
model=["accounts/fireworks/models/kimi-k2-instruct"],
11+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
1212
rollout_input_params=[{"temperature": 0.0}],
1313
max_dataset_rows=5,
1414
threshold_of_success=0.0,

tests/pytest/test_pytest_math_format_length.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@evaluation_test(
1212
input_dataset=["development/gsm8k_sample.jsonl"],
1313
dataset_adapter=gsm8k_to_evaluation_row,
14-
model=["accounts/fireworks/models/kimi-k2-instruct"],
14+
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
1515
rollout_input_params=[{"temperature": 0.0}],
1616
max_dataset_rows=5,
1717
threshold_of_success=0.0,

0 commit comments

Comments
 (0)