diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index 3cb5900..2a57b5b 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import random import inspect import logging import os @@ -8,11 +9,14 @@ from itertools import islice from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +from itertools import zip_longest import imageio.v3 as iio import numpy as np import param -from dask.distributed import Client, Future, get_client +import dask +from dask.distributed import Client, Future, get_client as _get_client, get_worker as _get_worker +from dask.diagnostics import ProgressBar from .models import Paused from .settings import config @@ -122,7 +126,7 @@ def get_distributed_client(client: Client | None = None, **kwargs) -> Client: return client try: - client = get_client() + client = _get_client() except ValueError: client = Client(**kwargs) return client @@ -184,9 +188,10 @@ def get_first(iterable): return next(islice(iterable, 0, 1), None) -def get_result(future: Future) -> Any: +def get_result(future: Future, timeout: int | None = None) -> Any: + timeout = get_config_default("timeout", timeout, warn=False) if isinstance(future, Future): - return future.result() + return future.result(timeout=timeout) elif hasattr(future, "compute"): return future.compute() else: @@ -292,13 +297,28 @@ def validate_renderer_iterables( ) -def map_over(client, func, resources, batch_size, *args, **kwargs): - try: - return client.map(func, resources, *args, batch_size=batch_size, **kwargs) - except TypeError: - return [ - client.submit(func, resource, *args, **kwargs) for resource in resources +def map_over(client, func, resources, batch_size, *args, processes=True, wait=False, progress_bar=None, **kwargs): + num_retries = get_config_default("num_retries", None, warn=False) + if processes: + try: + resources = client.map(func, resources, *args, batch_size=batch_size, retries=num_retries, **kwargs) + except TypeError: + resources = [ + client.submit(func, resource, *args, **kwargs) for resource in resources + ] + if wait: + resources = client.gather(resources) + else: + func = dask.delayed(func) + jobs = [ + func(resource, *iterable, **kwargs) + for resource, *iterable in zip_longest(resources, *args) ] + if not progress_bar: + progress_bar = ProgressBar(minimum=3) + with progress_bar: + resources = dask.compute(jobs, scheduler="threads")[0] + return resources def repeat_frame( @@ -351,10 +371,12 @@ def get_webdriver_path(webdriver: str): from webdriver_manager.chrome import ChromeDriverManager webdriver_path = ChromeDriverManager().install() + os.environ["BOKEH_CHROMEDRIVER_PATH"] = webdriver_path elif webdriver.lower() == "firefox": from webdriver_manager.firefox import GeckoDriverManager webdriver_path = GeckoDriverManager().install() + os.environ["geckodriver"] = webdriver_path return webdriver_path @@ -381,12 +403,17 @@ def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver: options.add_argument("--headless") options.add_argument("--disable-extensions") webdriver_path = webdriver_path or get_webdriver_path("firefox") - driver = WebDriver(service=Service(webdriver_path), options=options) + driver = WebDriver(service=Service(webdriver_path, port=random.randint(4000, 5000)), options=options) else: raise NotImplementedError( f"Webdriver {webdriver_key} not supported; " f"use 'chrome' or 'firefox', or pass a custom callable." ) - return driver + + +def cleanup_driver(driver): + worker = _get_worker() + if hasattr(worker, "_driver"): + worker._driver.quit() diff --git a/streamjoy/serializers.py b/streamjoy/serializers.py index 268e3b7..c429ac7 100644 --- a/streamjoy/serializers.py +++ b/streamjoy/serializers.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from packaging import version from inspect import isgenerator from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -343,12 +344,13 @@ def serialize_holoviews( backend = kwargs.get("backend", hv.Store.current_backend) - def _select_element(hv_obj, key): + def _select_element(key, hv_obj=None): + hv.extension(backend) try: resource = hv_obj[key] except Exception: resource = hv_obj.select(**{kdims[0].name: key}) - return resource + return resource.opts(title=str(key), backend=backend) hv_obj = resources if isinstance(hv_obj, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)): @@ -370,8 +372,13 @@ def _select_element(hv_obj, key): if len(kdims) > 1: raise ValueError("Can only handle 1D HoloViews objects.") - resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys] - + # TODO: experiment with this as keys instead and push holoviews object as iterables + # for i in range(nframes): + # plot.update(i) + resources = [ + _select_element(key, hv_obj=hv_obj) + for key in keys[: kwargs.get("max_frames")] + ] renderer_kwargs = renderer_kwargs or {} renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) @@ -399,12 +406,14 @@ def _select_element(hv_obj, key): clims=clims, ) - if kwargs.get("processes"): - logging.warning( - "HoloViews rendering does not support processes; " - "setting processes=False." - ) - kwargs["processes"] = False + if version.parse(hv.__version__) < version.parse("1.19.0"): + if kwargs.get("processes"): + logging.warning( + "HoloViews<1.19.0 rendering does not support processes; " + "setting processes=False; to use processes, upgrade wih " + "`pip install 'holoviews>=1.19.0'`" + ) + kwargs["processes"] = False return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) @@ -600,6 +609,7 @@ def serialize_appropriately( obj_handler = _select_obj_handler(resources) _utils.validate_renderer_iterables(resources, renderer_iterables) + kwargs["max_frames"] = _utils.get_max_frames(len(resources), kwargs.get("max_frames")) serialized = obj_handler( stream_cls, resources, diff --git a/streamjoy/settings.py b/streamjoy/settings.py index 47edec5..543c121 100644 --- a/streamjoy/settings.py +++ b/streamjoy/settings.py @@ -3,6 +3,7 @@ "fps": 8, "max_frames": 50, # dask + "timeout": 120, "batch_size": 10, "processes": True, "threads_per_worker": None, @@ -17,7 +18,7 @@ # matplotlib "max_open_warning": 100, # holoviews - "webdriver": "firefox", + "webdriver": "chrome", "num_retries": 5, # output "in_memory": False, diff --git a/streamjoy/streams.py b/streamjoy/streams.py index c4aefdd..26fc8c2 100644 --- a/streamjoy/streams.py +++ b/streamjoy/streams.py @@ -7,12 +7,10 @@ from contextlib import contextmanager from functools import partial from io import BytesIO -from itertools import zip_longest from pathlib import Path from textwrap import indent from typing import TYPE_CHECKING, Any, Callable -import dask.delayed import imageio.v3 as iio import numpy as np import param @@ -439,23 +437,23 @@ def _render_images( in_memory=self.in_memory, ) - if renderer and self.processes: - resources = _utils.map_over( - self.client, - renderer, - resources, - batch_size, - *renderer_iterables, - **renderer_kwargs, - ) - elif renderer and not self.processes: - renderer = dask.delayed(renderer) - jobs = [ - renderer(resource, *iterable, **renderer_kwargs) - for resource, *iterable in zip_longest(resources, *renderer_iterables) - ] - with self._progress_bar: - resources = dask.compute(jobs, scheduler="threads")[0] + if renderer: + try: + resources = _utils.map_over( + self.client, + renderer, + resources, + batch_size, + processes=self.processes, + progress_bar=self._progress_bar, + *renderer_iterables, + **renderer_kwargs, + ) + finally: + self.client.map( + _utils.cleanup_driver, + range(len(self.client.scheduler_info()["workers"])), + ) resource_0 = _utils.get_result(_utils.get_first(resources)) is_like_image = isinstance(resource_0, np.ndarray) and resource_0.ndim == 3 @@ -1105,7 +1103,7 @@ def _open_buffer( """ ], # noqa: E501 ) - player = pn.widgets.Player( + self._player = pn.widgets.Player( name="Time", start=0, value=0, @@ -1124,8 +1122,8 @@ def _open_buffer( """ ], ) - player.jslink(tabs, value="active", bidirectional=True) - self._column.objects = [tabs, player] + self._player.jslink(tabs, value="active", bidirectional=True) + self._column.objects = [tabs, self._player] yield tabs image = tabs.objects[0] width = image.object.width @@ -1136,7 +1134,7 @@ def _open_buffer( width=width + 50, height=height, ) - player.param.update( + self._player.param.update( width=width, end=len(tabs) - 1, ) @@ -1151,7 +1149,7 @@ def _open_buffer( max_height=int(height * 1.5), sizing_mode=sizing_mode, ) - player.param.update( + self._player.param.update( max_height=150, max_width=450, sizing_mode=sizing_mode, @@ -1216,6 +1214,7 @@ def _write_images(self, buf: pn.Tabs, images: list[Future], **write_kwargs) -> N self.fps, **write_kwargs, ) + self._player.end = len(buf) del image def write( diff --git a/streamjoy/wrappers.py b/streamjoy/wrappers.py index 29f1f29..a034498 100644 --- a/streamjoy/wrappers.py +++ b/streamjoy/wrappers.py @@ -7,6 +7,8 @@ from pathlib import Path from typing import Any, Callable +from dask.distributed import get_worker, Lock + from . import _utils from .models import Paused from .settings import config @@ -98,7 +100,6 @@ def wrap_holoviews( Returns: The wrapped function. """ - webdriver = _utils.get_config_default("webdriver", webdriver, warn=False) if isinstance(webdriver, str): webdriver = (webdriver, _utils.get_webdriver_path(webdriver)) @@ -110,8 +111,8 @@ def wrapper(renderer): @wraps(renderer) def wrapped(*args, **kwargs) -> Path | BytesIO: import holoviews as hv - backend = kwargs.get("backend", hv.Store.current_backend) + hv.extension(backend) output = renderer(*args, **kwargs) hv_obj = output @@ -134,22 +135,27 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: ) for r in range(retries): try: - driver = _utils.get_webdriver(webdriver) - with driver: + worker = get_worker() + lock = Lock(worker.id) + with lock: + if not hasattr(worker, "_driver"): + worker._driver = _utils.get_webdriver(webdriver) + driver = worker._driver image = get_screenshot_as_png( hv.render(hv_obj, backend=backend), driver=driver ) - if fsspec_fs: - with fsspec_fs.open(uri, "wb") as f: - image.save(f, format="png") - else: - image.save(uri, format="png") - break + if fsspec_fs: + with fsspec_fs.open(uri, "wb") as f: + image.save(f, format="png") + else: + image.save(uri, format="png") + break except Exception as e: + seconds = r * 2 logging.warning( - f"Failed to save image: {e}, retrying in {r * 2}s" + f"Failed to save image: {e}, retrying in {seconds}s" ) - time.sleep(r * 2) + time.sleep(seconds) if r == retries - 1: raise e else: