Skip to content

Commit 7d614e2

Browse files
authored
fix comparison table API scraping (#530)
1 parent 471385a commit 7d614e2

File tree

7 files changed

+111
-61
lines changed

7 files changed

+111
-61
lines changed

cupynumeric/_sphinxext/_comparison_config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,11 @@ class SectionConfig:
8383
UFUNCS = (numpy.ufunc,)
8484

8585
NUMPY_CONFIGS = [
86-
SectionConfig("Module-Level", None, types=FUNCTIONS),
87-
SectionConfig("Ufuncs", None, types=UFUNCS),
88-
SectionConfig("Multi-Dimensional Array", "ndarray", types=METHODS),
89-
SectionConfig("Linear Algebra", "linalg", types=FUNCTIONS),
90-
SectionConfig("Discrete Fourier Transform", "fft", types=FUNCTIONS),
91-
SectionConfig("Random Sampling", "random", types=FUNCTIONS),
86+
SectionConfig("Module-Level", None),
87+
SectionConfig("Multi-Dimensional Array", "ndarray"),
88+
SectionConfig("Linear Algebra", "linalg"),
89+
SectionConfig("Discrete Fourier Transform", "fft"),
90+
SectionConfig("Random Sampling", "random"),
9291
]
9392

9493
CONVOLVE = ("convolve", "correlate")

cupynumeric/_sphinxext/_comparison_util.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from dataclasses import dataclass
1818
from types import ModuleType
19-
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Type
19+
from typing import TYPE_CHECKING, Any, Iterable, Iterator
2020

21-
from .._utils.coverage import is_implemented, is_multi, is_single
21+
from .._utils.coverage import is_implemented, is_multi, is_single, is_wrapped
2222
from ._comparison_config import MISSING_NP_REFS, SKIP
2323

2424
if TYPE_CHECKING:
@@ -73,17 +73,31 @@ def _lgref(name: str, obj: Any, implemented: bool) -> str:
7373
return f":{role}:`{full_name}`"
7474

7575

76-
def filter_names(
76+
def filter_wrapped_names(
7777
obj: Any,
78-
types: tuple[Type[Any], ...] | None = None,
78+
*,
7979
skip: Iterable[str] = (),
8080
) -> Iterator[str]:
8181
names = (n for n in dir(obj)) # every name in the module or class
82+
names = (
83+
n for n in names if is_wrapped(getattr(obj, n))
84+
) # that is wrapped
85+
names = (n for n in names if n not in skip) # except the ones we skip
86+
names = (n for n in names if not n.startswith("_")) # or any private names
87+
return names
88+
89+
90+
def filter_type_names(
91+
obj: Any,
92+
*,
93+
skip: Iterable[str] = (),
94+
) -> Iterator[str]:
95+
names = (n for n in dir(obj)) # every name in the module or class
96+
names = (
97+
n for n in names if isinstance(getattr(obj, n), type)
98+
) # that is a type (class, dtype, etc)
8299
names = (n for n in names if n not in skip) # except the ones we skip
83100
names = (n for n in names if not n.startswith("_")) # or any private names
84-
if types:
85-
# optionally filtered by type
86-
names = (n for n in names if isinstance(getattr(obj, n), types))
87101
return names
88102

89103

@@ -123,9 +137,14 @@ def generate_section(config: SectionConfig) -> SectionDetail:
123137
names: Iterable[str]
124138

125139
if config.names:
126-
names = config.names
140+
names = set(config.names)
127141
else:
128-
names = filter_names(np_obj, config.types, skip=SKIP)
142+
wrapped_names = filter_wrapped_names(lg_obj, skip=SKIP)
143+
type_names = filter_type_names(lg_obj, skip=SKIP)
144+
names = set(wrapped_names) | set(type_names)
145+
146+
# we can omit anything that isn't in np namespace to begin with
147+
names = {n for n in names if n in dir(np_obj)}
129148

130149
items = [get_item(name, np_obj, lg_obj) for name in names]
131150

cupynumeric/_utils/coverage.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from dataclasses import dataclass
1919
from functools import WRAPPER_ASSIGNMENTS, wraps
20-
from types import BuiltinFunctionType, FunctionType, ModuleType
20+
from types import BuiltinFunctionType, ModuleType
2121
from typing import Any, Callable, Container, Iterable, Mapping, Protocol, cast
2222

2323
from legate.core import track_provenance
@@ -69,7 +69,7 @@ class CuWrapperMetadata:
6969

7070

7171
class CuWrapped(AnyCallable, Protocol):
72-
_cupynumeric: CuWrapperMetadata
72+
_cupynumeric_metadata: CuWrapperMetadata
7373
__wrapped__: AnyCallable
7474
__name__: str
7575
__qualname__: str
@@ -116,7 +116,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
116116
multi = "Multiple GPUs" in (getattr(func, "__doc__", None) or "")
117117
single = "Single GPU" in (getattr(func, "__doc__", None) or "") or multi
118118

119-
wrapper._cupynumeric = CuWrapperMetadata(
119+
wrapper._cupynumeric_metadata = CuWrapperMetadata(
120120
implemented=True, single=single, multi=multi
121121
)
122122

@@ -185,7 +185,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
185185
--------
186186
{name}
187187
"""
188-
wrapper._cupynumeric = CuWrapperMetadata(implemented=False)
188+
wrapper._cupynumeric_metadata = CuWrapperMetadata(implemented=False)
189189

190190
return wrapper
191191

@@ -242,7 +242,7 @@ def clone_module(
242242
# Only need to wrap things that are in the origin module to begin with
243243
if attr not in origin_module.__dict__:
244244
continue
245-
if isinstance(value, (FunctionType, lgufunc)) or (
245+
if should_wrap(value) or (
246246
include_builtin_function_type
247247
and isinstance(value, BuiltinFunctionType)
248248
):
@@ -273,7 +273,7 @@ def clone_module(
273273
from numpy import ufunc as npufunc
274274

275275
for attr, value in missing.items():
276-
if isinstance(value, (FunctionType, npufunc)) or (
276+
if should_wrap(value) or (
277277
include_builtin_function_type
278278
and isinstance(value, BuiltinFunctionType)
279279
):
@@ -300,13 +300,19 @@ def clone_module(
300300

301301

302302
def should_wrap(obj: object) -> bool:
303-
# custom callables, e.g. cython used in np2, do not inherit anything. See
304-
# https://github.com/nv-legate/cupynumeric.internal/issues/179#issuecomment-2423813051
303+
from numpy import ufunc as npufunc
304+
305+
from .._ufunc.ufunc import ufunc as lgufunc
306+
307+
# Custom callables, e.g. cython functions used in np2, do not inherit
308+
# anything, so we check callable() instead (and include the __get__/__set__
309+
# checks to filter out classes). OTOH ufuncs need to be checked specially
310+
# because they do not have __get__.
305311
return (
306312
callable(obj)
307313
and hasattr(obj, "__get__")
308314
and not hasattr(obj, "__set__")
309-
)
315+
) or isinstance(obj, (lgufunc, npufunc))
310316

311317

312318
def clone_class(
@@ -363,13 +369,17 @@ def _clone_class(cls: type) -> type:
363369
return _clone_class
364370

365371

372+
def is_wrapped(obj: Any) -> bool:
373+
return hasattr(obj, "_cupynumeric_metadata")
374+
375+
366376
def is_implemented(obj: Any) -> bool:
367-
return hasattr(obj, "_cupynumeric") and obj._cupynumeric.implemented
377+
return is_wrapped(obj) and obj._cupynumeric_metadata.implemented
368378

369379

370380
def is_single(obj: Any) -> bool:
371-
return hasattr(obj, "_cupynumeric") and obj._cupynumeric.single
381+
return is_wrapped(obj) and obj._cupynumeric_metadata.single
372382

373383

374384
def is_multi(obj: Any) -> bool:
375-
return hasattr(obj, "_cupynumeric") and obj._cupynumeric.multi
385+
return is_wrapped(obj) and obj._cupynumeric_metadata.multi

tests/integration/test_array_fallback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_unimplemented_method_self_fallback():
2727
# to verify a behaviour of unimplemented ndarray method wrappers. If std
2828
# becomes implemeneted in the future, this assertion will start to fail,
2929
# and a new (unimplemented) ndarray method should be found to replace it
30-
assert not ones.std._cupynumeric.implemented
30+
assert not ones.std._cupynumeric_metadata.implemented
3131

3232
ones.std()
3333

tests/integration/test_fallback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_ufunc():
3333
# methods. If logical_and.accumulate becomes implemented in the future,
3434
# this assertion will start to fail, and a new (unimplemented) ufunc method
3535
# should be found to replace it
36-
assert not num.logical_and.accumulate._cupynumeric.implemented
36+
assert not num.logical_and.accumulate._cupynumeric_metadata.implemented
3737

3838
out_num = num.logical_and.accumulate(in_num)
3939
out_np = np.logical_and.accumulate(in_np)

tests/unit/cupynumeric/_sphinxext/test__comparison_util.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,31 @@ def test_get_namespaces_attr(attr):
3535
assert res[1] is getattr(num, attr)
3636

3737

38+
class _wrapped:
39+
class _cupynumeric_metadata:
40+
implemeneted = True
41+
42+
3843
class _TestObj:
39-
a = 10
40-
b = 10.2
41-
c = "str"
42-
_priv = "priv"
44+
a = _wrapped
45+
b = _wrapped
46+
c = _wrapped
47+
d = 10
48+
_priv = _wrapped
4349

4450

45-
class Test_filter_names:
51+
class Test_filter_wrapped_names:
4652
def test_default(self):
47-
assert set(m.filter_names(_TestObj)) == {"a", "b", "c"}
48-
49-
def test_types(self):
50-
assert set(m.filter_names(_TestObj, (int,))) == {"a"}
51-
assert set(m.filter_names(_TestObj, (int, str))) == {"a", "c"}
52-
assert set(m.filter_names(_TestObj, (int, set))) == {"a"}
53-
assert set(m.filter_names(_TestObj, (set,))) == set()
53+
assert set(m.filter_wrapped_names(_TestObj())) == {"a", "b", "c"}
5454

5555
def test_skip(self):
56-
assert set(m.filter_names(_TestObj, skip=("a",))) == {"b", "c"}
57-
assert set(m.filter_names(_TestObj, skip=("a", "c"))) == {"b"}
56+
assert set(m.filter_wrapped_names(_TestObj(), skip=("a",))) == {
57+
"b",
58+
"c",
59+
}
60+
assert set(m.filter_wrapped_names(_TestObj(), skip=("a", "c"))) == {
61+
"b"
62+
}
5863

5964

6065
if __name__ == "__main__":

tests/unit/cupynumeric/_utils/test_coverage.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,23 @@ def __call__(self, a: int, b: int) -> int:
117117
_test_ufunc = _Test_ufunc()
118118

119119

120+
class Test_helpers:
121+
def test_is_wrapped_true(self) -> None:
122+
wrapped = m.implemented(_test_func, "foo", "_test_func")
123+
assert m.is_wrapped(wrapped)
124+
125+
def test_is_wrapped_false(self) -> None:
126+
assert not m.is_wrapped(10)
127+
128+
def test_is_implemented_true(self) -> None:
129+
wrapped = m.implemented(_test_func, "foo", "_test_func")
130+
assert m.is_implemented(wrapped)
131+
132+
def test_is_implemented_false(self) -> None:
133+
wrapped = m.unimplemented(_test_func, "foo", "_test_func")
134+
assert not m.is_implemented(wrapped)
135+
136+
120137
class Test_implemented:
121138
@patch("cupynumeric.runtime.record_api_call")
122139
def test_reporting_True_func(
@@ -347,10 +364,10 @@ def test_report_coverage_True(self) -> None:
347364
assert _Dest.attr2 == 30
348365

349366
assert _Dest.function1.__wrapped__ is _OriginMod.function1
350-
assert not _Dest.function1._cupynumeric.implemented
367+
assert not _Dest.function1._cupynumeric_metadata.implemented
351368

352369
assert _Dest.function2.__wrapped__
353-
assert _Dest.function2._cupynumeric.implemented
370+
assert _Dest.function2._cupynumeric_metadata.implemented
354371

355372
assert not hasattr(_Dest.extra, "_cupynumeric")
356373

@@ -373,10 +390,10 @@ def test_report_coverage_False(self) -> None:
373390
assert _Dest.attr2 == 30
374391

375392
assert _Dest.function1.__wrapped__ is _OriginMod.function1
376-
assert not _Dest.function1._cupynumeric.implemented
393+
assert not _Dest.function1._cupynumeric_metadata.implemented
377394

378395
assert _Dest.function2.__wrapped__
379-
assert _Dest.function2._cupynumeric.implemented
396+
assert _Dest.function2._cupynumeric_metadata.implemented
380397

381398
assert not hasattr(_Dest.extra, "_cupynumeric")
382399

@@ -428,10 +445,10 @@ def test_report_coverage_True(self) -> None:
428445
assert _Test_ndarray.attr2 == 30
429446

430447
assert _Test_ndarray.foo.__wrapped__ is _Orig_ndarray.foo
431-
assert not _Test_ndarray.foo._cupynumeric.implemented
448+
assert not _Test_ndarray.foo._cupynumeric_metadata.implemented
432449

433450
assert _Test_ndarray.bar.__wrapped__
434-
assert _Test_ndarray.bar._cupynumeric.implemented
451+
assert _Test_ndarray.bar._cupynumeric_metadata.implemented
435452

436453
assert not hasattr(_Test_ndarray.extra, "_cupynumeric")
437454

@@ -447,10 +464,10 @@ def test_report_coverage_False(self) -> None:
447464
assert _Test_ndarray.attr2 == 30
448465

449466
assert _Test_ndarray.foo.__wrapped__ is _Orig_ndarray.foo
450-
assert not _Test_ndarray.foo._cupynumeric.implemented
467+
assert not _Test_ndarray.foo._cupynumeric_metadata.implemented
451468

452469
assert _Test_ndarray.bar.__wrapped__
453-
assert _Test_ndarray.bar._cupynumeric.implemented
470+
assert _Test_ndarray.bar._cupynumeric_metadata.implemented
454471

455472
assert not hasattr(_Test_ndarray.extra, "_cupynumeric")
456473

@@ -469,32 +486,32 @@ def test_ufunc_methods_binary() -> None:
469486

470487
# reduce is implemented
471488
assert np.add.reduce.__wrapped__
472-
assert np.add.reduce._cupynumeric.implemented
489+
assert np.add.reduce._cupynumeric_metadata.implemented
473490

474491
# the rest are not
475492
assert np.add.reduceat.__wrapped__
476-
assert not np.add.reduceat._cupynumeric.implemented
493+
assert not np.add.reduceat._cupynumeric_metadata.implemented
477494
assert np.add.outer.__wrapped__
478-
assert not np.add.outer._cupynumeric.implemented
495+
assert not np.add.outer._cupynumeric_metadata.implemented
479496
assert np.add.at.__wrapped__
480-
assert not np.add.at._cupynumeric.implemented
497+
assert not np.add.at._cupynumeric_metadata.implemented
481498
assert np.add.accumulate.__wrapped__
482-
assert not np.add.accumulate._cupynumeric.implemented
499+
assert not np.add.accumulate._cupynumeric_metadata.implemented
483500

484501

485502
def test_ufunc_methods_unary() -> None:
486503
import cupynumeric as np
487504

488505
assert np.negative.reduce.__wrapped__
489-
assert not np.negative.reduce._cupynumeric.implemented
506+
assert not np.negative.reduce._cupynumeric_metadata.implemented
490507
assert np.negative.reduceat.__wrapped__
491-
assert not np.negative.reduceat._cupynumeric.implemented
508+
assert not np.negative.reduceat._cupynumeric_metadata.implemented
492509
assert np.negative.outer.__wrapped__
493-
assert not np.negative.outer._cupynumeric.implemented
510+
assert not np.negative.outer._cupynumeric_metadata.implemented
494511
assert np.negative.at.__wrapped__
495-
assert not np.negative.at._cupynumeric.implemented
512+
assert not np.negative.at._cupynumeric_metadata.implemented
496513
assert np.negative.accumulate.__wrapped__
497-
assert not np.negative.accumulate._cupynumeric.implemented
514+
assert not np.negative.accumulate._cupynumeric_metadata.implemented
498515

499516

500517
if __name__ == "__main__":

0 commit comments

Comments
 (0)