Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/source/notebooks/BeginnersGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,37 @@
"# param.npvalue converts the value into numpy before returning it"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`caskade` collapses all params into a 1D array, even if the param had multiple values itself. To explore the 1D array you can use the \"finders\" as shown below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"params = secondsim.get_values()\n",
"print(params)\n",
"# Hmmm, I wonder which param goes in place 3?\n",
"# The result is a tuple (Param, index), the index tells you within the param\n",
"# where the index 3 lands. Since phi is a scalar this is just ()\n",
"print(secondsim.find_param(3))\n",
"# Notice if we get index 1 the returned index is more interesting\n",
"print(secondsim.find_param(1))\n",
"\n",
"# Hmmm, I wonder which index the q param corresponds to?\n",
"print(secondsim.find_index(secondsim.q))\n",
"# For multidimensional params, we will get a slice instead\n",
"print(secondsim.find_index(secondsim.x0))\n",
"\n",
"# You can also query lists to get a bunch at once\n",
"print(secondsim.find_param([0, 1, 2]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
123 changes: 123 additions & 0 deletions src/caskade/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Mapping, Sequence, Union
from math import prod
import numpy as np

from .param import Param
from .errors import (
Expand Down Expand Up @@ -200,6 +201,128 @@ def _recursive_build_params_dict(
del params[link]
return params

def _array_inspection(self, group: Optional[int] = None):
param_list = self.dynamic_params
param_list = tuple(p for p in param_list if (group is None or p.group == group))
self._check_values(param_list, "array")

x = []
with Memo(self, self.name + ":semi_findidx_active"):
for param in param_list:
if param.online:
shape = param.shape
else:
depth = max(memo.count("|") for memo in param.memos)
shape = param.batch_shape[-depth:] + param.shape
if shape == ():
x.append((param, ()))
else:
for i in range(prod(shape)):
x.append((param, tuple(itm.item() for itm in np.unravel_index(i, shape))))
return x

# Finders
#################################################################
def find_param(
self, idx: Union[int, tuple[int]], group: Optional[int] = None, scheme: str = "array"
) -> tuple[Param, tuple[int]]:
"""
Identify which param is associated with the provided index in the
dynamic params array.

Parameters
----------
idx: Union[int, tuple[int]]
The index in the params array at which we wish to find the
associated param.
group: Optional[int]
If the dynamic params have multiple group values, then this argument
specifies which group to check.
scheme: str
Whether to search the array (default) params or list version of
params. dict is currently unsupported.

Returns
-------
param_info: tuple[Param, tuple[int]]
A tuple with the Param object and the index within the Param value
associated with idx (empty tuple if scalar). If idx is a tuple then
the result is a tuple of these results.
"""
if not isinstance(idx, int):
return tuple(self.find_param(i, group, scheme) for i in idx)

if scheme == "array":
x = self._array_inspection(group)
return x[idx]
elif scheme == "list":
param_list = tuple(p for p in self.dynamic_params if group is None or p.group == group)
return param_list[idx]
elif scheme == "dict":
raise NotImplementedError(
"find_param is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways."
)
else:
raise ValueError(f"unrecognized scheme: {scheme}")

def find_index(
self, param: Union[Param, tuple[Param], "Module"], scheme: str = "array"
) -> Union[int, slice]:
"""
Identify what index is associated with a param in the dynamic params
array.

Parameters
----------
param: Union[Param, tuple[Param], Module]
The param for which to find the associated index.
scheme: str
Whether to search the array (default) params or list version of
params. dict is currently unsupported.

Returns
-------
param_info: Union[int, slice]
A int giving the index associated with the provided Param object. If
the param is multi-dimensional then the result will be a slice over
all indices associated with that param.
"""
# 1. Handle recursive structures
if isinstance(param, (list, tuple)):
return tuple(self.find_index(p, scheme) for p in param)
if isinstance(param, GetSetValues):
return tuple(
self.find_index(c, scheme)
for c in param.children.values()
if isinstance(c, Param) and c.dynamic
)

groups = self.dynamic_param_groups if len(self.dynamic_param_groups) > 1 else [None]

for group in groups:
if scheme in ["array", "tensor"]:
inspection = self._array_inspection(group)
matches = [i for i, item in enumerate(inspection) if item[0] is param]

if not matches:
continue
idx = matches[0] if len(matches) == 1 else slice(min(matches), max(matches) + 1)

elif scheme == "list":
param_list = [p for p in self.dynamic_params if group is None or p.group == group]
if param not in param_list:
continue
idx = param_list.index(param)
elif scheme == "dict":
raise NotImplementedError("find_index is not implemented for the dict scheme.")
else:
raise ValueError(f"unrecognized scheme: {scheme}")

# Return with group prefix if we are in multi-group mode
return (group, idx) if len(self.dynamic_param_groups) > 1 else idx

raise ValueError(f"Param {param.name} could not be found in dynamic params.")

# To/From Valid
#################################################################
def _transform_params(self, node, init_params, param_list, transform_attr):
Expand Down
15 changes: 12 additions & 3 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,18 @@ def update_graph(self):
super().update_graph()

def param_order(self):
return ", ".join(
tuple(f"{next(iter(p.parents)).name}: {p.name}" for p in self.dynamic_params)
)
res = []
for g in self.dynamic_param_groups:
res.append(
", ".join(
tuple(
f"{next(iter(p.parents)).name}: {p.name}"
for p in self.dynamic_params
if p.group == g
)
)
)
return "\n".join(res)

@property
def dynamic(self) -> bool:
Expand Down
68 changes: 68 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,74 @@ def test_input_methods_multi_hierarchical(multi_hierarchical_sim, params_type):
assert np.allclose(2 * backend.to_numpy(val), sim.run_sim(20, 22, 2 * p0[0]))


def test_finders(sim):
sim.to_dynamic(False)
assert sim.find_param(0)[0] is sim.workers[0].w2
assert sim.find_param(0)[1] == (0, 0)
assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.helper.h1, sim.s1]))
with pytest.raises(IndexError):
sim.find_param(100)

