Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions libs/core/langchain_core/utils/tool_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Utilities for normalizing `tool_choice` across providers.

This provides a single place to normalize user-facing `tool_choice` values
into provider-specific representations (or canonical forms) so adapters can
consume them consistently.
"""
from __future__ import annotations

from typing import Any


def normalize_tool_choice(tool_choice: Any) -> Any:
"""Normalize common `tool_choice` inputs to canonical values.

Normalizations applied:
- `"any"` -> `"required"` (many providers use `required`)
- `True` -> `"required"`
- truthy strings that match known well-known tools are left as-is
- `None` / `False` left as-is

This function intentionally performs only conservative mappings; provider
adapters may further map the result to provider-specific payloads.

Args:
tool_choice: Arbitrary user-supplied tool_choice value.

Returns:
The normalized value.
"""
if tool_choice is None:
return None
# Booleans: True means require a tool, False means no-op
if isinstance(tool_choice, bool):
return "required" if tool_choice else None
# Strings
if isinstance(tool_choice, str):
lc = tool_choice.lower()
if lc == "any":
return "required"
# preserve 'auto', 'none', 'required', and specific names
return tool_choice
# dicts and other types: return as-is (adapters can validate)
return tool_choice
22 changes: 22 additions & 0 deletions libs/core/tests/unit_tests/test_tool_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from langchain_core.utils.tool_choice import normalize_tool_choice


def test_normalize_any_string():
assert normalize_tool_choice("any") == "required"


def test_normalize_true_boolean():
assert normalize_tool_choice(True) == "required"


def test_normalize_false_boolean():
assert normalize_tool_choice(False) is None


def test_normalize_none():
assert normalize_tool_choice(None) is None


def test_preserve_auto_and_required():
assert normalize_tool_choice("auto") == "auto"
assert normalize_tool_choice("required") == "required"
57 changes: 16 additions & 41 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,131 @@
"""OpenAI chat wrapper."""

from __future__ import annotations

import base64
import json
import logging
import os
import re
import ssl
import sys
import warnings
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterator,
Mapping,
Sequence,
)
from functools import partial
from io import BytesIO
from json import JSONDecodeError
from math import ceil
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
TypeVar,
cast,
)
from urllib.parse import urlparse

import certifi
import openai
import tiktoken
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import (
LanguageModelInput,
ModelProfileRegistry,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
is_data_content_block,
)
from langchain_core.messages import content as types
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
)
from langchain_core.messages.block_translators.openai import (
_convert_from_v03_ai_message,
convert_to_openai_data_block,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
from langchain_core.tools.base import _stringify
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import (
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from langchain_core.utils.tool_choice import normalize_tool_choice
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self

from langchain_openai.chat_models._client_utils import (
_get_default_async_httpx_client,
_get_default_httpx_client,
_resolve_sync_and_async_api_keys,
)
from langchain_openai.chat_models._compat import (
_convert_from_v1_to_chat_completions,
_convert_from_v1_to_responses,
_convert_to_v03_ai_message,
)

Check failure on line 126 in libs/partners/openai/langchain_openai/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/openai, 3.14) / Python 3.14

Ruff (I001)

langchain_openai/chat_models/base.py:3:1: I001 Import block is un-sorted or un-formatted

Check failure on line 126 in libs/partners/openai/langchain_openai/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/openai, 3.10) / Python 3.10

Ruff (I001)

langchain_openai/chat_models/base.py:3:1: I001 Import block is un-sorted or un-formatted
from langchain_openai.data._profiles import _PROFILES

if TYPE_CHECKING:
from langchain_core.language_models import ModelProfile
from openai.types.responses import Response

logger = logging.getLogger(__name__)
Expand All @@ -138,14 +134,6 @@
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
global_ssl_context = ssl.create_default_context(cafile=certifi.where())

_MODEL_PROFILES = cast(ModelProfileRegistry, _PROFILES)


def _get_default_model_profile(model_name: str) -> ModelProfile:
default = _MODEL_PROFILES.get(model_name) or {}
return default.copy()


WellKnownTools = (
"file_search",
"web_search_preview",
Expand Down Expand Up @@ -573,7 +561,6 @@
!!! version-added "Added in `langchain-openai` 0.3.9"

!!! warning "Behavior changed in `langchain-openai` 0.3.35"

Enabled for default base URL and client.
"""

