Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/api_symmetry.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@
:members:
:undoc-members:
```

## Subgroup enumeration

```{eval-rst}
.. automodule:: spgrep.symmetry.subgroup
:members:
:undoc-members:
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dynamic = [
"version",
]
dependencies = [
"hsnf>=0.3.16",
"numpy>=1.20.1",
"spglib>=2.7",
"typing-extensions",
Expand Down
125 changes: 125 additions & 0 deletions src/spgrep/symmetry/subgroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Isotropy subgroup of space group."""

from __future__ import annotations

from queue import Queue

from spgrep.rep.group import get_identity_index, get_inverse_index
from spgrep.utils import (
NDArrayInt,
)


def enumerate_point_subgroup(
table: NDArrayInt, preserve_sublattice: list[bool], return_conjugacy_class: bool = True
) -> list[list[int]]:
"""Enumerate conjugacy subgroups of point group.

Parameters
----------
table: array[int], (order, order)
Multiplication table of group
preserve_sublattice: list[bool]
Specify ``preserve_sublattice[i] = True`` if the ``i``-th operation preserves translational subgroup of isotropy subgroup
return_conjugacy_class: bool, default=True
If true, return representatives of conjugacy classes.

Returns
-------
subgroups: list[list[int]]
"""
order = len(table)
identity = get_identity_index(table)
# Represent choice of elements by bit array
st = {1 << identity}
for i in range(order):
if (not preserve_sublattice[i]) or (i == identity):
continue
if (1 << i) in st:
# Already visited
continue

next_st = set()
for bits in st:
elements = _decode_bits(bits, order)
assert _is_subgroup(elements, table)
generated = _traverse(elements + [i], identity, table)
next_st.add(sum(1 << idx for idx in set(generated)))

st = st.union(next_st)

if not return_conjugacy_class:
subgroups = []
for bits in sorted(st):
subgroups.append(_decode_bits(bits, order))
return subgroups

# Group by conjugacy classes
found = set()
ret = []
for bits in sorted(st):
if bits in found:
continue
elements = _decode_bits(bits, order)
ret.append(elements)
for i in range(order):
if not preserve_sublattice[i]:
continue
inv = get_inverse_index(table, i)
conj = [int(table[inv, table[idx, i]]) for idx in elements]
found.add(sum(1 << idx for idx in set(conj)))

assert found == st
return ret


def enumerate_point_subgroup_naive(table, preserve_sublattice: list[bool]):
"""Enumerate conjugacy subgroups of point group in brute force."""
order = len(table)
ret = []
for bits in range(1, 1 << order):
elements = _decode_bits(bits, order)
if not all([preserve_sublattice[idx] for idx in elements]):
continue

if _is_subgroup(elements, table):
ret.append(bits)

return ret


def _decode_bits(bits: int, order: int) -> list[int]:
elements = [idx for idx in range(order) if (bits >> idx) & 1 == 1]
return elements


def _is_subgroup(elements: list[int], table: NDArrayInt) -> bool:
subtable = table[elements][:, elements]
for i in range(len(subtable)):
if (set(subtable[i]) != set(elements)) or (set(subtable[:, i]) != set(elements)):
return False
return True


def _traverse(
generators: list[int],
identity: int,
table: NDArrayInt,
) -> list[int]:
"""Traverse group elements from generators."""
visited = [False for _ in range(len(table))]
que = Queue() # type: ignore
que.put(identity)

while not que.empty():
g = que.get()
if visited[g]:
continue
visited[g] = True

for h in generators:
gh = int(table[g, h]) # cast np.int64 to int
if not visited[gh]:
que.put(gh)

return sorted([i for i, v in enumerate(visited) if v])
16 changes: 16 additions & 0 deletions tests/symmetry/test_symmetry_isotropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

import numpy as np

from spgrep.symmetry.group import get_cayley_table
from spgrep.symmetry.pointgroup import pg_dataset
from spgrep.symmetry.subgroup import enumerate_point_subgroup, enumerate_point_subgroup_naive


def test_enumerate_point_subgroup():
pointgroup = pg_dataset["4/mmm"][0]
table = get_cayley_table(np.array(pointgroup))
flags = [True for _ in range(len(pointgroup))]
subgroups_actual = enumerate_point_subgroup(table, flags, return_conjugacy_class=False)
subgroups_expect = enumerate_point_subgroup_naive(table, flags)
assert len(subgroups_actual) == len(subgroups_expect)
44 changes: 44 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading