Skip to content

Port to orjson from ujson #8584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dspy/clients/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import cloudpickle
import pydantic
import ujson
import orjson
from cachetools import LRUCache
from diskcache import FanoutCache

Expand Down Expand Up @@ -93,7 +93,7 @@ def transform_value(value):
return value

params = {k: transform_value(v) for k, v in request.items() if k not in ignored_args_for_cache_key}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
return sha256(orjson.dumps(params, option=orjson.OPT_SORT_KEYS)).hexdigest()

def get(self, request: dict[str, Any], ignored_args_for_cache_key: list[str] | None = None) -> Any:
try:
Expand Down
4 changes: 2 additions & 2 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any

import requests
import ujson
import orjson

from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import TrainDataFormat, get_finetune_directory
Expand Down Expand Up @@ -318,7 +318,7 @@ def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: Trai
elif data_format == TrainDataFormat.COMPLETION:
_validate_completion_data(item)

f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item).decode() + "\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend changing the open(file_path, "w"), above, to "wb" mode to eliminate the decodes happening here.

The end result will be:

Suggested change
f.write(orjson.dumps(item).decode() + "\n")
f.write(orjson.dumps(item) + b"\n")

return file_path


Expand Down
6 changes: 3 additions & 3 deletions dspy/clients/utils_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Any, Literal, TypedDict

import ujson
import orjson

import dspy
from dspy.adapters.base import Adapter
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_finetune_directory() -> str:
def write_lines(file_path, data):
with open(file_path, "w") as f:
for item in data:
f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item).decode() + "\n")
Comment on lines 61 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, I recommend eliminating the calls to .decode():

    with open(file_path, "wb") as f:
        for item in data:
            f.write(orjson.dumps(item) + b"\n")

This suggestion applies to the other change in this file.



def save_data(
Expand All @@ -77,7 +77,7 @@ def save_data(
file_path = os.path.abspath(file_path)
with open(file_path, "w") as f:
for item in data:
f.write(ujson.dumps(item) + "\n")
f.write(orjson.dumps(item).decode() + "\n")
return file_path


Expand Down
7 changes: 3 additions & 4 deletions dspy/predict/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import textwrap
from typing import Callable

import ujson
import orjson

import dspy
from dspy.adapters.utils import get_field_description_string
Expand Down Expand Up @@ -158,10 +158,9 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
}

advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
# advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
# only dumps if it's a list or dict
advise_kwargs = {
k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
for k, v in advise_kwargs.items()
}
advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
Expand Down Expand Up @@ -200,7 +199,7 @@ def inspect_modules(program):
def recursive_mask(o):
# If the object is already serializable, return it.
try:
ujson.dumps(o)
orjson.dumps(o)
return o
except TypeError:
pass
Expand Down
8 changes: 4 additions & 4 deletions dspy/primitives/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

import cloudpickle
import ujson
import orjson

from dspy.utils.saving import get_dependency_versions

Expand Down Expand Up @@ -216,7 +216,7 @@ def save(self, path, save_program=False, modules_to_serialize=None):
"or consider using state-only saving by setting `save_program=False`."
)
with open(path / "metadata.json", "w", encoding="utf-8") as f:
ujson.dump(metadata, f, indent=2, ensure_ascii=False)
f.write(orjson.dumps(metadata, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE).decode('utf-8'))

return

Expand All @@ -225,7 +225,7 @@ def save(self, path, save_program=False, modules_to_serialize=None):
if path.suffix == ".json":
try:
with open(path, "w", encoding="utf-8") as f:
f.write(ujson.dumps(state, indent=2 , ensure_ascii=False))
f.write(orjson.dumps(state, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE).decode('utf-8'))
except Exception as e:
raise RuntimeError(
f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
Expand All @@ -249,7 +249,7 @@ def load(self, path):

if path.suffix == ".json":
with open(path, encoding="utf-8") as f:
state = ujson.loads(f.read())
state = orjson.loads(f.read().encode('utf-8'))
Comment on lines 251 to +252
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last place I'll leave feedback, but in general the suggestion is "don't decode content only to re-encode it immediately:

            with open(path, "rb") as f:
                state = orjson.loads(f.read())

Also, these are pathlib objects so this is more ideal:

            state = orjson.loads(path.read_bytes())

elif path.suffix == ".pkl":
with open(path, "rb") as f:
state = cloudpickle.load(f)
Expand Down
6 changes: 3 additions & 3 deletions dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator

import litellm
import ujson
import orjson
from anyio import create_memory_object_stream, create_task_group
from anyio.streams.memory import MemoryObjectSendStream
from litellm import ModelResponseStream
Expand Down Expand Up @@ -261,10 +261,10 @@ async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
async for value in streamer:
if isinstance(value, Prediction):
data = {"prediction": dict(value.items(include_dspy=False))}
yield f"data: {ujson.dumps(data)}\n\n"
yield f"data: {orjson.dumps(data).decode()}\n\n"
elif isinstance(value, litellm.ModelResponseStream):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
yield f"data: {orjson.dumps(data).decode()}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
# The chunk value is an OpenAI-compatible streaming chunk value,
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",
Expand Down
8 changes: 4 additions & 4 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from typing import Callable

import ujson
import orjson

import dspy
from dspy.adapters.utils import get_field_description_string
Expand Down Expand Up @@ -120,7 +120,7 @@ def append_a_rule(bucket, system, **kwargs):
"module_names": module_names,
}

kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode()
for k, v in kwargs.items()}
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice

Expand Down Expand Up @@ -194,9 +194,9 @@ def inspect_modules(program):
def recursive_mask(o):
# If the object is already serializable, return it.
try:
ujson.dumps(o)
orjson.dumps(o)
return o
except TypeError:
except (TypeError, orjson.JSONEncodeError):
pass

# If it's a dictionary, apply recursively to its values.
Expand Down
4 changes: 2 additions & 2 deletions dspy/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import cloudpickle
import ujson
import orjson

if TYPE_CHECKING:
from dspy.primitives.module import Module
Expand Down Expand Up @@ -40,7 +40,7 @@ def load(path: str) -> "Module":
raise FileNotFoundError(f"The path '{path}' does not exist.")

with open(path / "metadata.json") as f:
metadata = ujson.load(f)
metadata = orjson.loads(f.read())

dependency_versions = get_dependency_versions()
saved_dependency_versions = metadata["dependency_versions"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"datasets>=2.14.6", # needed for Bootstrap's Hasher
"regex>=2023.10.3",
"datasets>=2.14.6", # needed for Bootstrap's Hasher
"ujson>=5.8.0",
"orjson>=3.9.0",
"tqdm>=4.66.1",
"requests>=2.31.0",
"optuna>=3.4.0",
Expand Down
Loading
Loading