assert sim.find_param(0, scheme="list") is sim.workers[0].w2
with pytest.raises(NotImplementedError):
sim.find_param(0, scheme="dict")
with pytest.raises(ValueError):
sim.find_param(0, scheme="funky")

assert sim.find_index(sim.workers[0].w2) == slice(0, 4)
assert sim.find_index((sim.helper.h1, sim.s1)) == (19, 27)
assert sim.find_index(sim.helper) == (19, slice(20, 22))
with pytest.raises(ValueError):
sim.find_index(Param("bad_param"))

assert sim.find_index(sim.workers[0].w2, scheme="list") == 0
with pytest.raises(ValueError):
sim.find_index(Param("bad_param"), scheme="list")
with pytest.raises(NotImplementedError):
sim.find_index(sim.s1, scheme="dict")
with pytest.raises(ValueError):
sim.find_index(sim.s1, scheme="funky")

sim.workers[1].w2.group = 1
sim.helper.h1.group = 1
sim.workers[4].w1.group = 1
assert sim.find_param(0, 1)[0] is sim.workers[1].w2
assert sim.find_param(0, 1)[1] == (0, 0)
assert all(a[0] is b for a, b in zip(sim.find_param([16, -1], 0), [sim.helper.h2, sim.s1]))
with pytest.raises(IndexError):
sim.find_param(25, 0)

assert sim.find_index(sim.workers[0].w2) == (0, slice(0, 4))
assert sim.find_index(sim.workers[1].w2) == (1, slice(0, 4))
assert sim.find_index((sim.helper.h1, sim.s1)) == ((1, 4), (0, 21))
with pytest.raises(ValueError):
sim.find_index(Param("bad_param"))

assert sim.find_index(sim.workers[0].w2, scheme="list") == (0, 0)
with pytest.raises(ValueError):
sim.find_index(Param("bad_param"), scheme="list")
with pytest.raises(NotImplementedError):
sim.find_index(sim.s1, scheme="dict")
with pytest.raises(ValueError):
sim.find_index(sim.s1, scheme="funky")


def test_finders_hierarchical(hierarchical_sim):
sim = hierarchical_sim
sim.to_dynamic(False)
print(sim.param_order())
assert sim.find_param(0)[0] is sim.helper.h1
assert sim.find_param(0)[1] == ()
assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.worker.w2, sim.s1]))
with pytest.raises(IndexError):
sim.find_param(100)

assert sim.find_index(sim.worker.w2) == slice(8, 28)
assert sim.find_index((sim.helper.h1, sim.s1)) == (0, 28)
with pytest.raises(ValueError):
sim.find_index(Param("bad_param"))


def nested_double(params):
new_params = {}
for param in params:
Expand Down
Loading