Skip to content

Commit 8581787

Browse files
xiaoboxiaobo
authored andcommitted
Upgrade to Pydantic v2: Update Validators and Packaging
- Refactored code to use `@model_validator` instead of `@root_validator` - Updated `Field` usage to align with Pydantic v2 standards - Modified validation logic to comply with the new API - Updated `poetry.lock` and dependencies in `pyproject.toml`
1 parent 08569c8 commit 8581787

File tree

11 files changed

+40
-38
lines changed

11 files changed

+40
-38
lines changed

nemoguardrails/eval/models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from typing import Any, Dict, List, Optional, Union
1818

19-
from pydantic import BaseModel, Field, root_validator
19+
from pydantic import BaseModel, Field, model_validator
2020

2121
from nemoguardrails.eval.utils import load_dict_from_path
2222
from nemoguardrails.logging.explain import LLMCallInfo
@@ -107,7 +107,7 @@ class InteractionSet(BaseModel):
107107
description="A list of tags that should be associated with the interactions. Useful for filtering when reporting.",
108108
)
109109

110-
@root_validator(pre=True)
110+
@model_validator(mode="before")
111111
def instantiate_expected_output(cls, values: Any):
112112
"""Creates the right instance of the expected output."""
113113
type_mapping = {
@@ -147,11 +147,11 @@ class EvalConfig(BaseModel):
147147
description="The prompts that should be used for the various LLM tasks.",
148148
)
149149

