From 6adc3f1472e6eeb6be181462a037d9604a7f3c64 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 13 Jun 2024 17:40:57 -0700 Subject: [PATCH 1/7] Support HoloViews version --- streamjoy/_utils.py | 2 +- streamjoy/serializers.py | 15 +++++++++------ streamjoy/streams.py | 11 ++++++----- streamjoy/wrappers.py | 1 + 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index 3cb5900..91c5cb1 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -186,7 +186,7 @@ def get_first(iterable): def get_result(future: Future) -> Any: if isinstance(future, Future): - return future.result() + return future.result(timeout=30) elif hasattr(future, "compute"): return future.compute() else: diff --git a/streamjoy/serializers.py b/streamjoy/serializers.py index 268e3b7..0997498 100644 --- a/streamjoy/serializers.py +++ b/streamjoy/serializers.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from packaging import version from inspect import isgenerator from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -399,12 +400,14 @@ def _select_element(hv_obj, key): clims=clims, ) - if kwargs.get("processes"): - logging.warning( - "HoloViews rendering does not support processes; " - "setting processes=False." - ) - kwargs["processes"] = False + if version.parse(hv.__version__) < version.parse("1.19.0"): + if kwargs.get("processes"): + logging.warning( + "HoloViews<1.19.0 rendering does not support processes; " + "setting processes=False; to use processes, upgrade wih " + "`pip install 'holoviews>=1.19.0'`" + ) + kwargs["processes"] = False return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) diff --git a/streamjoy/streams.py b/streamjoy/streams.py index c4aefdd..333742d 100644 --- a/streamjoy/streams.py +++ b/streamjoy/streams.py @@ -1105,7 +1105,7 @@ def _open_buffer( """ ], # noqa: E501 ) - player = pn.widgets.Player( + self._player = pn.widgets.Player( name="Time", start=0, value=0, @@ -1124,8 +1124,8 @@ def _open_buffer( """ ], ) - player.jslink(tabs, value="active", bidirectional=True) - self._column.objects = [tabs, player] + self._player.jslink(tabs, value="active", bidirectional=True) + self._column.objects = [tabs, self._player] yield tabs image = tabs.objects[0] width = image.object.width @@ -1136,7 +1136,7 @@ def _open_buffer( width=width + 50, height=height, ) - player.param.update( + self._player.param.update( width=width, end=len(tabs) - 1, ) @@ -1151,7 +1151,7 @@ def _open_buffer( max_height=int(height * 1.5), sizing_mode=sizing_mode, ) - player.param.update( + self._player.param.update( max_height=150, max_width=450, sizing_mode=sizing_mode, @@ -1216,6 +1216,7 @@ def _write_images(self, buf: pn.Tabs, images: list[Future], **write_kwargs) -> N self.fps, **write_kwargs, ) + self._player.end = len(buf) del image def write( diff --git a/streamjoy/wrappers.py b/streamjoy/wrappers.py index 29f1f29..cf44ea9 100644 --- a/streamjoy/wrappers.py +++ b/streamjoy/wrappers.py @@ -112,6 +112,7 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: import holoviews as hv backend = kwargs.get("backend", hv.Store.current_backend) + hv.extension(backend) output = renderer(*args, **kwargs) hv_obj = output From bf84f58c9dadb17d35094713dafe45fe07249356 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 13 Jun 2024 19:02:18 -0700 Subject: [PATCH 2/7] Try to speed things up --- streamjoy/_utils.py | 39 +++++++++++++++++++++++++++++---------- streamjoy/serializers.py | 26 +++++++++++++++++++++++--- streamjoy/settings.py | 1 + streamjoy/streams.py | 14 +++----------- 4 files changed, 56 insertions(+), 24 deletions(-) diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index 91c5cb1..e734ce9 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -8,11 +8,14 @@ from itertools import islice from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +from itertools import zip_longest import imageio.v3 as iio import numpy as np import param -from dask.distributed import Client, Future, get_client +import dask +from dask.distributed import Client, Future, get_client as _get_client +from dask.diagnostics import ProgressBar from .models import Paused from .settings import config @@ -122,7 +125,7 @@ def get_distributed_client(client: Client | None = None, **kwargs) -> Client: return client try: - client = get_client() + client = _get_client() except ValueError: client = Client(**kwargs) return client @@ -184,9 +187,10 @@ def get_first(iterable): return next(islice(iterable, 0, 1), None) -def get_result(future: Future) -> Any: +def get_result(future: Future, timeout: int | None = None) -> Any: + timeout = get_config_default("timeout", timeout, warn=False) if isinstance(future, Future): - return future.result(timeout=30) + return future.result(timeout=timeout) elif hasattr(future, "compute"): return future.compute() else: @@ -292,13 +296,28 @@ def validate_renderer_iterables( ) -def map_over(client, func, resources, batch_size, *args, **kwargs): - try: - return client.map(func, resources, *args, batch_size=batch_size, **kwargs) - except TypeError: - return [ - client.submit(func, resource, *args, **kwargs) for resource in resources +def map_over(client, func, resources, batch_size, *args, processes=True, wait=False, progress_bar=None, **kwargs): + num_retries = get_config_default("num_retries", None, warn=False) + if processes: + try: + resources = client.map(func, resources, *args, batch_size=batch_size, retries=num_retries, **kwargs) + except TypeError: + resources = [ + client.submit(func, resource, *args, **kwargs) for resource in resources + ] + if wait: + resources = client.gather(resources) + else: + func = dask.delayed(func) + jobs = [ + func(resource, *iterable, **kwargs) + for resource, *iterable in zip_longest(resources, *args) ] + if not progress_bar: + progress_bar = ProgressBar(minimum=3) + with progress_bar: + resources = dask.compute(jobs, scheduler="threads")[0] + return resources def repeat_frame( diff --git a/streamjoy/serializers.py b/streamjoy/serializers.py index 0997498..55c1a9d 100644 --- a/streamjoy/serializers.py +++ b/streamjoy/serializers.py @@ -344,12 +344,13 @@ def serialize_holoviews( backend = kwargs.get("backend", hv.Store.current_backend) - def _select_element(hv_obj, key): + def _select_element(key, hv_obj=None): + hv.extension(backend) try: resource = hv_obj[key] except Exception: resource = hv_obj.select(**{kdims[0].name: key}) - return resource + return resource.opts(title=str(key), backend=backend) hv_obj = resources if isinstance(hv_obj, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)): @@ -371,7 +372,25 @@ def _select_element(hv_obj, key): if len(kdims) > 1: raise ValueError("Can only handle 1D HoloViews objects.") - resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys] + # if isinstance(hv_map, hv.core.spaces.DynamicMap): + # logging.warning( + # "HoloViews DynamicMap objects may be slow to serialize " + # "due to the need to render each frame individually..." + # ) + resources = [ + _select_element(key, hv_obj=hv_obj) + for key in keys[: kwargs.get("max_frames")] + ] + # else: + # client = _utils.get_distributed_client() + # resources = _utils.map_over( + # client, + # _select_element, + # keys[: kwargs.get("max_frames")], + # kwargs.get("batch_size"), + # hv_obj=hv_obj, + # wait=True + # ) renderer_kwargs = renderer_kwargs or {} renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) @@ -603,6 +622,7 @@ def serialize_appropriately( obj_handler = _select_obj_handler(resources) _utils.validate_renderer_iterables(resources, renderer_iterables) + kwargs["max_frames"] = _utils.get_max_frames(len(resources), kwargs.get("max_frames")) serialized = obj_handler( stream_cls, resources, diff --git a/streamjoy/settings.py b/streamjoy/settings.py index 47edec5..c9b2fa0 100644 --- a/streamjoy/settings.py +++ b/streamjoy/settings.py @@ -3,6 +3,7 @@ "fps": 8, "max_frames": 50, # dask + "timeout": 30, "batch_size": 10, "processes": True, "threads_per_worker": None, diff --git a/streamjoy/streams.py b/streamjoy/streams.py index 333742d..2e4ea15 100644 --- a/streamjoy/streams.py +++ b/streamjoy/streams.py @@ -7,12 +7,10 @@ from contextlib import contextmanager from functools import partial from io import BytesIO -from itertools import zip_longest from pathlib import Path from textwrap import indent from typing import TYPE_CHECKING, Any, Callable -import dask.delayed import imageio.v3 as iio import numpy as np import param @@ -439,23 +437,17 @@ def _render_images( in_memory=self.in_memory, ) - if renderer and self.processes: + if renderer: resources = _utils.map_over( self.client, renderer, resources, batch_size, + processes=self.processes, + progress_bar=self._progress_bar, *renderer_iterables, **renderer_kwargs, ) - elif renderer and not self.processes: - renderer = dask.delayed(renderer) - jobs = [ - renderer(resource, *iterable, **renderer_kwargs) - for resource, *iterable in zip_longest(resources, *renderer_iterables) - ] - with self._progress_bar: - resources = dask.compute(jobs, scheduler="threads")[0] resource_0 = _utils.get_result(_utils.get_first(resources)) is_like_image = isinstance(resource_0, np.ndarray) and resource_0.ndim == 3 From 996685f8207197a73db0f1d72c7981872815d626 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 13 Jun 2024 23:17:20 -0700 Subject: [PATCH 3/7] wooo fixed --- streamjoy/_utils.py | 7 +++++-- streamjoy/settings.py | 4 ++-- streamjoy/wrappers.py | 32 ++++++++++++++++++++------------ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index e734ce9..5aa460a 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import random import inspect import logging import os @@ -370,10 +371,12 @@ def get_webdriver_path(webdriver: str): from webdriver_manager.chrome import ChromeDriverManager webdriver_path = ChromeDriverManager().install() + os.environ["BOKEH_CHROMEDRIVER_PATH"] = webdriver_path elif webdriver.lower() == "firefox": from webdriver_manager.firefox import GeckoDriverManager webdriver_path = GeckoDriverManager().install() + os.environ["geckodriver"] = webdriver_path return webdriver_path @@ -400,12 +403,12 @@ def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver: options.add_argument("--headless") options.add_argument("--disable-extensions") webdriver_path = webdriver_path or get_webdriver_path("firefox") - driver = WebDriver(service=Service(webdriver_path), options=options) + driver = WebDriver(service=Service(webdriver_path, port=random.randint(4000, 5000)), options=options) else: raise NotImplementedError( f"Webdriver {webdriver_key} not supported; " f"use 'chrome' or 'firefox', or pass a custom callable." ) - + print("CREATED WEBDRIVER") return driver diff --git a/streamjoy/settings.py b/streamjoy/settings.py index c9b2fa0..543c121 100644 --- a/streamjoy/settings.py +++ b/streamjoy/settings.py @@ -3,7 +3,7 @@ "fps": 8, "max_frames": 50, # dask - "timeout": 30, + "timeout": 120, "batch_size": 10, "processes": True, "threads_per_worker": None, @@ -18,7 +18,7 @@ # matplotlib "max_open_warning": 100, # holoviews - "webdriver": "firefox", + "webdriver": "chrome", "num_retries": 5, # output "in_memory": False, diff --git a/streamjoy/wrappers.py b/streamjoy/wrappers.py index cf44ea9..07f5b7b 100644 --- a/streamjoy/wrappers.py +++ b/streamjoy/wrappers.py @@ -7,6 +7,9 @@ from pathlib import Path from typing import Any, Callable +from dask.distributed import get_worker +from dask.distributed import Lock + from . import _utils from .models import Paused from .settings import config @@ -98,7 +101,6 @@ def wrap_holoviews( Returns: The wrapped function. """ - webdriver = _utils.get_config_default("webdriver", webdriver, warn=False) if isinstance(webdriver, str): webdriver = (webdriver, _utils.get_webdriver_path(webdriver)) @@ -110,7 +112,6 @@ def wrapper(renderer): @wraps(renderer) def wrapped(*args, **kwargs) -> Path | BytesIO: import holoviews as hv - backend = kwargs.get("backend", hv.Store.current_backend) hv.extension(backend) output = renderer(*args, **kwargs) @@ -128,29 +129,36 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: fsspec_fs=fsspec_fs, ) if backend == "bokeh": + import os from bokeh.io.export import get_screenshot_as_png + from bokeh.io.webdriver import webdriver_control retries = _utils.get_config_default( "num_retries", num_retries, warn=False ) for r in range(retries): try: - driver = _utils.get_webdriver(webdriver) - with driver: + worker = get_worker() + lock = Lock(worker.id) + with lock: + if not hasattr(worker, "_driver"): + worker._driver = webdriver_control.create() + driver = worker._driver image = get_screenshot_as_png( hv.render(hv_obj, backend=backend), driver=driver ) - if fsspec_fs: - with fsspec_fs.open(uri, "wb") as f: - image.save(f, format="png") - else: - image.save(uri, format="png") - break + if fsspec_fs: + with fsspec_fs.open(uri, "wb") as f: + image.save(f, format="png") + else: + image.save(uri, format="png") + break except Exception as e: + seconds = r * 5 logging.warning( - f"Failed to save image: {e}, retrying in {r * 2}s" + f"Failed to save image: {e}, retrying in {seconds}s" ) - time.sleep(r * 2) + time.sleep(seconds) if r == retries - 1: raise e else: From 84f85f1619c20892434e451e10a72a3626b1c381 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 13 Jun 2024 23:19:42 -0700 Subject: [PATCH 4/7] cleanup --- streamjoy/_utils.py | 1 - streamjoy/wrappers.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index 5aa460a..70a739f 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -410,5 +410,4 @@ def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver: f"Webdriver {webdriver_key} not supported; " f"use 'chrome' or 'firefox', or pass a custom callable." ) - print("CREATED WEBDRIVER") return driver diff --git a/streamjoy/wrappers.py b/streamjoy/wrappers.py index 07f5b7b..1d4f538 100644 --- a/streamjoy/wrappers.py +++ b/streamjoy/wrappers.py @@ -8,7 +8,6 @@ from typing import Any, Callable from dask.distributed import get_worker -from dask.distributed import Lock from . import _utils from .models import Paused @@ -129,9 +128,7 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: fsspec_fs=fsspec_fs, ) if backend == "bokeh": - import os from bokeh.io.export import get_screenshot_as_png - from bokeh.io.webdriver import webdriver_control retries = _utils.get_config_default( "num_retries", num_retries, warn=False @@ -139,10 +136,9 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: for r in range(retries): try: worker = get_worker() - lock = Lock(worker.id) - with lock: + with worker._lock: if not hasattr(worker, "_driver"): - worker._driver = webdriver_control.create() + worker._driver = _utils.get_webdriver(webdriver) driver = worker._driver image = get_screenshot_as_png( hv.render(hv_obj, backend=backend), driver=driver @@ -154,7 +150,7 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: image.save(uri, format="png") break except Exception as e: - seconds = r * 5 + seconds = r * 2 logging.warning( f"Failed to save image: {e}, retrying in {seconds}s" ) From 807c2fb25b4f568a827c0d6f01ef4d87eed88a4b Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Thu, 13 Jun 2024 23:22:18 -0700 Subject: [PATCH 5/7] use diff lock --- streamjoy/wrappers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streamjoy/wrappers.py b/streamjoy/wrappers.py index 1d4f538..a034498 100644 --- a/streamjoy/wrappers.py +++ b/streamjoy/wrappers.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Callable -from dask.distributed import get_worker +from dask.distributed import get_worker, Lock from . import _utils from .models import Paused @@ -136,7 +136,8 @@ def wrapped(*args, **kwargs) -> Path | BytesIO: for r in range(retries): try: worker = get_worker() - with worker._lock: + lock = Lock(worker.id) + with lock: if not hasattr(worker, "_driver"): worker._driver = _utils.get_webdriver(webdriver) driver = worker._driver From 7ef0a6cd121e2c7df5a43be7c544cb73e4ae2398 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Fri, 14 Jun 2024 22:38:34 -0700 Subject: [PATCH 6/7] cleanup drivers --- streamjoy/_utils.py | 8 +++++++- streamjoy/streams.py | 26 ++++++++++++++++---------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/streamjoy/_utils.py b/streamjoy/_utils.py index 70a739f..2a57b5b 100644 --- a/streamjoy/_utils.py +++ b/streamjoy/_utils.py @@ -15,7 +15,7 @@ import numpy as np import param import dask -from dask.distributed import Client, Future, get_client as _get_client +from dask.distributed import Client, Future, get_client as _get_client, get_worker as _get_worker from dask.diagnostics import ProgressBar from .models import Paused @@ -411,3 +411,9 @@ def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver: f"use 'chrome' or 'firefox', or pass a custom callable." ) return driver + + +def cleanup_driver(driver): + worker = _get_worker() + if hasattr(worker, "_driver"): + worker._driver.quit() diff --git a/streamjoy/streams.py b/streamjoy/streams.py index 2e4ea15..26fc8c2 100644 --- a/streamjoy/streams.py +++ b/streamjoy/streams.py @@ -438,16 +438,22 @@ def _render_images( ) if renderer: - resources = _utils.map_over( - self.client, - renderer, - resources, - batch_size, - processes=self.processes, - progress_bar=self._progress_bar, - *renderer_iterables, - **renderer_kwargs, - ) + try: + resources = _utils.map_over( + self.client, + renderer, + resources, + batch_size, + processes=self.processes, + progress_bar=self._progress_bar, + *renderer_iterables, + **renderer_kwargs, + ) + finally: + self.client.map( + _utils.cleanup_driver, + range(len(self.client.scheduler_info()["workers"])), + ) resource_0 = _utils.get_result(_utils.get_first(resources)) is_like_image = isinstance(resource_0, np.ndarray) and resource_0.ndim == 3 From ad63875e73a90819cf93b558b4d90d1ce8207eb6 Mon Sep 17 00:00:00 2001 From: ahuang11 Date: Fri, 14 Jun 2024 22:41:55 -0700 Subject: [PATCH 7/7] add todo --- streamjoy/serializers.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/streamjoy/serializers.py b/streamjoy/serializers.py index 55c1a9d..c429ac7 100644 --- a/streamjoy/serializers.py +++ b/streamjoy/serializers.py @@ -372,26 +372,13 @@ def _select_element(key, hv_obj=None): if len(kdims) > 1: raise ValueError("Can only handle 1D HoloViews objects.") - # if isinstance(hv_map, hv.core.spaces.DynamicMap): - # logging.warning( - # "HoloViews DynamicMap objects may be slow to serialize " - # "due to the need to render each frame individually..." - # ) + # TODO: experiment with this as keys instead and push holoviews object as iterables + # for i in range(nframes): + # plot.update(i) resources = [ _select_element(key, hv_obj=hv_obj) for key in keys[: kwargs.get("max_frames")] ] - # else: - # client = _utils.get_distributed_client() - # resources = _utils.map_over( - # client, - # _select_element, - # keys[: kwargs.get("max_frames")], - # kwargs.get("batch_size"), - # hv_obj=hv_obj, - # wait=True - # ) - renderer_kwargs = renderer_kwargs or {} renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))