Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 33 additions & 129 deletions src/themefinder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,118 +22,32 @@ class EvidenceRich(str, Enum):
NO = "NO"


class ValidatedModel(BaseModel):
"""Base model with common validation methods"""
class SentimentAnalysisOutput(BaseModel):
"""Model for sentiment analysis output"""

def validate_non_empty_fields(self) -> "ValidatedModel":
"""
Validate that all string fields are non-empty and all list fields are not empty.
"""
for field_name, value in self.__dict__.items():
if isinstance(value, str) and not value.strip():
raise ValueError(f"{field_name} cannot be empty or only whitespace")
if isinstance(value, list) and not value:
raise ValueError(f"{field_name} cannot be an empty list")
if isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, str) and not item.strip():
raise ValueError(
f"Item {i} in {field_name} cannot be empty or only whitespace"
)
return self
response_id: int = Field(gt=0)
position: Position

def validate_unique_items(
self, field_name: str, transform_func: Optional[callable] = None
) -> "ValidatedModel":
"""
Validate that a field contains unique values.

Args:
field_name: The name of the field to check for uniqueness
transform_func: Optional function to transform items before checking uniqueness
(e.g., lowercasing strings)
"""
if not hasattr(self, field_name):
raise ValueError(f"Field '{field_name}' does not exist")
items = getattr(self, field_name)
if not isinstance(items, list):
raise ValueError(f"Field '{field_name}' is not a list")
if transform_func:
transformed_items = [transform_func(item) for item in items]
else:
transformed_items = items
if len(transformed_items) != len(set(transformed_items)):
raise ValueError(f"'{field_name}' must contain unique values")
return self
class SentimentAnalysisResponses(BaseModel):
"""Container for all sentiment analysis responses"""

def validate_unique_attribute_in_list(
self, list_field: str, attr_name: str
) -> "ValidatedModel":
"""
Validate that an attribute across all objects in a list field is unique.

Args:
list_field: The name of the list field containing objects
attr_name: The attribute within each object to check for uniqueness
"""
if not hasattr(self, list_field):
raise ValueError(f"Field '{list_field}' does not exist")

items = getattr(self, list_field)
if not isinstance(items, list):
raise ValueError(f"Field '{list_field}' is not a list")

attr_values = []
for item in items:
if not hasattr(item, attr_name):
raise ValueError(
f"Item in '{list_field}' does not have attribute '{attr_name}'"
)
attr_values.append(getattr(item, attr_name))
if len(attr_values) != len(set(attr_values)):
raise ValueError(
f"'{attr_name}' must be unique across all items in '{list_field}'"
)
return self

def validate_equal_lengths(self, *field_names) -> "ValidatedModel":
"""
Validate that multiple list fields have the same length.

Args:
*field_names: Variable number of field names to check for equal lengths
"""
if len(field_names) < 2:
return self
lengths = []
for field_name in field_names:
if not hasattr(self, field_name):
raise ValueError(f"Field '{field_name}' does not exist")

items = getattr(self, field_name)
if not isinstance(items, list):
raise ValueError(f"Field '{field_name}' is not a list")

lengths.append(len(items))
if len(set(lengths)) > 1:
raise ValueError(
f"Fields {', '.join(field_names)} must all have the same length"
)
return self
responses: List[SentimentAnalysisOutput] = Field(min_length=1)

@model_validator(mode="after")
def run_validations(self) -> "ValidatedModel":
"""
Run common validations. Override in subclasses to add specific validations.
"""
return self.validate_non_empty_fields()
def run_validations(self) -> "SentimentAnalysisResponses":
"""Validate that response_ids are unique"""
response_ids = [resp.response_id for resp in self.responses]
if len(response_ids) != len(set(response_ids)):
raise ValueError("Response IDs must be unique")
return self


def lower_case_strip_str(value: str) -> str:
return value.lower().strip()


class Theme(ValidatedModel):
class Theme(BaseModel):
"""Model for a single extracted theme"""