150-
@root_validator(pre=False, skip_on_failure=True)
151-
def validate_policy_ids(cls, values: Any):
150+
@model_validator(mode="after")
151+
def validate_policy_ids(cls, values: "EvalConfig") -> "EvalConfig":
152152
"""Validates the policy ids used in the interactions."""
153-
policy_ids = {policy.id for policy in values.get("policies")}
154-
for interaction_set in values.get("interactions"):
153+
policy_ids = {policy.id for policy in values.policies}
154+
for interaction_set in values.interactions:
155155
for expected_output in interaction_set.expected_output:
156156
if expected_output.policy not in policy_ids:
157157
raise ValueError(
@@ -180,7 +180,7 @@ def from_path(
180180
else:
181181
raise ValueError(f"Invalid config path {config_path}.")
182182

183-
return cls.parse_obj(config_obj)
183+
return cls.model_validate(config_obj)
184184

185185

186186
class ComplianceCheckLog(BaseModel):
@@ -361,4 +361,4 @@ def from_path(
361361
else:
362362
raise ValueError(f"Invalid config path {output_path}.")
363363

364-
return cls.parse_obj(output_obj)
364+
return cls.model_validate(output_obj)

nemoguardrails/library/factchecking/align_score/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# The minimal set of requirements for the AlignScore server to run.
2-
pydantic>=1.10
2+
pydantic>=2.10.6
33
fastapi>=0.109.1
44
starlette>=0.36.2
55
typer>=0.7.0

nemoguardrails/library/jailbreak_detection/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# The minimal set of requirements for the jailbreak detection server to run.
2-
pydantic>=1.10.9
2+
pydantic>=2.10.6
33
fastapi>=0.103.1
44
starlette>=0.27.0
55
typer>=0.7.0

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from langchain_core.messages import BaseMessage
2323
from langchain_core.outputs import ChatResult
2424
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
25-
from pydantic.v1 import Field
25+
from pydantic import Field
2626

2727
log = logging.getLogger(__name__)
2828

nemoguardrails/llm/providers/nemollm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from langchain.schema import Generation
2828
from langchain.schema.output import GenerationChunk, LLMResult
2929
from langchain_core.language_models.llms import BaseLLM
30-
from pydantic.v1 import root_validator
30+
from pydantic import model_validator
3131

3232
log = logging.getLogger(__name__)
3333

@@ -52,7 +52,7 @@ class NeMoLLM(BaseLLM):
5252
streaming: bool = False
5353
check_api_host_version: bool = True
5454

55-
@root_validator(pre=True, allow_reuse=True)
55+
@model_validator(mode="before")
5656
def check_env_variables(cls, values):
5757
for field in ["api_host", "api_key", "organization_id"]:
5858
# If it's an explicit environment variable, we use that

nemoguardrails/llm/providers/trtllm/llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from langchain.callbacks.manager import CallbackManagerForLLMRun
2424
from langchain_core.language_models.llms import BaseLLM
25-
from pydantic.v1 import Field, root_validator
25+
from pydantic import Field, model_validator
2626

2727
from nemoguardrails.llm.providers.trtllm.client import TritonClient
2828

@@ -61,12 +61,12 @@ class TRTLLM(BaseLLM):
6161
client: Any
6262
streaming: Optional[bool] = True
6363

64-
@root_validator(allow_reuse=True)
64+
@model_validator(mode="after")
6565
@classmethod
66-
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
66+
def validate_environment(cls, values: "TRTLLM") -> "TRTLLM":
6767
"""Validate that python package exists in environment."""
6868
try:
69-
values["client"] = TritonClient(values["server_url"])
69+
values.client = TritonClient(values.server_url)
7070

7171
except ImportError as err:
7272
raise ImportError(

nemoguardrails/rails/llm/config.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2323

2424
import yaml
25-
from pydantic import BaseModel, ConfigDict, ValidationError, root_validator
26-
from pydantic.fields import Field
25+
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
2726

2827
from nemoguardrails import utils
2928
from nemoguardrails.colang import parse_colang_file, parse_flow_elements
@@ -213,7 +212,7 @@ class TaskPrompt(BaseModel):
213212
description="The maximum number of tokens that can be generated in the chat completion.",
214213
)
215214

216-
@root_validator(pre=True, allow_reuse=True)
215+
@model_validator(mode="before")
217216
def check_fields(cls, values):
218217
if not values.get("content") and not values.get("messages"):
219218
raise ValidationError("One of `content` or `messages` must be provided.")
@@ -867,16 +866,15 @@ class RailsConfig(BaseModel):
867866
description="The list of bot messages that should be used for the rails.",
868867
)
869868

870-
# NOTE: the Any below is used to get rid of a warning with pydantic 1.10.x;
871-
# The correct typing should be List[Dict, Flow]. To be updated when
872-
# support for pydantic 1.10.x is dropped.
873-
flows: List[Union[Dict, Any]] = Field(
869+
flows: List[Union[Dict, Flow]] = Field(
874870
default_factory=list,
875871
description="The list of flows that should be used for the rails.",
876872
)
877873

878874
instructions: Optional[List[Instruction]] = Field(
879-
default=[Instruction.parse_obj(obj) for obj in _default_config["instructions"]],
875+
default=[
876+
Instruction.model_validate(obj) for obj in _default_config["instructions"]
877+
],
880878
description="List of instructions in natural language that the LLM should use.",
881879
)
882880

@@ -981,7 +979,7 @@ class RailsConfig(BaseModel):
981979
description="Configuration for tracing.",
982980
)
983981

984-
@root_validator(pre=True, allow_reuse=True)
982+
@model_validator(mode="before")
985983
def check_prompt_exist_for_self_check_rails(cls, values):
986984
rails = values.get("rails", {})
987985

@@ -1035,7 +1033,7 @@ def check_prompt_exist_for_self_check_rails(cls, values):
10351033

10361034
return values
10371035

1038-
@root_validator(pre=True, allow_reuse=True)
1036+
@model_validator(mode="before")
10391037
def check_output_parser_exists(cls, values):
10401038
tasks_requiring_output_parser = [
10411039
"self_check_input",
@@ -1068,7 +1066,7 @@ def check_output_parser_exists(cls, values):
10681066
)
10691067
return values
10701068

1071-
@root_validator(pre=True, allow_reuse=True)
1069+
@model_validator(mode="before")
10721070
def fill_in_default_values_for_v2_x(cls, values):
10731071
instructions = values.get("instructions", {})
10741072
sample_conversation = values.get("sample_conversation")
@@ -1197,7 +1195,7 @@ def parse_object(cls, obj):
11971195
):
11981196
flow_data["elements"] = parse_flow_elements(flow_data["elements"])
11991197

1200-
return cls.parse_obj(obj)
1198+
return cls.model_validate(obj)
12011199

12021200
@property
12031201
def streaming_supported(self):

nemoguardrails/rails/llm/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"""
7979
from typing import Any, Dict, List, Optional, Union
8080

81-
from pydantic import BaseModel, Field, root_validator
81+
from pydantic import BaseModel, Field, model_validator
8282

8383
from nemoguardrails.logging.explain import LLMCallInfo, LLMCallSummary
8484

@@ -168,7 +168,7 @@ class GenerationOptions(BaseModel):
168168
description="Options about what to include in the log. By default, nothing is included. ",
169169
)
170170

171-
@root_validator(pre=True, allow_reuse=True)
171+
@model_validator(mode="before")
172172
def check_fields(cls, values):
173173
# Translate the `rails` generation option from List[str] to dict.
174174
if "rails" in values and isinstance(values["rails"], list):

nemoguardrails/server/api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from fastapi import FastAPI, Request
2727
from fastapi.middleware.cors import CORSMiddleware
28-
from pydantic import BaseModel, Field, root_validator, validator
28+
from pydantic import BaseModel, Field, field_validator, model_validator
2929
from starlette.responses import StreamingResponse
3030
from starlette.staticfiles import StaticFiles
3131

@@ -140,7 +140,7 @@ class RequestBody(BaseModel):
140140
description="A state object that should be used to continue the interaction.",
141141
)
142142

143-
@root_validator(pre=True)
143+
@model_validator(mode="before")
144144
def ensure_config_id(cls, data: Any) -> Any:
145145
if isinstance(data, dict):
146146
if data.get("config_id") is not None and data.get("config_ids") is not None:
@@ -155,11 +155,15 @@ def ensure_config_id(cls, data: Any) -> Any:
155155
)
156156
return data
157157

158-
@validator("config_ids", pre=True, always=True)
158+
@field_validator("config_ids", mode="before")
159159
def ensure_config_ids(cls, v, values):
160-
if v is None and values.get("config_id") and values.get("config_ids") is None:
160+
if (
161+
v is None
162+
and values.data.get("config_id")
163+
and values.data.get("config_ids") is None
164+
):
161165
# populate config_ids with config_id if only config_id is provided
162-
return [values["config_id"]]
166+
return [values.data["config_id"]]
163167
return v
164168

165169

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ langchain-community = ">=0.0.16,<0.4.0"
6565
lark = ">=1.1.7"
6666
nest-asyncio = ">=1.5.6,"
6767
prompt-toolkit = ">=3.0"
68-
pydantic = ">=1.10"
68+
pydantic = ">=2.0"
6969
pyyaml = ">=6.0"
7070
rich = ">=13.5.2"
7171
simpleeval = ">=0.9.13,"

0 commit comments

Comments
 (0)