Skip to content

Commit 3f7b4c3

Browse files
authored
add tests regarding hashes (#100)
* add tests regarding hashes * update * fix stable json test
1 parent a473f0c commit 3f7b4c3

File tree

2 files changed

+331
-13
lines changed

2 files changed

+331
-13
lines changed

eval_protocol/models.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,14 @@ class InputMetadata(BaseModel):
211211

212212
model_config = ConfigDict(extra="allow")
213213

214-
row_id: Optional[str] = Field(None, description="Unique string to ID the row")
214+
row_id: Optional[str] = Field(
215+
default=None,
216+
description=(
217+
"Unique string to ID the row. If not provided, a stable hash will be generated "
218+
"based on the row's content. The hash removes fields that are not typically stable "
219+
"across processes such as created_at, execution_metadata, and pid."
220+
),
221+
)
215222
completion_params: CompletionParams = Field(
216223
default_factory=dict, description="Completion endpoint parameters used"
217224
)
@@ -430,20 +437,53 @@ def get_termination_reason(self) -> str:
430437
return "unknown"
431438

432439
def __hash__(self) -> int:
433-
# Use a stable hash by sorting keys and ensuring compact output
434-
json_str = self.stable_json(self)
435-
return hash(json_str)
440+
# Use a stable hash that works across Python processes
441+
return self._stable_hash()
442+
443+
def _stable_hash(self) -> int:
444+
"""Generate a stable hash that works across Python processes."""
445+
import hashlib
446+
447+
# Get the stable JSON representation
448+
json_str = self._stable_json()
449+
450+
# Use SHA-256 for deterministic hashing across processes
451+
hash_obj = hashlib.sha256(json_str.encode("utf-8"))
452+
453+
# Convert to a positive integer (first 8 bytes)
454+
hash_bytes = hash_obj.digest()[:8]
455+
return int.from_bytes(hash_bytes, byteorder="big")
456+
457+
def _stable_json(self) -> str:
458+
"""Generate a stable JSON string representation for hashing."""
459+
# Produce a canonical, key-sorted JSON across nested structures and
460+
# exclude volatile fields that can differ across processes
461+
import json
462+
from enum import Enum
463+
464+
def canonicalize(value):
465+
# Recursively convert to a structure with deterministic key ordering
466+
if isinstance(value, dict):
467+
return {k: canonicalize(value[k]) for k in sorted(value.keys())}
468+
if isinstance(value, list):
469+
return [canonicalize(v) for v in value]
470+
if isinstance(value, Enum):
471+
return value.value
472+
return value
436473

437-
@classmethod
438-
def stable_json(cls, row: "EvaluationRow") -> int:
439-
json_str = row.model_dump_json(
474+
# Dump to a plain Python structure first
475+
data = self.model_dump(
440476
exclude_none=True,
441477
exclude_defaults=True,
442478
by_alias=True,
443-
indent=None,
444-
exclude=["created_at", "execution_metadata"],
479+
exclude={"created_at", "execution_metadata", "pid"},
445480
)
446-
return json_str
481+
482+
# Ensure deterministic ordering for all nested dicts
483+
canonical_data = canonicalize(data)
484+
485+
# Compact, sorted JSON string
486+
return json.dumps(canonical_data, separators=(",", ":"), sort_keys=True, ensure_ascii=False)
447487

448488

449489
# Original dataclass-based models for backwards compatibility

tests/test_models.py

Lines changed: 281 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,46 @@
1313
)
1414

1515

16+
def dummy_row() -> EvaluationRow:
17+
from eval_protocol.models import (
18+
EvaluateResult as _EvaluateResult,
19+
EvaluationRow as _EvaluationRow,
20+
InputMetadata as _InputMetadata,
21+
Message as _Message,
22+
MetricResult as _MetricResult,
23+
)
24+
25+
msgs = [
26+
_Message(role="system", content="You are a helpful assistant"),
27+
_Message(role="user", content="Compute 2+2"),
28+
_Message(role="assistant", content="4"),
29+
]
30+
eval_res = _EvaluateResult(
31+
score=1.0,
32+
reason="Correct",
33+
metrics={
34+
"accuracy": _MetricResult(score=1.0, reason="matches ground truth"),
35+
},
36+
)
37+
child_row = _EvaluationRow(
38+
messages=msgs,
39+
ground_truth="4",
40+
evaluation_result=eval_res,
41+
input_metadata=_InputMetadata(
42+
row_id="arith_0001",
43+
completion_params={"model": "dummy/local-model", "temperature": 0.0},
44+
dataset_info={"source": "unit_test", "variant": "subprocess"},
45+
session_data={"attempt": 1},
46+
),
47+
)
48+
return child_row
49+
50+
51+
def _child_compute_hash_value(_unused=None) -> int:
52+
row = dummy_row()
53+
return hash(row)
54+
55+
1656
def test_metric_result_creation():
1757
"""Test creating a MetricResult."""
1858
metric = MetricResult(score=0.5, reason="Test reason", is_score_valid=True)
@@ -289,7 +329,7 @@ def test_evaluation_row_creation():
289329
assert not row.is_trajectory_evaluation()
290330

291331

292-
def test_stable_hash():
332+
def test_stable_json():
293333
"""Test the stable hash method."""
294334
row = EvaluationRow(
295335
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
@@ -299,8 +339,8 @@ def test_stable_hash():
299339
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
300340
ground_truth="4",
301341
)
302-
stable_json = EvaluationRow.stable_json(row)
303-
stable_json2 = EvaluationRow.stable_json(row2)
342+
stable_json = row._stable_json()
343+
stable_json2 = row2._stable_json()
304344
assert stable_json == stable_json2
305345
assert "created_at" not in stable_json
306346
assert "execution_metadata" not in stable_json
@@ -382,3 +422,241 @@ def test_message_creation_requires_role():
382422
msg_none_content = Message(role="user") # content defaults to ""
383423
assert msg_none_content.role == "user"
384424
assert msg_none_content.content == ""
425+
426+
427+
def test_stable_hash_consistency():
428+
"""Test that the same EvaluationRow produces the same hash value consistently."""
429+
row1 = EvaluationRow(
430+
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
431+
ground_truth="4",
432+
)
433+
row2 = EvaluationRow(
434+
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
435+
ground_truth="4",
436+
)
437+
438+
# Same content should produce same hash
439+
assert hash(row1) == hash(row2)
440+
441+
# Hash should be consistent across multiple calls
442+
hash1_first = hash(row1)
443+
hash1_second = hash(row1)
444+
hash1_third = hash(row1)
445+
446+
assert hash1_first == hash1_second == hash1_third
447+
448+
# Hash should be a positive integer
449+
assert isinstance(hash1_first, int)
450+
assert hash1_first > 0
451+
452+
453+
def test_stable_hash_different_content():
454+
"""Test that different content produces different hash values."""
455+
row1 = EvaluationRow(
456+
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
457+
ground_truth="4",
458+
)
459+
row2 = EvaluationRow(
460+
messages=[Message(role="user", content="What is 3+3?"), Message(role="assistant", content="3+3 equals 6.")],
461+
ground_truth="6",
462+
)
463+
464+
# Different content should produce different hashes
465+
assert hash(row1) != hash(row2)
466+
467+
468+
def test_stable_hash_ignores_volatile_fields():
469+
"""Test that volatile fields like timestamps don't affect the hash."""
470+
messages = [Message(role="user", content="Test"), Message(role="assistant", content="Response")]
471+
472+
# Create rows with different timestamps
473+
row1 = EvaluationRow(messages=messages, ground_truth="test")
474+
row2 = EvaluationRow(messages=messages, ground_truth="test")
475+
476+
# Wait a moment to ensure different timestamps
477+
import time
478+
479+
time.sleep(0.001)
480+
481+
# Create another row
482+
row3 = EvaluationRow(messages=messages, ground_truth="test")
483+
484+
# All should have the same hash despite different timestamps
485+
assert hash(row1) == hash(row2) == hash(row3)
486+
487+
488+
def test_stable_hash_with_complex_data():
489+
"""Test stable hashing with complex nested data structures."""
490+
complex_messages = [
491+
Message(role="system", content="You are a helpful assistant"),
492+
Message(role="user", content="Solve this math problem: 15 * 23"),
493+
Message(
494+
role="assistant",
495+
content="Let me solve this step by step:\n1. 15 * 20 = 300\n2. 15 * 3 = 45\n3. 300 + 45 = 345",
496+
),
497+
Message(role="user", content="Thank you!"),
498+
Message(role="assistant", content="You're welcome! Let me know if you need help with anything else."),
499+
]
500+
501+
complex_evaluation = EvaluateResult(
502+
score=0.95,
503+
reason="Excellent step-by-step solution with clear explanation",
504+
metrics={
505+
"accuracy": MetricResult(score=1.0, reason="Correct mathematical calculation"),
506+
"explanation_quality": MetricResult(score=0.9, reason="Clear step-by-step breakdown"),
507+
"completeness": MetricResult(score=0.95, reason="Covers all aspects of the problem"),
508+
},
509+
)
510+
511+
row1 = EvaluationRow(
512+
messages=complex_messages,
513+
ground_truth="345",
514+
evaluation_result=complex_evaluation,
515+
input_metadata=InputMetadata(
516+
row_id="complex_math_001",
517+
completion_params={"model": "gpt-4", "temperature": 0.1},
518+
dataset_info={"source": "math_eval", "difficulty": "medium"},
519+
session_data={"user_id": "test_user", "session_id": "session_123"},
520+
),
521+
)
522+
523+
row2 = EvaluationRow(
524+
messages=complex_messages,
525+
ground_truth="345",
526+
evaluation_result=complex_evaluation,
527+
input_metadata=InputMetadata(
528+
row_id="complex_math_001",
529+
completion_params={"model": "gpt-4", "temperature": 0.1},
530+
dataset_info={"source": "math_eval", "difficulty": "medium"},
531+
session_data={"user_id": "test_user", "session_id": "session_123"},
532+
),
533+
)
534+
535+
# Complex data should still produce consistent hashes
536+
assert hash(row1) == hash(row2)
537+
538+
# Hash should be different from simple rows
539+
simple_row = EvaluationRow(
540+
messages=[Message(role="user", content="Simple"), Message(role="assistant", content="Response")],
541+
ground_truth="test",
542+
)
543+
assert hash(row1) != hash(simple_row)
544+
545+
546+
def test_stable_hash_json_representation():
547+
"""Test that the stable JSON representation is consistent and excludes volatile fields."""
548+
row = EvaluationRow(
549+
messages=[Message(role="user", content="Test"), Message(role="assistant", content="Response")],
550+
ground_truth="test",
551+
)
552+
553+
# Get the stable JSON representation
554+
stable_json = row._stable_json()
555+
556+
# Should be a valid JSON string
557+
parsed = json.loads(stable_json)
558+
559+
# Should contain the core data
560+
assert "messages" in parsed
561+
assert "ground_truth" in parsed
562+
assert parsed["ground_truth"] == "test"
563+
564+
# Should NOT contain volatile fields
565+
assert "created_at" not in parsed
566+
assert "execution_metadata" not in parsed
567+
568+
# Should be deterministic (same content produces same JSON)
569+
stable_json2 = row._stable_json()
570+
assert stable_json == stable_json2
571+
572+
573+
def test_stable_hash_consistency_for_identical_rows():
574+
"""Test that identical EvaluationRow objects produce the same stable hash.
575+
576+
This simulates the behavior expected across Python process restarts by
577+
creating multiple identical objects and ensuring their hashes match.
578+
"""
579+
# Create a complex evaluation row
580+
messages = [
581+
Message(role="user", content="What is the capital of France?"),
582+
Message(role="assistant", content="The capital of France is Paris."),
583+
Message(role="user", content="What about Germany?"),
584+
Message(role="assistant", content="The capital of Germany is Berlin."),
585+
]
586+
587+
evaluation_result = EvaluateResult(
588+
score=0.9,
589+
reason="Correct answers for both questions",
590+
metrics={
591+
"geography_knowledge": MetricResult(score=1.0, reason="Both capitals correctly identified"),
592+
"response_quality": MetricResult(score=0.8, reason="Clear and concise responses"),
593+
},
594+
)
595+
596+
# Create multiple identical rows
597+
rows = []
598+
for i in range(5):
599+
row = EvaluationRow(
600+
messages=messages,
601+
ground_truth="Paris, Berlin",
602+
evaluation_result=evaluation_result,
603+
input_metadata=InputMetadata(
604+
completion_params={"model": "gpt-4"},
605+
dataset_info={"source": "geography_eval"},
606+
),
607+
)
608+
rows.append(row)
609+
610+
# All rows should have identical hashes
611+
first_hash = hash(rows[0])
612+
for row in rows[1:]:
613+
assert hash(row) == first_hash
614+
615+
# The hash should be a large positive integer (SHA-256 first 8 bytes)
616+
assert first_hash > 0
617+
assert first_hash < 2**64 # 8 bytes = 64 bits
618+
619+
620+
def test_stable_hash_edge_cases():
621+
"""Test stable hashing with edge cases like empty data and None values."""
622+
# Empty messages
623+
empty_row = EvaluationRow(messages=[], ground_truth="")
624+
empty_hash = hash(empty_row)
625+
assert isinstance(empty_hash, int)
626+
assert empty_hash > 0
627+
628+
# None values in optional fields
629+
none_row = EvaluationRow(
630+
messages=[Message(role="user", content="Test")], ground_truth=None, evaluation_result=None
631+
)
632+
none_hash = hash(none_row)
633+
assert isinstance(none_hash, int)
634+
assert none_hash > 0
635+
636+
# Different from empty row
637+
assert empty_hash != none_hash
638+
639+
# Row with only required fields
640+
minimal_row = EvaluationRow(messages=[Message(role="user", content="Minimal")])
641+
minimal_hash = hash(minimal_row)
642+
assert isinstance(minimal_hash, int)
643+
assert minimal_hash > 0
644+
645+
# Should be different from other edge cases
646+
assert minimal_hash != empty_hash
647+
assert minimal_hash != none_hash
648+
649+
650+
def test_stable_hash_across_subprocess():
651+
"""Verify the same EvaluationRow produces the same hash in a separate Python process."""
652+
import multiprocessing as mp
653+
654+
row = dummy_row()
655+
parent_hash = hash(row)
656+
# Compute the same hash in a fresh interpreter via Pool.map (spawned process)
657+
ctx = mp.get_context("spawn")
658+
with ctx.Pool(processes=1) as pool:
659+
[child_hash] = pool.map(_child_compute_hash_value, [None])
660+
661+
assert isinstance(child_hash, int)
662+
assert parent_hash == child_hash

0 commit comments

Comments
 (0)