Expand Down Expand Up @@ -804,7 +791,6 @@
- `'v1'`: v1 of LangChain cross-provider standard.

!!! warning "Behavior changed in `langchain-openai` 1.0.0"

Default updated to `"responses/v1"`.
"""

Expand All @@ -826,15 +812,14 @@
(Defaults to 1)
"""
model = values.get("model_name") or values.get("model") or ""
model_lower = model.lower()

# For o1 models, set temperature=1 if not provided
if model_lower.startswith("o1") and "temperature" not in values:
if model.startswith("o1") and "temperature" not in values:
values["temperature"] = 1

# For gpt-5 models, handle temperature restrictions
# Note that gpt-5-chat models do support temperature
if model_lower.startswith("gpt-5") and "chat" not in model_lower:
if model.startswith("gpt-5") and "chat" not in model:
temperature = values.get("temperature")
if temperature is not None and temperature != 1:
# For gpt-5 (non-chat), only temperature=1 is supported
Expand Down Expand Up @@ -968,13 +953,6 @@
self.async_client = self.root_async_client.chat.completions
return self

@model_validator(mode="after")
def _set_model_profile(self) -> Self:
"""Set model profile if not overridden."""
if self.profile is None:
self.profile = _get_default_model_profile(self.model_name)
return self

@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
Expand Down Expand Up @@ -1669,13 +1647,15 @@
model = self.tiktoken_model_name
else:
model = self.model_name

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
model_lower = model.lower()
encoder = "cl100k_base"
if model_lower.startswith(("gpt-4o", "gpt-4.1", "gpt-5")):
if (
self.model_name.startswith("gpt-4o")
or self.model_name.startswith("gpt-4.1")
or self.model_name.startswith("gpt-5")
):
encoder = "o200k_base"
encoding = tiktoken.get_encoding(encoder)
return model, encoding
Expand Down Expand Up @@ -1837,6 +1817,8 @@
else:
pass
if tool_choice:
# Normalize common user inputs first (e.g. 'any' -> 'required', True -> 'required')

Check failure on line 1820 in libs/partners/openai/langchain_openai/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/openai, 3.14) / Python 3.14

Ruff (E501)

langchain_openai/chat_models/base.py:1820:89: E501 Line too long (95 > 88)

Check failure on line 1820 in libs/partners/openai/langchain_openai/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/openai, 3.10) / Python 3.10

Ruff (E501)

langchain_openai/chat_models/base.py:1820:89: E501 Line too long (95 > 88)
tool_choice = normalize_tool_choice(tool_choice)
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice in tool_names:
Expand All @@ -1846,19 +1828,17 @@
}
elif tool_choice in WellKnownTools:
tool_choice = {"type": tool_choice}
# 'any' is not natively supported by OpenAI API.
# We support 'any' since other models use this instead of 'required'.
elif tool_choice == "any":
tool_choice = "required"
else:
# leave other strings (e.g. 'auto', 'required') as-is
pass
elif isinstance(tool_choice, bool):
tool_choice = "required"
elif isinstance(tool_choice, dict):
pass
elif tool_choice is None:
# explicit None -> no-op
pass
else:
msg = (
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Unrecognized tool_choice type. Expected str, bool, dict or None. "
f"Received: {tool_choice}"
)
raise ValueError(msg)
Expand Down Expand Up @@ -2008,11 +1988,9 @@
- `'parsing_error'`: `BaseException | None`

!!! warning "Behavior changed in `langchain-openai` 0.3.12"

Support for `tools` added.

!!! warning "Behavior changed in `langchain-openai` 0.3.21"

Pass `kwargs` through to the model.
"""
if strict is not None and method == "json_mode":
Expand Down Expand Up @@ -3132,15 +3110,12 @@
- `'parsing_error'`: `BaseException | None`

!!! warning "Behavior changed in `langchain-openai` 0.3.0"

`method` default changed from `"function_calling"` to `"json_schema"`.

!!! warning "Behavior changed in `langchain-openai` 0.3.12"

Support for `tools` added.

!!! warning "Behavior changed in `langchain-openai` 0.3.21"

Pass `kwargs` through to the model.

??? note "Example: `schema=Pydantic` class, `method='json_schema'`, `include_raw=False`, `strict=True`"
Expand Down
Loading