diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d0ac44c8..14c39a5e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -126,7 +126,7 @@ jobs: test_minimum_verisons: name: Test Minimum Versions runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 20 steps: - uses: actions/checkout@v4 - uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 @@ -143,6 +143,11 @@ jobs: run: | hatch -vv run test:nowarn + - name: Run the unit tests with orjson installed + run: | + hatch -e test run pip install orjson + hatch -vv run test:nowarn + test_prereleases: name: Test Prereleases timeout-minutes: 10 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c74d89a..c18bafea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,14 +37,18 @@ repos: types_or: [yaml, html, json] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.18.2" + rev: "v1.19.0" hooks: - id: mypy files: jupyter_client stages: [manual] args: ["--install-types", "--non-interactive"] additional_dependencies: - ["traitlets>=5.13", "ipykernel>=6.26", "jupyter_core>=5.3.2"] + - traitlets>=5.13 + - ipykernel>=6.26 + - jupyter_core>=5.3.2 + - orjson>=3.11.4 + - msgpack-types - repo: https://github.com/adamchainz/blacken-docs rev: "1.20.0" diff --git a/jupyter_client/session.py b/jupyter_client/session.py index c58067ad..aa28be68 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -13,6 +13,7 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations +import functools import hashlib import hmac import json @@ -33,6 +34,7 @@ from traitlets import ( Any, Bool, + Callable, CBytes, CUnicode, Dict, @@ -125,6 +127,41 @@ def json_unpacker(s: str | bytes) -> t.Any: return json.loads(s) +try: + import orjson +except ModuleNotFoundError: + has_orjson = False + orjson_packer, orjson_unpacker = json_packer, json_unpacker +else: + has_orjson = True + + def orjson_packer( + obj: t.Any, *, option: int | None = orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z + ) -> bytes: + """Convert a json object to a bytes using orjson with fallback to json_packer.""" + try: + return orjson.dumps(obj, default=json_default, option=option) + except Exception: + return json_packer(obj) + + def orjson_unpacker(s: str | bytes) -> t.Any: + """Convert a json bytes or string to an object using orjson with fallback to json_unpacker.""" + try: + return orjson.loads(s) + except Exception: + return json_unpacker(s) + + +try: + import msgpack +except ModuleNotFoundError: + has_msgpack = False +else: + has_msgpack = True + msgpack_packer = functools.partial(msgpack.packb, default=json_default) + msgpack_unpacker = msgpack.unpackb + + def pickle_packer(o: t.Any) -> bytes: """Pack an object using the pickle module.""" return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) @@ -132,8 +169,6 @@ def pickle_packer(o: t.Any) -> bytes: pickle_unpacker = pickle.loads -default_packer = json_packer -default_unpacker = json_unpacker DELIM = b"" # singleton dummy tracker, which will always report as done @@ -316,7 +351,7 @@ class Session(Configurable): debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -351,48 +386,42 @@ class Session(Configurable): """, ) + # serialization traits: packer = DottedObjectName( - "json", + "orjson" if has_orjson else "json", config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""", ) - - @observe("packer") - def _packer_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.unpacker = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.unpacker = new - else: - self.pack = import_item(str(new)) - unpacker = DottedObjectName( - "json", + "orjson" if has_orjson else "json", config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""", ) - - @observe("unpacker") - def _unpacker_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.packer = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.packer = new + pack = Callable(orjson_packer if has_orjson else json_packer) # the actual packer function + unpack = Callable( + orjson_unpacker if has_orjson else json_unpacker + ) # the actual unpacker function + + @observe("packer", "unpacker") + def _packer_unpacker_changed(self, change: t.Any) -> None: + new = change["new"].lower() + if new == "orjson" and has_orjson: + self.pack, self.unpack = orjson_packer, orjson_unpacker + elif new == "json" or new == "orjson": + self.pack, self.unpack = json_packer, json_unpacker + elif new == "pickle": + self.pack, self.unpack = pickle_packer, pickle_unpacker + elif new == "msgpack" and has_msgpack: + self.pack, self.unpack = msgpack_packer, msgpack_unpacker else: - self.unpack = import_item(str(new)) + obj = import_item(str(change["new"])) + name = "pack" if change["name"] == "packer" else "unpack" + self.set_trait(name, obj) + return + self.packer = self.unpacker = change["new"] session = CUnicode("", config=True, help="""The UUID identifying this session.""") @@ -417,8 +446,7 @@ def _session_changed(self, change: t.Any) -> None: metadata = Dict( {}, config=True, - help="Metadata dictionary, which serves as the default top-level metadata dict for each " - "message.", + help="Metadata dictionary, which serves as the default top-level metadata dict for each message.", ) # if 0, no adapting to do. @@ -487,25 +515,6 @@ def _keyfile_changed(self, change: t.Any) -> None: # for protecting against sends from forks pid = Integer() - # serialization traits: - - pack = Any(default_packer) # the actual packer function - - @observe("pack") - def _pack_changed(self, change: t.Any) -> None: - new = change["new"] - if not callable(new): - raise TypeError("packer must be callable, not %s" % type(new)) - - unpack = Any(default_unpacker) # the actual packer function - - @observe("unpack") - def _unpack_changed(self, change: t.Any) -> None: - # unpacker is not checked - it is assumed to be - new = change["new"] - if not callable(new): - raise TypeError("unpacker must be callable, not %s" % type(new)) - # thresholds: copy_threshold = Integer( 2**16, @@ -515,8 +524,7 @@ def _unpack_changed(self, change: t.Any) -> None: buffer_threshold = Integer( MAX_BYTES, config=True, - help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " - "pickling.", + help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.", ) item_threshold = Integer( MAX_ITEMS, @@ -534,7 +542,7 @@ def __init__(self, **kwargs: t.Any) -> None: debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -626,10 +634,7 @@ def _check_packers(self) -> None: unpacked = unpack(packed) assert unpacked == msg_list except Exception as e: - msg = ( - f"unpacker '{self.unpacker}' could not handle output from packer" - f" '{self.packer}': {e}" - ) + msg = f"unpacker {self.unpacker!r} could not handle output from packer {self.packer!r}: {e}" raise ValueError(msg) from e # check datetime support diff --git a/pyproject.toml b/pyproject.toml index 0a3b4653..fcda69f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ test = [ "pytest-jupyter[client]>=0.6.2", "pytest-cov", "pytest-timeout", + "msgpack" ] docs = [ "ipykernel", @@ -65,6 +66,7 @@ docs = [ "sphinxcontrib-spelling", "sphinx-autodoc-typehints", ] +orjson = ["orjson"] # When orjson is installed it will be used for faster pack and unpack [project.scripts] jupyter-kernelspec = "jupyter_client.kernelspecapp:KernelSpecApp.launch_instance" diff --git a/tests/test_session.py b/tests/test_session.py index de817423..61a298d5 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -9,12 +9,14 @@ import uuid import warnings from datetime import datetime +from pickle import PicklingError from unittest import mock import pytest import zmq from dateutil.tz import tzlocal from tornado import ioloop +from traitlets import TraitError from zmq.eventloop.zmqstream import ZMQStream from jupyter_client import jsonutil @@ -41,6 +43,16 @@ def session(): return ss.Session() +serializers = [ + ("json", ss.json_packer, ss.json_unpacker), + ("pickle", ss.pickle_packer, ss.pickle_unpacker), +] +if ss.has_orjson: + serializers.append(("orjson", ss.orjson_packer, ss.orjson_unpacker)) +if ss.has_msgpack: + serializers.append(("msgpack", ss.msgpack_packer, ss.msgpack_unpacker)) + + @pytest.mark.usefixtures("no_copy_threshold") class TestSession: def assertEqual(self, a, b): @@ -64,7 +76,11 @@ def test_msg(self, session): self.assertEqual(msg["header"]["msg_type"], "execute") self.assertEqual(msg["msg_type"], "execute") - def test_serialize(self, session): + @pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) + def test_serialize(self, session, packer, pack, unpack): + session.packer = packer + assert session.pack is pack + assert session.unpack is unpack msg = session.msg("execute", content=dict(a=10, b=1.1)) msg_list = session.serialize(msg, ident=b"foo") ident, msg_list = session.feed_identities(msg_list) @@ -234,16 +250,14 @@ async def test_send(self, session): def test_args(self, session): """initialization arguments for Session""" s = session - self.assertTrue(s.pack is ss.default_packer) - self.assertTrue(s.unpack is ss.default_unpacker) self.assertEqual(s.username, os.environ.get("USER", "username")) s = ss.Session() self.assertEqual(s.username, os.environ.get("USER", "username")) - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(pack="hi") - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(unpack="hi") u = str(uuid.uuid4()) s = ss.Session(username="carrot", session=u) @@ -491,11 +505,6 @@ async def test_send_raw(self, session): B.close() ctx.term() - def test_set_packer(self, session): - s = session - s.packer = "json" - s.unpacker = "json" - def test_clone(self, session): s = session s._add_digest("initial") @@ -515,14 +524,45 @@ def test_squash_unicode(): assert ss.squash_unicode("hi") == b"hi" -def test_json_packer(): - ss.json_packer(dict(a=1)) - with pytest.raises(ValueError): - ss.json_packer(dict(a=ss.Session())) - ss.json_packer(dict(a=datetime(2021, 4, 1, 12, tzinfo=tzlocal()))) +@pytest.mark.parametrize( + ["description", "data"], + [ + ("dict", [{"a": 1}, [{"a": 1}]]), + ("infinite", [math.inf, ["inf", None]]), + ("datetime", [datetime(2021, 4, 1, 12, tzinfo=tzlocal()), []]), + ], +) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_serialize_objects(packer, pack, unpack, description, data): + data_in, data_out_options = data with warnings.catch_warnings(): warnings.simplefilter("ignore") - ss.json_packer(dict(a=math.inf)) + value = pack(data_in) + unpacked = unpack(value) + if (description == "infinite") and (packer in ["pickle", "msgpack"]): + assert math.isinf(unpacked) + elif description == "datetime": + assert data_in == jsonutil.parse_date(unpacked) + else: + assert unpacked in data_out_options + + +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_cannot_serialize(session, packer, pack, unpack): + data = {"a": session} + with pytest.raises((TypeError, ValueError, PicklingError)): + pack(data) + + +@pytest.mark.parametrize("mode", ["packer", "unpacker"]) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_pack_unpack(session, packer, pack, unpack, mode): + s: ss.Session = session + s.set_trait(mode, packer) + assert s.pack is pack + assert s.unpack is unpack + mode_reverse = "unpacker" if mode == "packer" else "packer" + assert getattr(s, mode_reverse) == packer def test_message_cls():