diff --git a/libs/core/langchain_core/utils/tool_choice.py b/libs/core/langchain_core/utils/tool_choice.py new file mode 100644 index 0000000000000..491d324ea7b17 --- /dev/null +++ b/libs/core/langchain_core/utils/tool_choice.py @@ -0,0 +1,43 @@ +"""Utilities for normalizing `tool_choice` across providers. + +This provides a single place to normalize user-facing `tool_choice` values +into provider-specific representations (or canonical forms) so adapters can +consume them consistently. +""" +from __future__ import annotations + +from typing import Any + + +def normalize_tool_choice(tool_choice: Any) -> Any: + """Normalize common `tool_choice` inputs to canonical values. + + Normalizations applied: + - `"any"` -> `"required"` (many providers use `required`) + - `True` -> `"required"` + - truthy strings that match known well-known tools are left as-is + - `None` / `False` left as-is + + This function intentionally performs only conservative mappings; provider + adapters may further map the result to provider-specific payloads. + + Args: + tool_choice: Arbitrary user-supplied tool_choice value. + + Returns: + The normalized value. + """ + if tool_choice is None: + return None + # Booleans: True means require a tool, False means no-op + if isinstance(tool_choice, bool): + return "required" if tool_choice else None + # Strings + if isinstance(tool_choice, str): + lc = tool_choice.lower() + if lc == "any": + return "required" + # preserve 'auto', 'none', 'required', and specific names + return tool_choice + # dicts and other types: return as-is (adapters can validate) + return tool_choice diff --git a/libs/core/tests/unit_tests/test_tool_choice.py b/libs/core/tests/unit_tests/test_tool_choice.py new file mode 100644 index 0000000000000..5afefe5f3de76 --- /dev/null +++ b/libs/core/tests/unit_tests/test_tool_choice.py @@ -0,0 +1,22 @@ +from langchain_core.utils.tool_choice import normalize_tool_choice + + +def test_normalize_any_string(): + assert normalize_tool_choice("any") == "required" + + +def test_normalize_true_boolean(): + assert normalize_tool_choice(True) == "required" + + +def test_normalize_false_boolean(): + assert normalize_tool_choice(False) is None + + +def test_normalize_none(): + assert normalize_tool_choice(None) is None + + +def test_preserve_auto_and_required(): + assert normalize_tool_choice("auto") == "auto" + assert normalize_tool_choice("required") == "required" diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 790867cae5328..df78e5cc1532f 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -40,10 +40,7 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models import ( - LanguageModelInput, - ModelProfileRegistry, -) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, LangSmithParams, @@ -106,6 +103,7 @@ is_basemodel_subclass, ) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env +from langchain_core.utils.tool_choice import normalize_tool_choice from pydantic import ( BaseModel, ConfigDict, @@ -126,10 +124,8 @@ _convert_from_v1_to_responses, _convert_to_v03_ai_message, ) -from langchain_openai.data._profiles import _PROFILES if TYPE_CHECKING: - from langchain_core.language_models import ModelProfile from openai.types.responses import Response logger = logging.getLogger(__name__) @@ -138,14 +134,6 @@ # https://www.python-httpx.org/advanced/ssl/#configuring-client-instances global_ssl_context = ssl.create_default_context(cafile=certifi.where()) -_MODEL_PROFILES = cast(ModelProfileRegistry, _PROFILES) - - -def _get_default_model_profile(model_name: str) -> ModelProfile: - default = _MODEL_PROFILES.get(model_name) or {} - return default.copy() - - WellKnownTools = ( "file_search", "web_search_preview", @@ -573,7 +561,6 @@ async def get_api_key() -> str: !!! version-added "Added in `langchain-openai` 0.3.9" !!! warning "Behavior changed in `langchain-openai` 0.3.35" - Enabled for default base URL and client. """ @@ -804,7 +791,6 @@ async def get_api_key() -> str: - `'v1'`: v1 of LangChain cross-provider standard. !!! warning "Behavior changed in `langchain-openai` 1.0.0" - Default updated to `"responses/v1"`. """ @@ -826,15 +812,14 @@ def validate_temperature(cls, values: dict[str, Any]) -> Any: (Defaults to 1) """ model = values.get("model_name") or values.get("model") or "" - model_lower = model.lower() # For o1 models, set temperature=1 if not provided - if model_lower.startswith("o1") and "temperature" not in values: + if model.startswith("o1") and "temperature" not in values: values["temperature"] = 1 # For gpt-5 models, handle temperature restrictions # Note that gpt-5-chat models do support temperature - if model_lower.startswith("gpt-5") and "chat" not in model_lower: + if model.startswith("gpt-5") and "chat" not in model: temperature = values.get("temperature") if temperature is not None and temperature != 1: # For gpt-5 (non-chat), only temperature=1 is supported @@ -968,13 +953,6 @@ def validate_environment(self) -> Self: self.async_client = self.root_async_client.chat.completions return self - @model_validator(mode="after") - def _set_model_profile(self) -> Self: - """Set model profile if not overridden.""" - if self.profile is None: - self.profile = _get_default_model_profile(self.model_name) - return self - @property def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling OpenAI API.""" @@ -1669,13 +1647,15 @@ def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]: model = self.tiktoken_model_name else: model = self.model_name - try: encoding = tiktoken.encoding_for_model(model) except KeyError: - model_lower = model.lower() encoder = "cl100k_base" - if model_lower.startswith(("gpt-4o", "gpt-4.1", "gpt-5")): + if ( + self.model_name.startswith("gpt-4o") + or self.model_name.startswith("gpt-4.1") + or self.model_name.startswith("gpt-5") + ): encoder = "o200k_base" encoding = tiktoken.get_encoding(encoder) return model, encoding @@ -1837,6 +1817,8 @@ def bind_tools( else: pass if tool_choice: + # Normalize common user inputs first (e.g. 'any' -> 'required', True -> 'required') + tool_choice = normalize_tool_choice(tool_choice) if isinstance(tool_choice, str): # tool_choice is a tool/function name if tool_choice in tool_names: @@ -1846,19 +1828,17 @@ def bind_tools( } elif tool_choice in WellKnownTools: tool_choice = {"type": tool_choice} - # 'any' is not natively supported by OpenAI API. - # We support 'any' since other models use this instead of 'required'. - elif tool_choice == "any": - tool_choice = "required" else: + # leave other strings (e.g. 'auto', 'required') as-is pass - elif isinstance(tool_choice, bool): - tool_choice = "required" elif isinstance(tool_choice, dict): pass + elif tool_choice is None: + # explicit None -> no-op + pass else: msg = ( - f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Unrecognized tool_choice type. Expected str, bool, dict or None. " f"Received: {tool_choice}" ) raise ValueError(msg) @@ -2008,11 +1988,9 @@ def get_weather(location: str) -> str: - `'parsing_error'`: `BaseException | None` !!! warning "Behavior changed in `langchain-openai` 0.3.12" - Support for `tools` added. !!! warning "Behavior changed in `langchain-openai` 0.3.21" - Pass `kwargs` through to the model. """ if strict is not None and method == "json_mode": @@ -3132,15 +3110,12 @@ def get_weather(location: str) -> str: - `'parsing_error'`: `BaseException | None` !!! warning "Behavior changed in `langchain-openai` 0.3.0" - `method` default changed from `"function_calling"` to `"json_schema"`. !!! warning "Behavior changed in `langchain-openai` 0.3.12" - Support for `tools` added. !!! warning "Behavior changed in `langchain-openai` 0.3.21" - Pass `kwargs` through to the model. ??? note "Example: `schema=Pydantic` class, `method='json_schema'`, `include_raw=False`, `strict=True`"