Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions streamjoy/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import random
import inspect
import logging
import os
Expand All @@ -8,11 +9,14 @@
from itertools import islice
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from itertools import zip_longest

import imageio.v3 as iio
import numpy as np
import param
from dask.distributed import Client, Future, get_client
import dask
from dask.distributed import Client, Future, get_client as _get_client, get_worker as _get_worker
from dask.diagnostics import ProgressBar

from .models import Paused
from .settings import config
Expand Down Expand Up @@ -122,7 +126,7 @@ def get_distributed_client(client: Client | None = None, **kwargs) -> Client:
return client

try:
client = get_client()
client = _get_client()
except ValueError:
client = Client(**kwargs)
return client
Expand Down Expand Up @@ -184,9 +188,10 @@ def get_first(iterable):
return next(islice(iterable, 0, 1), None)


def get_result(future: Future) -> Any:
def get_result(future: Future, timeout: int | None = None) -> Any:
timeout = get_config_default("timeout", timeout, warn=False)
if isinstance(future, Future):
return future.result()
return future.result(timeout=timeout)
elif hasattr(future, "compute"):
return future.compute()
else:
Expand Down Expand Up @@ -292,13 +297,28 @@ def validate_renderer_iterables(
)


def map_over(client, func, resources, batch_size, *args, **kwargs):
try:
return client.map(func, resources, *args, batch_size=batch_size, **kwargs)
except TypeError:
return [
client.submit(func, resource, *args, **kwargs) for resource in resources
def map_over(client, func, resources, batch_size, *args, processes=True, wait=False, progress_bar=None, **kwargs):
num_retries = get_config_default("num_retries", None, warn=False)
if processes:
try:
resources = client.map(func, resources, *args, batch_size=batch_size, retries=num_retries, **kwargs)
except TypeError:
resources = [
client.submit(func, resource, *args, **kwargs) for resource in resources
]
if wait:
resources = client.gather(resources)
else:
func = dask.delayed(func)
jobs = [
func(resource, *iterable, **kwargs)
for resource, *iterable in zip_longest(resources, *args)
]
if not progress_bar:
progress_bar = ProgressBar(minimum=3)
with progress_bar:
resources = dask.compute(jobs, scheduler="threads")[0]
return resources


def repeat_frame(
Expand Down Expand Up @@ -351,10 +371,12 @@ def get_webdriver_path(webdriver: str):
from webdriver_manager.chrome import ChromeDriverManager

webdriver_path = ChromeDriverManager().install()
os.environ["BOKEH_CHROMEDRIVER_PATH"] = webdriver_path
elif webdriver.lower() == "firefox":
from webdriver_manager.firefox import GeckoDriverManager

webdriver_path = GeckoDriverManager().install()
os.environ["geckodriver"] = webdriver_path
return webdriver_path


Expand All @@ -381,12 +403,17 @@ def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver:
options.add_argument("--headless")
options.add_argument("--disable-extensions")
webdriver_path = webdriver_path or get_webdriver_path("firefox")
driver = WebDriver(service=Service(webdriver_path), options=options)
driver = WebDriver(service=Service(webdriver_path, port=random.randint(4000, 5000)), options=options)

else:
raise NotImplementedError(
f"Webdriver {webdriver_key} not supported; "
f"use 'chrome' or 'firefox', or pass a custom callable."
)

return driver


def cleanup_driver(driver):
worker = _get_worker()
if hasattr(worker, "_driver"):
worker._driver.quit()
30 changes: 20 additions & 10 deletions streamjoy/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from packaging import version
from inspect import isgenerator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
Expand Down Expand Up @@ -343,12 +344,13 @@ def serialize_holoviews(

backend = kwargs.get("backend", hv.Store.current_backend)

def _select_element(hv_obj, key):
def _select_element(key, hv_obj=None):
hv.extension(backend)
try:
resource = hv_obj[key]
except Exception:
resource = hv_obj.select(**{kdims[0].name: key})
return resource
return resource.opts(title=str(key), backend=backend)

hv_obj = resources
if isinstance(hv_obj, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)):
Expand All @@ -370,8 +372,13 @@ def _select_element(hv_obj, key):
if len(kdims) > 1:
raise ValueError("Can only handle 1D HoloViews objects.")

resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys]

# TODO: experiment with this as keys instead and push holoviews object as iterables
# for i in range(nframes):
# plot.update(i)
resources = [
_select_element(key, hv_obj=hv_obj)
for key in keys[: kwargs.get("max_frames")]
]
renderer_kwargs = renderer_kwargs or {}
renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))

Expand Down Expand Up @@ -399,12 +406,14 @@ def _select_element(hv_obj, key):
clims=clims,
)

if kwargs.get("processes"):
logging.warning(
"HoloViews rendering does not support processes; "
"setting processes=False."
)
kwargs["processes"] = False
if version.parse(hv.__version__) < version.parse("1.19.0"):
if kwargs.get("processes"):
logging.warning(
"HoloViews<1.19.0 rendering does not support processes; "
"setting processes=False; to use processes, upgrade wih "
"`pip install 'holoviews>=1.19.0'`"
)
kwargs["processes"] = False
return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)


Expand Down Expand Up @@ -600,6 +609,7 @@ def serialize_appropriately(
obj_handler = _select_obj_handler(resources)
_utils.validate_renderer_iterables(resources, renderer_iterables)

kwargs["max_frames"] = _utils.get_max_frames(len(resources), kwargs.get("max_frames"))
serialized = obj_handler(
stream_cls,
resources,
Expand Down
3 changes: 2 additions & 1 deletion streamjoy/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"fps": 8,
"max_frames": 50,
# dask
"timeout": 120,
"batch_size": 10,
"processes": True,
"threads_per_worker": None,
Expand All @@ -17,7 +18,7 @@
# matplotlib
"max_open_warning": 100,
# holoviews
"webdriver": "firefox",
"webdriver": "chrome",
"num_retries": 5,
# output
"in_memory": False,
Expand Down
47 changes: 23 additions & 24 deletions streamjoy/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from contextlib import contextmanager
from functools import partial
from io import BytesIO
from itertools import zip_longest
from pathlib import Path
from textwrap import indent
from typing import TYPE_CHECKING, Any, Callable

import dask.delayed
import imageio.v3 as iio
import numpy as np
import param
Expand Down Expand Up @@ -439,23 +437,23 @@ def _render_images(
in_memory=self.in_memory,
)

if renderer and self.processes:
resources = _utils.map_over(
self.client,
renderer,
resources,
batch_size,
*renderer_iterables,
**renderer_kwargs,
)
elif renderer and not self.processes:
renderer = dask.delayed(renderer)
jobs = [
renderer(resource, *iterable, **renderer_kwargs)
for resource, *iterable in zip_longest(resources, *renderer_iterables)
]
with self._progress_bar:
resources = dask.compute(jobs, scheduler="threads")[0]
if renderer:
try:
resources = _utils.map_over(
self.client,
renderer,
resources,
batch_size,
processes=self.processes,
progress_bar=self._progress_bar,
*renderer_iterables,
**renderer_kwargs,
)
finally:
self.client.map(
_utils.cleanup_driver,
range(len(self.client.scheduler_info()["workers"])),
)
resource_0 = _utils.get_result(_utils.get_first(resources))

is_like_image = isinstance(resource_0, np.ndarray) and resource_0.ndim == 3
Expand Down Expand Up @@ -1105,7 +1103,7 @@ def _open_buffer(
"""
], # noqa: E501
)
player = pn.widgets.Player(
self._player = pn.widgets.Player(
name="Time",
start=0,
value=0,
Expand All @@ -1124,8 +1122,8 @@ def _open_buffer(
"""
],
)
player.jslink(tabs, value="active", bidirectional=True)
self._column.objects = [tabs, player]
self._player.jslink(tabs, value="active", bidirectional=True)
self._column.objects = [tabs, self._player]
yield tabs
image = tabs.objects[0]
width = image.object.width
Expand All @@ -1136,7 +1134,7 @@ def _open_buffer(
width=width + 50,
height=height,
)
player.param.update(
self._player.param.update(
width=width,
end=len(tabs) - 1,
)
Expand All @@ -1151,7 +1149,7 @@ def _open_buffer(
max_height=int(height * 1.5),
sizing_mode=sizing_mode,
)
player.param.update(
self._player.param.update(
max_height=150,
max_width=450,
sizing_mode=sizing_mode,
Expand Down Expand Up @@ -1216,6 +1214,7 @@ def _write_images(self, buf: pn.Tabs, images: list[Future], **write_kwargs) -> N
self.fps,
**write_kwargs,
)
self._player.end = len(buf)
del image

def write(
Expand Down
30 changes: 18 additions & 12 deletions streamjoy/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import Any, Callable

from dask.distributed import get_worker, Lock

from . import _utils
from .models import Paused
from .settings import config
Expand Down Expand Up @@ -98,7 +100,6 @@ def wrap_holoviews(
Returns:
The wrapped function.
"""

webdriver = _utils.get_config_default("webdriver", webdriver, warn=False)
if isinstance(webdriver, str):
webdriver = (webdriver, _utils.get_webdriver_path(webdriver))
Expand All @@ -110,8 +111,8 @@ def wrapper(renderer):
@wraps(renderer)
def wrapped(*args, **kwargs) -> Path | BytesIO:
import holoviews as hv

backend = kwargs.get("backend", hv.Store.current_backend)
hv.extension(backend)
output = renderer(*args, **kwargs)

hv_obj = output
Expand All @@ -134,22 +135,27 @@ def wrapped(*args, **kwargs) -> Path | BytesIO:
)
for r in range(retries):
try:
driver = _utils.get_webdriver(webdriver)
with driver:
worker = get_worker()
lock = Lock(worker.id)
with lock:
if not hasattr(worker, "_driver"):
worker._driver = _utils.get_webdriver(webdriver)
Comment thread
ahuang11 marked this conversation as resolved.
driver = worker._driver
image = get_screenshot_as_png(
hv.render(hv_obj, backend=backend), driver=driver
)
if fsspec_fs:
with fsspec_fs.open(uri, "wb") as f:
image.save(f, format="png")
else:
image.save(uri, format="png")
break
if fsspec_fs:
with fsspec_fs.open(uri, "wb") as f:
image.save(f, format="png")
else:
image.save(uri, format="png")
break
except Exception as e:
seconds = r * 2
logging.warning(
f"Failed to save image: {e}, retrying in {r * 2}s"
f"Failed to save image: {e}, retrying in {seconds}s"
)
time.sleep(r * 2)
time.sleep(seconds)
if r == retries - 1:
raise e
else:
Expand Down