Skip to content

Transform on GPU does not run within multiprocess pytorch DataLoader #3

@bentaculum

Description

@bentaculum

This minimal example

import numpy as np
import torch
import augmend
import gputools

device = torch.device("cuda")

aug = augmend.Augmend()
aug.add(augmend.Elastic(axis=(0, 1), use_gpu=True))


class AugmendDataset(torch.utils.data.Dataset):
    def __init__(self, aug):
        self.aug = aug
        self.data = np.random.rand(10, 100, 100).astype(np.float32)

    def __getitem__(self, idx):
        x = self.aug(self.data[idx])
        return torch.as_tensor(x, dtype=torch.float)

    def __len__(self):
        return len(self.data)


ds = AugmendDataset(aug)
dl = torch.utils.data.DataLoader(
    ds,
    num_workers=2
)


for x in dl:
    x = x.to(device)

leads to the following error

  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/gallusse/code/tarrow/local_playground/gpu_fail.py", line 18, in __getitem__
    x = self.aug(self.data[idx])
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/augmend.py", line 50, in __call__
    x = self._call(x)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/augmend.py", line 110, in _call
    x = trans(x, rng=self._rng)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/transforms/base.py", line 25, in __call__
    return map_single_func_tree(_apply, zip_trees(self.tree, x))
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/utils.py", line 100, in map_single_func_tree
    else type(x)(map(partial(map_single_func_tree, func), x))
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/utils.py", line 99, in map_single_func_tree
    return func(x) if _is_leaf_node(x) \
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/transforms/base.py", line 23, in _apply
    return trans(_x, rng=rng)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/transforms/base.py", line 43, in __call__
    return self._transform_func(x,
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/transforms/elastic.py", line 173, in transform_elastic
    res = _zoom_and_transform_gpu(img, dxs_coarse=dxs_coarse, order=order)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/augmend/transforms/elastic.py", line 56, in _zoom_and_transform_gpu
    img_im = OCLImage.from_array(img)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/gputools/core/ocltypes.py", line 208, in from_array
    res = cl.image_from_array(ctx, prepare(arr), num_channels=num_channels,
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pyopencl/__init__.py", line 1894, in image_from_array
    return Image(ctx, mode_flags | mem_flags.COPY_HOST_PTR,
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pyopencl/__init__.py", line 974, in image_init
    if context._get_cl_version() >= (1, 2) and get_cl_header_version() >= (1, 2):
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pytools/__init__.py", line 706, in wrapper
    result = function(obj, *args, **kwargs)
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pyopencl/__init__.py", line 698, in context_get_cl_version
    return self.devices[0].platform._get_cl_version()
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pyopencl/__init__.py", line 633, in generic_get_cl_version
    version_string = self.version
  File "/home/gallusse/miniconda3/envs/tarrow_v2/lib/python3.8/site-packages/pyopencl/__init__.py", line 1362, in result
    return info_method(self, info_attr)
pyopencl._cl.LogicError: clGetPlatformInfo failed: INVALID_DEVICE

Not sure whether this can be fixed here/in gputools, might be a pyopencl issue. Note that a in the main process DataLoader(num_workers=0) runs without issues.
Would be great to have for 3d augmentations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions