diff --git a/docs/_static/generate_images.py b/docs/_static/generate_images.py index c8ca6a1..ff8a577 100644 --- a/docs/_static/generate_images.py +++ b/docs/_static/generate_images.py @@ -333,7 +333,7 @@ def logo_scene() -> StructureScene: ) data = dict(zip(range(1, 1 + n_verts), colour_vals)) - scene.set_atom_data("gradient", data) + scene.set_atom_data("gradient", by_index=data) scene.view.look_along(look_dir) scene.view.perspective = 0.12 @@ -664,12 +664,12 @@ def red_blue(t: float) -> tuple[float, float, float]: type_dict: dict[int, object] = {} for i in range(n_outer): type_dict[i] = ["Fe", "Co", "Ni"][i % 3] - multi_scene.set_atom_data("metal", type_dict) + multi_scene.set_atom_data("metal", by_index=type_dict) # Inner ring: numerical charge. charge_dict: dict[int, object] = {} for i in range(n_inner): charge_dict[n_outer + i] = float(i) / max(n_inner - 1, 1) - multi_scene.set_atom_data("charge", charge_dict) + multi_scene.set_atom_data("charge", by_index=charge_dict) multi_scene.render_mpl( OUT / "colour_by_multi.svg", colour_by=["metal", "charge"], diff --git a/docs/changelog.rst b/docs/changelog.rst index 0cef204..3747c71 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,21 @@ Changelog 0.19.0 ------ +- :meth:`~hofmann.StructureScene.set_atom_data` gains ``by_species`` + and ``by_index`` keyword arguments for sparse per-atom metadata + assignment. ``by_species`` maps species labels to values; + ``by_index`` maps atom indices. Both can be combined in one call, + with ``by_index`` taking precedence at overlapping atoms. + +- :meth:`~hofmann.StructureScene.set_atom_data` no longer accepts a + dict as its positional ``values`` argument. Use ``by_index=`` + instead. + +- Missing entries in sparse categorical atom data are now filled with + ``None`` (object-dtype) instead of empty strings. Both are treated + as missing by the rendering pipeline; ``None`` cannot collide with + a real label. + - The ``AtomData`` container is no longer re-exported from ``hofmann`` or ``hofmann.model``. The only supported way to obtain an instance is to read the diff --git a/docs/colouring.rst b/docs/colouring.rst index b8f1553..a9403cb 100644 --- a/docs/colouring.rst +++ b/docs/colouring.rst @@ -81,14 +81,46 @@ String arrays assign a distinct colour to each unique value. :align: center :alt: Ring of atoms coloured by categorical site labels -Atoms with ``NaN`` (numeric) or ``""`` (categorical) values fall +Atoms with ``NaN`` (numeric) or ``None`` (categorical) values fall back to their species colour. This is useful when metadata is only available for a subset of atoms: .. code-block:: python # Only colour specific atoms by charge; the rest keep species colours. - scene.set_atom_data("charge", {0: 1.2, 3: -0.8, 5: 0.4}) + scene.set_atom_data("charge", by_index={0: 1.2, 3: -0.8, 5: 0.4}) + +Sparse assignment +----------------- + +Use ``by_species`` or ``by_index`` to assign metadata to a subset of +atoms without building a full-length array: + +.. code-block:: python + + # All Mn atoms get charge 2.0. + scene.set_atom_data("charge", by_species={"Mn": 2.0}) + + # Specific atoms by index. + scene.set_atom_data("charge", by_index={0: 1.2, 3: -0.8}) + +Both forms can be combined in a single call. ``by_index`` values +take precedence where they overlap with ``by_species``: + +.. code-block:: python + + # All Mn atoms charge 2.0, except atom 3 (defect site) at 1.9. + scene.set_atom_data( + "charge", + by_species={"Mn": 2.0}, + by_index={3: 1.9}, + ) + +For trajectory data, ``by_species`` accepts 2-D arrays of shape +``(n_frames, n_species_atoms)`` and ``by_index`` accepts 1-D arrays +of length ``n_frames``. Either of these promotes the output to 2-D. +Scalar and 1-D ``by_species`` values and scalar ``by_index`` values +broadcast across frames automatically. Custom colouring functions -------------------------- @@ -128,9 +160,9 @@ the inner ring uses a numerical charge gradient: .. code-block:: python # Outer atoms: categorical type. - scene.set_atom_data("metal", {0: "Fe", 1: "Co", 2: "Ni"}) + scene.set_atom_data("metal", by_index={0: "Fe", 1: "Co", 2: "Ni"}) # Inner atoms: numerical charge. - scene.set_atom_data("charge", {12: 0.0, 13: 0.3}) + scene.set_atom_data("charge", by_index={12: 0.0, 13: 0.3}) scene.render_mpl( "output.svg", colour_by=["metal", "charge"], @@ -160,7 +192,7 @@ polyhedra without any additional configuration. # No colour on the spec -- polyhedra inherit from colour_by. spec = PolyhedronSpec(centre="M", alpha=0.4) - scene.set_atom_data("val", {0: 0.0, 1: 0.5, 2: 1.0}) + scene.set_atom_data("val", by_index={0: 0.0, 1: 0.5, 2: 1.0}) scene.render_mpl( "output.svg", colour_by="val", cmap="coolwarm", diff --git a/src/hofmann/model/colour.py b/src/hofmann/model/colour.py index ef24f7f..1440d35 100644 --- a/src/hofmann/model/colour.py +++ b/src/hofmann/model/colour.py @@ -186,8 +186,8 @@ def _resolve_atom_colours( non-empty for categorical) determines the atom's colour. This allows different colouring rules for different atom subsets:: - scene.set_atom_data("metal_type", {0: "Fe", 2: "Co"}) - scene.set_atom_data("o_coord", {1: 4, 3: 6}) + scene.set_atom_data("metal_type", by_index={0: "Fe", 2: "Co"}) + scene.set_atom_data("o_coord", by_index={1: 4, 3: 6}) scene.render_mpl( colour_by=["metal_type", "o_coord"], cmap=["Set1", "Blues"], diff --git a/src/hofmann/model/structure_scene.py b/src/hofmann/model/structure_scene.py index 55068ae..7418ddc 100644 --- a/src/hofmann/model/structure_scene.py +++ b/src/hofmann/model/structure_scene.py @@ -26,6 +26,9 @@ class StructureScene: """Top-level scene holding atoms, frames, styles, bond rules, and view. + The :attr:`view` (camera/projection state) and :attr:`atom_data` + (per-atom metadata) properties are documented individually below. + Attributes: species: One label per atom. frames: List of coordinate snapshots. Each :class:`Frame` may @@ -33,12 +36,7 @@ class StructureScene: atom_styles: Mapping from species label to visual style. bond_specs: Declarative bond detection rules. polyhedra: Declarative polyhedron rendering rules. - view: Camera / projection state. title: Scene title for display. - atom_data: Per-atom metadata container: read-only view with a - Mapping-style interface. See :meth:`set_atom_data`, - :meth:`del_atom_data`, and :meth:`clear_2d_atom_data` for - modifications. """ def __init__( @@ -375,10 +373,199 @@ def centre_on(self, atom_index: int, *, frame: int = 0) -> None: ) self.view.centre = self.frames[frame].coords[atom_index].copy() + def _coerce_sparse_atom_data( + self, + key: str, + *, + by_species: dict[str, object], + by_index: dict[int, object], + ) -> np.ndarray: + """Resolve sparse by_species/by_index dicts into a dense array. + + Builds a 1-D ``(n_atoms,)`` array by default. Promotes to + 2-D ``(n_frames, n_atoms)`` if any ``by_species`` value is + 2-D or any ``by_index`` value is 1-D. + + ``by_index`` values overwrite ``by_species`` values at + overlapping atoms. When promoted to 2-D, scalar and 1-D + ``by_species`` values are broadcast across the frame axis. + + Args: + key: Metadata key (for error messages). + by_species: Species-label-to-value mapping. + by_index: Atom-index-to-value mapping. + + Returns: + Dense array ready for ``_atom_data._set``. + + Raises: + ValueError: If a species label is unknown, an index is + out of range, or a value has the wrong shape. + TypeError: If values contain a mixture of string and + numeric types. + """ + n_atoms = len(self.species) + n_frames = len(self.frames) + + # --- Validate keys --- + known = set(self.species) + for label in by_species: + if label not in known: + raise ValueError( + f"atom_data[{key!r}]: unknown species {label!r} " + f"(not present in scene)" + ) + for idx in by_index: + if not 0 <= idx < n_atoms: + raise ValueError( + f"atom index {idx} out of range for {n_atoms} atoms" + ) + + # --- Coerce values and infer dtype / dimensionality --- + seen_str = False + seen_num = False + promotes_2d = False + + def _classify_scalar(v: object) -> None: + """Update seen_str / seen_num from a single scalar.""" + nonlocal seen_str, seen_num + if v is None: + return # missing sentinel; does not determine dtype + if isinstance(v, str): + seen_str = True + else: + seen_num = True + + def _classify_array(a: np.ndarray) -> None: + """Update seen_str / seen_num from a numpy array's dtype.""" + nonlocal seen_str, seen_num + if a.dtype.kind == "U": + seen_str = True + elif a.dtype.kind == "O": + # Object arrays may contain strings, numerics, or + # None sentinels. Classify from non-None elements. + for v in a.ravel(): + if seen_str and seen_num: + break + _classify_scalar(v) + else: + seen_num = True + + # Pre-process by_species values. + species_arr = np.array(self.species) + species_entries: list[tuple[np.ndarray, np.ndarray]] = [] + for label, val in by_species.items(): + mask = species_arr == label + n_sp = int(mask.sum()) + a = np.asarray(val) + if a.ndim == 0: + _classify_scalar(a.item()) + elif a.ndim == 1: + if len(a) != n_sp: + raise ValueError( + f"atom_data[{key!r}]: by_species[{label!r}] has " + f"length {len(a)} but species {label!r} has " + f"{n_sp} atoms" + ) + _classify_array(a) + elif a.ndim == 2: + if a.shape != (n_frames, n_sp): + raise ValueError( + f"atom_data[{key!r}]: by_species[{label!r}] has " + f"shape {a.shape} but expected " + f"({n_frames}, {n_sp}) for {n_frames} frames " + f"and {n_sp} atoms of species {label!r}" + ) + promotes_2d = True + _classify_array(a) + else: + raise ValueError( + f"atom_data[{key!r}]: by_species[{label!r}] must be " + f"scalar, 1-D, or 2-D, got {a.ndim}-D" + ) + species_entries.append((mask, a)) + + # Pre-process by_index values. + index_entries: list[tuple[int, np.ndarray]] = [] + for idx, val in by_index.items(): + a = np.asarray(val) + if a.ndim == 0: + _classify_scalar(a.item()) + elif a.ndim == 1: + if len(a) != n_frames: + raise ValueError( + f"atom_data[{key!r}]: by_index[{idx}] has " + f"length {len(a)} but expected {n_frames} frames" + ) + promotes_2d = True + _classify_array(a) + else: + raise ValueError( + f"atom_data[{key!r}]: by_index[{idx}] must be " + f"scalar or 1-D, got {a.ndim}-D" + ) + index_entries.append((idx, a)) + + # Dtype inference. If all values are None (missing + # sentinels), neither flag is set and the default is numeric + # (NaN fill). + if seen_str and seen_num: + raise TypeError( + f"atom_data[{key!r}] has mixed string and numeric " + f"values; all values must be the same type " + f"(string or numeric)" + ) + is_categorical = seen_str + + # --- Allocate output --- + arr: np.ndarray + if promotes_2d: + if is_categorical: + arr = np.empty((n_frames, n_atoms), dtype=object) + arr[:] = None + else: + arr = np.full((n_frames, n_atoms), np.nan) + else: + if is_categorical: + arr = np.array([None] * n_atoms, dtype=object) + else: + arr = np.full(n_atoms, np.nan) + + # --- Fill from by_species (first, lower precedence) --- + for mask, a in species_entries: + if promotes_2d: + if a.ndim == 0: + arr[:, mask] = a.item() + elif a.ndim == 1: + # Broadcast static per-atom across frames. + arr[:, mask] = a[np.newaxis, :] + else: + arr[:, mask] = a + else: + if a.ndim == 0: + arr[mask] = a.item() + else: + arr[mask] = a + + # --- Fill from by_index (second, higher precedence) --- + for idx, a in index_entries: + if promotes_2d: + if a.ndim == 0: + arr[:, idx] = a.item() + else: + arr[:, idx] = a + else: + arr[idx] = a.item() if a.ndim == 0 else a + + return arr + def set_atom_data( self, key: str, - values: ArrayLike | dict[int, object], + values: ArrayLike | None = None, + *, + by_species: dict[str, object] | None = None, + by_index: dict[int, object] | None = None, ) -> None: """Set per-atom metadata for colourmap-based rendering. @@ -388,77 +575,88 @@ def set_atom_data( (for example after extending the trajectory) use :meth:`clear_2d_atom_data`. - A 2-D *values* array is validated in a single walk against - the container's prospective post-write state: the new array - must have ``shape[0] == len(self.frames)``, and any other - stored 2-D entry not being overridden must agree. For the - common single-2-D-entry case, this means an in-place - reassignment at a new shape after ``scene.frames.append(...)`` - just works -- the stored version of *key* is treated as - overridden and skipped. Multi-entry scenes still need - :meth:`clear_2d_atom_data` to drop the other stale entries. + Provide data in one of two forms: + + - **Full array** via *values*: a 1-D array-like of length + ``n_atoms`` (same value every frame) or a 2-D array-like of + shape ``(n_frames, n_atoms)`` (per-frame values). + - **Sparse** via *by_species* and/or *by_index*: maps species + labels or atom indices to values. See below for shape rules + and precedence. + + Mixing *values* with *by_species* or *by_index* raises + :class:`ValueError`. + + **by_species** maps species labels to values. Scalars broadcast + to all atoms of the species; 1-D arrays (length = count of that + species' atoms) assign per-atom; 2-D arrays of shape + ``(n_frames, n_species_atoms)`` assign per-frame. A 1-D array + is always interpreted as static per-atom, even when its length + equals ``n_frames``. + + **by_index** maps atom indices to values. Scalars are static; + 1-D arrays of length ``n_frames`` are per-frame. + + When both are provided, *by_index* values take precedence over + *by_species* at overlapping atoms. + + Unspecified atoms are filled with ``NaN`` (numeric) or ``None`` + (categorical, stored as object-dtype). + + A 2-D *values* array, or any ``by_*`` form that promotes to + 2-D, is validated against the container's prospective post-write + state: the array's frame count must match ``len(self.frames)``. Args: key: Name for this metadata (e.g. ``"charge"``, ``"site"``). - values: Either a 1-D array-like of length ``n_atoms`` (same - value every frame), a 2-D array-like of shape - ``(n_frames, n_atoms)`` (per-frame values), or a dict - mapping atom indices to values (always 1-D). When a - dict is given, the fill value for missing atoms is - inferred from the first entry: ``NaN`` for numeric - values or ``""`` for string values. All values in a - dict must be of compatible types (all numeric or all - strings). + values: Full-length array-like. Must not be a dict; use + *by_index* for sparse assignment by atom index. + by_species: Maps species labels to scalar, 1-D, or 2-D + values. All keys must be present in + ``scene.species``. + by_index: Maps atom indices to scalar or 1-D values. + All keys must be in ``range(len(scene.species))``. Raises: - ValueError: If an array-like has the wrong length or - shape for *n_atoms*, if a 2-D array's leading - dimension does not match ``len(self.frames)``, if a - non-overridden stored 2-D entry is stale relative to - ``len(self.frames)`` (the error names the stale key - and points at :meth:`clear_2d_atom_data` for - recovery), if the coerced array has an unsupported - dtype (only bool, integer, float, string, and object - are accepted), or if a dict contains indices outside - the valid range. - TypeError: If a dict contains a mixture of string and - numeric values. + ValueError: If *values* is mixed with *by_species* or + *by_index*; if all three are absent; if a species + label is unknown; if an atom index is out of range; + if an array has the wrong shape for its context; or + if a 2-D array's frame count does not match + ``len(self.frames)``. + TypeError: If a dict is passed as *values* (use + ``by_index=`` instead), or if values contain a + mixture of string and numeric types. See Also: :meth:`del_atom_data`: Remove a single entry. :meth:`clear_2d_atom_data`: Remove all 2-D entries. """ - n_atoms = len(self.species) - if isinstance(values, dict): - if not values: - raise ValueError("values dict must not be empty") - for idx in values: - if not 0 <= idx < n_atoms: - raise ValueError( - f"atom index {idx} out of range for " - f"{n_atoms} atoms" - ) - sample = next(iter(values.values())) - is_str = isinstance(sample, str) - for idx, val in values.items(): - if isinstance(val, str) != is_str: - raise TypeError( - f"atom_data dict values must all be the same " - f"type (string or numeric), but index {idx} " - f"has type {type(val).__name__!r}" - ) - if is_str: - arr = np.array([""] * n_atoms, dtype=object) - for idx, val in values.items(): - arr[idx] = val - else: - arr = np.full(n_atoms, np.nan) - for idx, val in values.items(): - arr[idx] = val - else: + raise TypeError( + "values must be array-like; use by_index= for sparse " + "assignment by atom index" + ) + has_values = values is not None + has_sparse = bool(by_species) or bool(by_index) + if has_values and has_sparse: + raise ValueError( + "cannot mix positional values with by_species or by_index" + ) + if not has_values and not has_sparse: + raise ValueError( + "set_atom_data requires values, by_species, or by_index" + ) + + if has_values: arr = np.asarray(values) + else: + arr = self._coerce_sparse_atom_data( + key, + by_species=by_species or {}, + by_index=by_index or {}, + ) self._atom_data._set( key, arr, expected_frames=len(self.frames) @@ -491,17 +689,12 @@ def clear_2d_atom_data(self) -> None: as overridden by the pending write -- and this method is unnecessary. - Multi-entry recovery workflow:: - - scene.frames.append(new_frame) - scene.clear_2d_atom_data() - scene.set_atom_data("energy", new_energy_2d) - scene.set_atom_data("forces", new_forces_2d) - scene.render_mpl(...) + The multi-entry recovery workflow is: call this method, + then re-assign each 2-D key via :meth:`set_atom_data` at + the new shape, then render. See Also: - :meth:`set_atom_data`: Canonical write entry point; also - handles single-entry in-place reassignment. + :meth:`set_atom_data`: Canonical write entry point. :meth:`del_atom_data`: Remove a single entry. """ self._atom_data._clear_2d() diff --git a/tests/test_model/test_structure_scene.py b/tests/test_model/test_structure_scene.py index 353ebed..2f75a2c 100644 --- a/tests/test_model/test_structure_scene.py +++ b/tests/test_model/test_structure_scene.py @@ -232,31 +232,13 @@ def test_wrong_length_raises(self): with pytest.raises(ValueError, match="length 3"): scene.set_atom_data("charge", np.array([1.0, 2.0])) - def test_sparse_dict_numeric(self): + def test_dict_in_values_raises_type_error(self): scene = self._scene() - scene.set_atom_data("charge", {0: 1.5, 2: -0.3}) - arr = scene.atom_data["charge"] - assert arr[0] == pytest.approx(1.5) - assert np.isnan(arr[1]) - assert arr[2] == pytest.approx(-0.3) - - def test_sparse_dict_string(self): - scene = self._scene() - scene.set_atom_data("site", {1: "4a"}) - arr = scene.atom_data["site"] - assert arr[0] == "" - assert arr[1] == "4a" - assert arr[2] == "" - - def test_sparse_dict_out_of_range_raises(self): - scene = self._scene() - with pytest.raises(ValueError, match="out of range"): - scene.set_atom_data("charge", {5: 1.0}) - - def test_sparse_dict_empty_raises(self): - scene = self._scene() - with pytest.raises(ValueError, match="must not be empty"): - scene.set_atom_data("charge", {}) + with pytest.raises( + TypeError, + match="values must be array-like.*by_index", + ): + scene.set_atom_data("charge", {0: 1.0}) def test_overwrite_existing_key(self): scene = self._scene() @@ -273,12 +255,6 @@ def test_multiple_keys(self): assert "charge" in scene.atom_data assert "site" in scene.atom_data - def test_sparse_dict_mixed_types_raises(self): - """Dict with mixed string and numeric values raises TypeError.""" - scene = self._scene() - with pytest.raises(TypeError, match="same type"): - scene.set_atom_data("bad", {0: 1, 2: "text"}) - def test_2d_numeric_array(self): """A (n_frames, n_atoms) numeric array is accepted.""" coords = np.zeros((3, 3)) @@ -329,6 +305,339 @@ def test_2d_categorical_array(self): assert scene.atom_data["site"].shape == (2, 3) assert scene.atom_data["site"][0, 1] == "8b" + def test_by_index_numeric_scalar(self): + scene = self._scene() + scene.set_atom_data("charge", by_index={0: 1.5, 2: -0.3}) + arr = scene.atom_data["charge"] + assert arr[0] == pytest.approx(1.5) + assert np.isnan(arr[1]) + assert arr[2] == pytest.approx(-0.3) + + def test_by_index_out_of_range_raises(self): + scene = self._scene() + with pytest.raises(ValueError, match="out of range"): + scene.set_atom_data("charge", by_index={5: 1.0}) + + def test_by_index_negative_index_raises(self): + scene = self._scene() + with pytest.raises(ValueError, match="out of range"): + scene.set_atom_data("charge", by_index={-1: 1.0}) + + def test_values_and_by_index_raises(self): + scene = self._scene() + with pytest.raises(ValueError, match="cannot mix"): + scene.set_atom_data("charge", [1.0, 2.0, 3.0], by_index={0: 1.0}) + + def test_all_omitted_raises(self): + scene = self._scene() + with pytest.raises(ValueError): + scene.set_atom_data("charge") + + def test_by_index_categorical(self): + scene = self._scene() + scene.set_atom_data("site", by_index={1: "4a"}) + arr = scene.atom_data["site"] + assert arr[0] is None + assert arr[1] == "4a" + assert arr[2] is None + assert arr.dtype == object + + def test_by_index_mixed_types_raises(self): + scene = self._scene() + with pytest.raises(TypeError, match="same type"): + scene.set_atom_data("bad", by_index={0: 1.0, 2: "text"}) + + def test_by_species_scalar(self): + """Scalar broadcasts to all atoms of a species.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data("charge", by_species={"Mn": 2.0}) + arr = scene.atom_data["charge"] + assert arr[0] == pytest.approx(2.0) + assert arr[1] == pytest.approx(2.0) + assert np.isnan(arr[2]) + assert np.isnan(arr[3]) + + def test_by_species_1d_array(self): + """1-D array assigns per-atom values within the species.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data("charge", by_species={"Mn": [2.0, 3.0]}) + arr = scene.atom_data["charge"] + assert arr[0] == pytest.approx(2.0) + assert arr[1] == pytest.approx(3.0) + assert np.isnan(arr[2]) + assert np.isnan(arr[3]) + + def test_by_species_multiple_species(self): + """Multiple species in one call.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data("charge", by_species={"Mn": 2.0, "O": -2.0}) + arr = scene.atom_data["charge"] + np.testing.assert_array_almost_equal(arr, [2.0, 2.0, -2.0, -2.0]) + + def test_by_species_categorical(self): + """String values produce object-dtype with None fill.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data("site", by_species={"Mn": "oct"}) + arr = scene.atom_data["site"] + assert arr[0] == "oct" + assert arr[1] == "oct" + assert arr[2] is None + assert arr[3] is None + assert arr.dtype == object + + def test_by_species_unknown_raises(self): + scene = self._scene() + with pytest.raises(ValueError, match="unknown species.*Xe"): + scene.set_atom_data("charge", by_species={"Xe": 1.0}) + + def test_by_species_wrong_length_raises(self): + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + with pytest.raises(ValueError, match="2"): + scene.set_atom_data("charge", by_species={"Mn": [1.0, 2.0, 3.0]}) + + def test_combined_by_index_overrides_by_species(self): + """by_index wins where it overlaps with by_species.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": 2.0}, + by_index={0: 1.9}, + ) + arr = scene.atom_data["charge"] + assert arr[0] == pytest.approx(1.9) # by_index wins + assert arr[1] == pytest.approx(2.0) # by_species + assert np.isnan(arr[2]) + assert np.isnan(arr[3]) + + def test_combined_none_overrides_by_species(self): + """by_index={i: None} clears the by_species value for that atom.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data( + "site", + by_species={"Mn": "oct"}, + by_index={0: None}, + ) + arr = scene.atom_data["site"] + assert arr[0] is None # by_index None overrides by_species + assert arr[1] == "oct" # by_species + assert arr[2] is None # unspecified + assert arr[3] is None # unspecified + assert arr.dtype == object + + def test_combined_mixed_types_raises(self): + """String in by_species + numeric in by_index is a TypeError.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + with pytest.raises(TypeError, match="same type"): + scene.set_atom_data( + "bad", + by_species={"Mn": "oct"}, + by_index={2: 1.0}, + ) + + def test_values_and_by_species_raises(self): + scene = self._scene() + with pytest.raises(ValueError, match="cannot mix"): + scene.set_atom_data("charge", [1.0, 2.0, 3.0], by_species={"A": 1.0}) + + def test_by_species_2d_promotes(self): + """A 2-D by_species value promotes output to (n_frames, n_atoms).""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": np.array([[1.0, 2.0], [3.0, 4.0]])}, + ) + arr = scene.atom_data["charge"] + assert arr.shape == (2, 4) + assert arr[0, 0] == pytest.approx(1.0) + assert arr[1, 1] == pytest.approx(4.0) + assert np.isnan(arr[0, 2]) + + def test_by_index_1d_promotes(self): + """A 1-D by_index value of length n_frames promotes to 2-D.""" + coords = np.zeros((3, 3)) + scene = StructureScene( + species=["A", "B", "C"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_index={0: [10.0, 20.0]}, + ) + arr = scene.atom_data["charge"] + assert arr.shape == (2, 3) + assert arr[0, 0] == pytest.approx(10.0) + assert arr[1, 0] == pytest.approx(20.0) + assert np.isnan(arr[0, 1]) + + def test_by_species_scalar_broadcasts_when_promoted(self): + """Scalar by_species broadcasts across frames when 2-D.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": 2.0}, + by_index={2: [5.0, 6.0]}, # promotes to 2-D + ) + arr = scene.atom_data["charge"] + assert arr.shape == (2, 4) + assert arr[0, 0] == pytest.approx(2.0) + assert arr[1, 0] == pytest.approx(2.0) + assert arr[0, 2] == pytest.approx(5.0) + assert arr[1, 2] == pytest.approx(6.0) + + def test_by_species_1d_broadcasts_when_promoted(self): + """1-D by_species broadcasts across frames when promoted by by_index.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": [1.0, 2.0]}, + by_index={2: [5.0, 6.0]}, + ) + arr = scene.atom_data["charge"] + assert arr.shape == (2, 4) + assert arr[0, 0] == pytest.approx(1.0) # Mn[0] frame 0 + assert arr[1, 0] == pytest.approx(1.0) # Mn[0] frame 1 (broadcast) + assert arr[0, 1] == pytest.approx(2.0) # Mn[1] frame 0 + assert arr[1, 1] == pytest.approx(2.0) # Mn[1] frame 1 (broadcast) + + def test_by_species_2d_wrong_shape_raises(self): + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + with pytest.raises(ValueError, match="2 frames"): + scene.set_atom_data( + "charge", + by_species={"Mn": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])}, + ) + + def test_by_index_1d_wrong_length_raises(self): + coords = np.zeros((3, 3)) + scene = StructureScene( + species=["A", "B", "C"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + with pytest.raises(ValueError, match="2 frames"): + scene.set_atom_data("charge", by_index={0: [1.0, 2.0, 3.0]}) + + def test_by_index_2d_raises(self): + scene = self._scene() + with pytest.raises(ValueError): + scene.set_atom_data( + "charge", by_index={0: np.array([[1.0]])} + ) + + def test_by_species_3d_raises(self): + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + with pytest.raises(ValueError, match="2-D"): + scene.set_atom_data( + "charge", + by_species={"Mn": np.zeros((2, 2, 2))}, + ) + + def test_by_index_scalar_broadcasts_when_promoted(self): + """Scalar by_index broadcasts across frames when output is 2-D.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": np.array([[1.0, 2.0], [3.0, 4.0]])}, + by_index={2: 9.0}, + ) + arr = scene.atom_data["charge"] + assert arr.shape == (2, 4) + assert arr[0, 2] == pytest.approx(9.0) + assert arr[1, 2] == pytest.approx(9.0) + + def test_2d_categorical_via_by_species(self): + """2-D categorical sparse path allocates object-dtype with None.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords), Frame(coords=coords)], + ) + scene.set_atom_data( + "site", + by_species={ + "Mn": np.array([["oct", "tet"], ["tet", "oct"]], dtype=object), + }, + ) + arr = scene.atom_data["site"] + assert arr.shape == (2, 4) + assert arr[0, 0] == "oct" + assert arr[1, 0] == "tet" + assert arr[0, 2] is None + assert arr.dtype == object + + def test_none_in_object_array_not_misclassified(self): + """Object array of Nones mixed with numeric values is not categorical.""" + coords = np.zeros((4, 3)) + scene = StructureScene( + species=["Mn", "Mn", "O", "O"], + frames=[Frame(coords=coords)], + ) + scene.set_atom_data( + "charge", + by_species={"Mn": np.array([None, None], dtype=object)}, + by_index={2: 1.0}, + ) + arr = scene.atom_data["charge"] + assert np.isnan(arr[0]) + assert np.isnan(arr[1]) + assert arr[2] == pytest.approx(1.0) + assert np.isnan(arr[3]) + class TestAtomDataWriteMethods: """Tests for del_atom_data, clear_2d_atom_data, setter removal.""" diff --git a/tests/test_rendering/test_painter.py b/tests/test_rendering/test_painter.py index 3e92d3e..48634de 100644 --- a/tests/test_rendering/test_painter.py +++ b/tests/test_rendering/test_painter.py @@ -1080,7 +1080,7 @@ def test_callable_cmap(self): def test_colour_by_with_nan(self): """Atoms with NaN data fall back to species colour.""" scene = _minimal_scene() - scene.set_atom_data("charge", {0: 1.0}) # atom 1 gets NaN + scene.set_atom_data("charge", by_index={0: 1.0}) # atom 1 gets NaN fig = render_mpl(scene, colour_by="charge", show=False) assert isinstance(fig, Figure) plt.close(fig) @@ -1096,8 +1096,8 @@ def test_convenience_method(self): def test_list_colour_by_smoke(self): """render_mpl with a list of colour_by keys produces a Figure.""" scene = _minimal_scene() - scene.set_atom_data("a", {0: 1.0}) - scene.set_atom_data("b", {1: 2.0}) + scene.set_atom_data("a", by_index={0: 1.0}) + scene.set_atom_data("b", by_index={1: 2.0}) fig = render_mpl( scene, colour_by=["a", "b"], cmap=["viridis", "plasma"], show=False, @@ -1112,7 +1112,7 @@ def test_polyhedra_inherit_colour_by(self): scene = _octahedron_scene() # Ti is atom 0; give it a numerical value so it gets a cmap colour. n_atoms = len(scene.species) - scene.set_atom_data("val", {0: 0.5}) + scene.set_atom_data("val", by_index={0: 0.5}) def red(_v: float) -> tuple[float, float, float]: return (1.0, 0.0, 0.0) @@ -1131,7 +1131,7 @@ def test_polyhedra_spec_colour_overrides_colour_by(self): scene = _octahedron_scene(colour=(0.0, 0.0, 1.0)) n_atoms = len(scene.species) - scene.set_atom_data("val", {0: 0.5}) + scene.set_atom_data("val", by_index={0: 0.5}) def red(_v: float) -> tuple[float, float, float]: return (1.0, 0.0, 0.0) @@ -1235,7 +1235,7 @@ def test_per_frame_list_colour_by(self): scene = _minimal_scene() scene.frames.append(Frame(coords=scene.frames[0].coords.copy())) - scene.set_atom_data("static", {0: 1.0}) + scene.set_atom_data("static", by_index={0: 1.0}) scene.set_atom_data( "dynamic", np.array([[np.nan, 0.0], [np.nan, 1.0]]), )