Skip to content

Commit 28feb22

Browse files
committed
change it inlcude data and support os.Pathlike
1 parent 1c386c7 commit 28feb22

File tree

5 files changed

+47
-30
lines changed

5 files changed

+47
-30
lines changed

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,10 @@ dependencies = [
2626

2727
[tool.setuptools.packages]
2828
find = {} # Scanning implicit namespaces is active by default
29-
#1
29+
#1
30+
31+
[tool.setuptools.package-data]
32+
# The key is the package name.
33+
# The value is a list of file patterns to include.
34+
# "data/**/*" means include all files in the 'data' subdirectory, recursively.
35+
visualbench = ["data/**/*"]

tests/__init__.py

Whitespace-only changes.

visualbench/data/__init__.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
from collections.abc import Sequence
3+
from importlib import resources
4+
from pathlib import Path
35

46
import numpy as np
57
import torch
@@ -8,48 +10,51 @@
810
from ..utils import normalize, to_3HW
911
from ..utils.image import _imread
1012

11-
_path = os.path.dirname(__file__)
13+
# _path = os.path.dirname(__file__)
1214

13-
QRCODE96 = os.path.join(_path, 'qr-96.jpg')
15+
def _get_path(fname:str):
16+
return Path(str(resources.files("visualbench").joinpath("data", fname)))
17+
18+
QRCODE96 = _get_path('qr-96.jpg')
1419
"""QR code that links to my github account"""
1520

16-
ATTNGRAD96 = os.path.join(_path, 'attngrad-96.png')
21+
ATTNGRAD96 = _get_path('attngrad-96.png')
1722
"""Piece of gradient of some model from transformers library except I don't remember which one"""
1823

19-
SANIC96 = os.path.join(_path, 'sanic-96.jpg')
24+
SANIC96 = _get_path('sanic-96.jpg')
2025
"""is an 8 year old image from my images folder and I think it is a screenshot from one of the sanic games"""
2126

22-
FROG96 = os.path.join(_path, 'frog-96.png')
27+
FROG96 = _get_path('frog-96.png')
2328
"""frame from https://www.youtube.com/@NinjaFrog777/videos"""
2429

25-
WEEVIL96 = os.path.join(_path, 'weevil-96.png')
30+
WEEVIL96 = _get_path('weevil-96.png')
2631
"""is from http://growingorganic.com/ipm-guide/weevils/"""
2732

28-
TEST96 = os.path.join(_path, 'test-96.jpg')
33+
TEST96 = _get_path('test-96.jpg')
2934
"""this was generated in like 2012 ago by google doodle generator and its still my favourite image and it is called test"""
3035

31-
MAZE96 = os.path.join(_path, 'maze-96.png')
36+
MAZE96 = _get_path('maze-96.png')
3237
"""a generic maze"""
3338

34-
TEXT96 = os.path.join(_path, 'text-96.png')
39+
TEXT96 = _get_path('text-96.png')
3540
"""lorem ipsum from lorem ipsum text"""
3641

37-
GEOM96 = os.path.join(_path, 'geometry-96.png')
42+
GEOM96 = _get_path('geometry-96.png')
3843
"""CC0 image from wikicommons, SORRY I CAN'T FIND THE LINK ANYMORE!!!"""
3944

40-
RUBIC96 = os.path.join(_path, 'rubic-96.png')
45+
RUBIC96 = _get_path('rubic-96.png')
4146
"""is from https://speedsolving.fandom.com/wiki/Rubik%27s_Cube?file=Rubik%27s_Cube_transparency.png"""
4247

43-
SPIRAL96 = os.path.join(_path, 'spiral-96.png')
48+
SPIRAL96 = _get_path('spiral-96.png')
4449
"""A colorful spiral"""
4550

46-
BIANG96 = os.path.join(_path, 'biang-96.png')
51+
BIANG96 = _get_path('biang-96.png')
4752
"""apparently its the hardest hieroglyph and it is from https://commons.wikimedia.org/wiki/File:Bi%C3%A1ng_%28regular_script%29.svg"""
4853

49-
EMOJIS96 = os.path.join(_path, 'emojis-96.png')
54+
EMOJIS96 = _get_path('emojis-96.png')
5055
"""some random emojis"""
5156

52-
GRID96 = os.path.join(_path, 'grid-96.png')
57+
GRID96 = _get_path('grid-96.png')
5358
"""Grid of black and white cells"""
5459

5560
def get_qrcode():
@@ -165,7 +170,7 @@ def get_ill_conditioned(size: int | tuple[int,int], cond:float=1e17):
165170

166171
def get_font_dict(dtype=torch.bool, device=None):
167172
"""returns a dictionary which maps letters, numbers and +-*/|. to 3x3 binary images."""
168-
path = os.path.join(_path, '3x3 font.jpeg')
173+
path = _get_path('3x3 font.jpeg')
169174
image = to_3HW(_imread(path).float()).mean(0)
170175
image = torch.where(image > 128, 1, 0).contiguous().to(dtype=dtype, device=device)
171176

visualbench/utils/format.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import warnings
2+
from os import PathLike
23
from typing import Any
34

45
import numpy as np
56
import torch
67

78
from .image import _imread
89

10+
911
def _imread_normalize(x) -> torch.Tensor:
1012
x = _imread(x).float()
1113
x -= x.mean()
@@ -22,7 +24,7 @@ def is_scalar(x: Any) -> bool:
2224
return isinstance(x, (int,float,bool))
2325

2426
def totensor(x, device=None, dtype=None, clone=None) -> torch.Tensor:
25-
if isinstance(x, str): x = _imread_normalize(x)
27+
if isinstance(x, (str, PathLike)): x = _imread_normalize(x)
2628

2729
if isinstance(x, torch.Tensor): x = x.to(dtype=dtype, device=device, copy=False)
2830

@@ -42,7 +44,7 @@ def totensor(x, device=None, dtype=None, clone=None) -> torch.Tensor:
4244
return x
4345

4446
def tonumpy(x) -> np.ndarray:
45-
if isinstance(x, str): x = _imread_normalize(x)
47+
if isinstance(x, (str, PathLike)): x = _imread_normalize(x)
4648
if isinstance(x, np.ndarray): return x
4749
if isinstance(x, torch.Tensor): return x.numpy(force=True)
4850
return np.asarray(x)

visualbench/utils/image.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,39 @@
11
import warnings
2+
from os import PathLike
3+
24
import numpy as np
35
import torch
46

5-
def _imread_skimage(path:str) -> np.ndarray:
7+
8+
def _imread_skimage(path:str|PathLike) -> np.ndarray:
69
import skimage
710
return skimage.io.imread(path)
811

9-
def _imread_plt(path:str) -> np.ndarray:
12+
def _imread_plt(path:str|PathLike) -> np.ndarray:
1013
import matplotlib.pyplot as plt
11-
return plt.imread(path)
14+
return plt.imread(str(path))
1215

13-
def _imread_cv2(path):
16+
def _imread_cv2(path: str|PathLike) -> np.ndarray:
1417
import cv2
15-
image = cv2.imread(path) # pylint:disable=no-member
18+
image = cv2.imread(str(path)) # pylint:disable=no-member
19+
assert image is not None
1620
if image.ndim == 3: image = image[:, :, ::-1] # BRG -> RGB
1721
return image
1822

19-
def _imread_imageio(path):
23+
def _imread_imageio(path: str|PathLike):
2024
from imageio import v3
21-
return v3.imread(path)
25+
return v3.imread(str(path))
2226

23-
def _imread_pil(path:str) -> np.ndarray:
27+
def _imread_pil(path: str|PathLike) -> np.ndarray:
2428
import PIL.Image
2529
return np.array(PIL.Image.open(path))
2630

27-
def _imread_torchvision(path:str, dtype=None, device=None) -> torch.Tensor:
31+
def _imread_torchvision(path: str|PathLike, dtype=None, device=None) -> torch.Tensor:
2832
import torchvision
29-
return torchvision.io.read_image(path).to(dtype=dtype, device=device, copy=
33+
return torchvision.io.read_image(str(path)).to(dtype=dtype, device=device, copy=
3034
False)
3135

32-
def _imread(path: str) -> torch.Tensor:
36+
def _imread(path: str|PathLike) -> torch.Tensor:
3337
try: return _imread_torchvision(path)
3438
except Exception:
3539
img = None

0 commit comments

Comments
 (0)