topic_label: Annotated[str, AfterValidator(lower_case_strip_str)] = Field(
Expand Down Expand Up @@ -184,7 +98,7 @@ def _reduce(themes: list[Theme]) -> Theme:
return self


class CondensedTheme(ValidatedModel):
class CondensedTheme(BaseModel):
"""Model for a single condensed theme"""

topic_label: Annotated[str, AfterValidator(lower_case_strip_str)] = Field(
Expand Down Expand Up @@ -239,13 +153,15 @@ def _reduce(topic_label: str) -> CondensedTheme:
return self


class RefinedTheme(ValidatedModel):
class RefinedTheme(BaseModel):
"""Model for a single refined theme"""

# TODO: Split into separate topic_label + topic_description fields to match
# Theme/CondensedTheme models. Currently evals must parse the combined string.
topic: str = Field(
..., description="Topic label and description combined with a colon separator"
...,
description="Topic label and description combined with a colon separator",
min_length=1,
)
source_topic_count: int = Field(
..., gt=0, description="Count of source topics combined"
Expand All @@ -254,7 +170,6 @@ class RefinedTheme(ValidatedModel):
@model_validator(mode="after")
def run_validations(self) -> "RefinedTheme":
"""Run all validations for RefinedTheme"""
self.validate_non_empty_fields()
self.validate_topic_format()
return self

Expand All @@ -278,58 +193,49 @@ def validate_topic_format(self) -> "RefinedTheme":
return self


class ThemeRefinementResponses(ValidatedModel):
class ThemeRefinementResponses(BaseModel):
"""Container for all refined themes"""

responses: List[RefinedTheme] = Field(..., description="List of refined themes")
responses: List[RefinedTheme] = Field(
..., description="List of refined themes", min_length=1
)

@model_validator(mode="after")
def run_validations(self) -> "ThemeRefinementResponses":
"""Ensure there are no duplicate themes"""
self.validate_non_empty_fields()
topics = [theme.topic.lower().strip() for theme in self.responses]
if len(topics) != len(set(topics)):
raise ValueError("Duplicate topics detected")

return self


class ThemeMappingOutput(ValidatedModel):
class ThemeMappingOutput(BaseModel):
"""Model for theme mapping output"""

response_id: int = Field(gt=0, description="Response ID, must be greater than 0")
labels: List[str] = Field(..., description="List of theme labels")

@model_validator(mode="after")
def run_validations(self) -> "ThemeMappingOutput":
"""
Run all validations for ThemeMappingOutput.
"""
self.validate_non_empty_fields()
self.validate_unique_items("labels")
return self
labels: set[str] = Field(..., description="List of theme labels", min_length=1)


class ThemeMappingResponses(ValidatedModel):
class ThemeMappingResponses(BaseModel):
"""Container for all theme mapping responses"""

responses: List[ThemeMappingOutput] = Field(
..., description="List of theme mapping outputs"
..., description="List of theme mapping outputs", min_length=1
)

@model_validator(mode="after")
def run_validations(self) -> "ThemeMappingResponses":
"""
Validate that response_ids are unique.
"""
self.validate_non_empty_fields()
response_ids = [resp.response_id for resp in self.responses]
if len(response_ids) != len(set(response_ids)):
raise ValueError("Response IDs must be unique")
return self


class DetailDetectionOutput(ValidatedModel):
class DetailDetectionOutput(BaseModel):
"""Model for detail detection output"""

response_id: int = Field(gt=0, description="Response ID, must be greater than 0")
Expand All @@ -338,26 +244,25 @@ class DetailDetectionOutput(ValidatedModel):
)


class DetailDetectionResponses(ValidatedModel):
class DetailDetectionResponses(BaseModel):
"""Container for all detail detection responses"""

responses: List[DetailDetectionOutput] = Field(
..., description="List of detail detection outputs"
..., description="List of detail detection outputs", min_length=1
)

@model_validator(mode="after")
def run_validations(self) -> "DetailDetectionResponses":
"""
Validate that response_ids are unique.
"""
self.validate_non_empty_fields()
response_ids = [resp.response_id for resp in self.responses]
if len(response_ids) != len(set(response_ids)):
raise ValueError("Response IDs must be unique")
return self


class ThemeNode(ValidatedModel):
class ThemeNode(BaseModel):
"""Model for topic nodes created during hierarchical clustering"""

topic_id: str = Field(
Expand Down Expand Up @@ -393,12 +298,12 @@ def run_validations(self) -> "ThemeNode":
return self


class HierarchicalClusteringResponse(ValidatedModel):
class HierarchicalClusteringResponse(BaseModel):
"""Model for hierarchical clustering agent response"""

parent_themes: List[ThemeNode] = Field(
default=[],
description="List of parent themes created by merging similar themes",
min_length=1,
)
should_terminate: bool = Field(
...,
Expand All @@ -408,7 +313,6 @@ class HierarchicalClusteringResponse(ValidatedModel):
@model_validator(mode="after")
def run_validations(self) -> "HierarchicalClusteringResponse":
"""Validate clustering response constraints"""
self.validate_non_empty_fields()

# Validate that no child appears in multiple parents
all_children = []
Expand Down
43 changes: 0 additions & 43 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from themefinder.models import (
Position,
EvidenceRich,
ValidatedModel,
Theme,
ThemeGenerationResponses,
CondensedTheme,
Expand All @@ -16,48 +15,6 @@
)


class TestValidatedModelAdditional:
class NestedModel(ValidatedModel):
attr: str

class ContainerModel(ValidatedModel):
items: list["TestValidatedModelAdditional.NestedModel"]
other_list: list[str]

class MockModel(ValidatedModel):
field1: str
field2: list[str] | None = None
field3: list[str] | None = None

def test_validate_unique_attribute_in_list(self):
model = self.ContainerModel(
items=[
self.NestedModel(attr="value1"),
self.NestedModel(attr="value2"),
],
other_list=["a", "b"],
)
model.validate_unique_attribute_in_list("items", "attr")

model = self.ContainerModel(
items=[
self.NestedModel(attr="same"),
self.NestedModel(attr="same"),
],
other_list=["a", "b"],
)
with pytest.raises(ValueError, match="must be unique across all items"):
model.validate_unique_attribute_in_list("items", "attr")

def test_validate_equal_lengths(self):
model = self.MockModel(field1="test", field2=["a", "b"], field3=["x", "y"])
model.validate_equal_lengths("field2", "field3")

model = self.MockModel(field1="test", field2=["a", "b"], field3=["x"])
with pytest.raises(ValueError, match="must all have the same length"):
model.validate_equal_lengths("field2", "field3")


class TestTheme:
def test_valid_theme(self):
theme = Theme(
Expand Down