Skip to content

Commit ca475c6

Browse files
committed
Front end classes
Monitor tests Remove IndexedFieldDataArray
1 parent cc8fef1 commit ca475c6

File tree

8 files changed

+288
-17
lines changed

8 files changed

+288
-17
lines changed

tests/test_components/test_heat_charge.py

Lines changed: 131 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,22 @@ def monitors():
254254

255255
mesh_mnt = td.VolumeMeshMonitor(size=(1.6, 2, 3), name="mesh_test")
256256

257+
electric_field_mnt = td.SteadyElectricFieldMonitor(size=(1.6, 2, 3), name="electric_field_test")
258+
257259
return [
258-
temp_mnt1,
259-
temp_mnt2,
260-
temp_mnt3,
261-
temp_mnt4,
262-
volt_mnt1,
263-
volt_mnt2,
264-
volt_mnt3,
265-
volt_mnt4,
266-
capacitance_mnt1,
267-
free_carrier_mnt1,
268-
energy_band_mnt1,
269-
mesh_mnt,
260+
temp_mnt1, # 0
261+
temp_mnt2, # 1
262+
temp_mnt3, # 2
263+
temp_mnt4, # 3
264+
volt_mnt1, # 4
265+
volt_mnt2, # 5
266+
volt_mnt3, # 6
267+
volt_mnt4, # 7
268+
capacitance_mnt1, # 8
269+
free_carrier_mnt1, # 9
270+
energy_band_mnt1, # 10
271+
mesh_mnt, # 11
272+
electric_field_mnt, # 12
270273
]
271274

272275

@@ -519,7 +522,10 @@ def temperature_monitor_data(monitors):
519522
@pytest.fixture(scope="module")
520523
def voltage_monitor_data(monitors):
521524
"""Creates different voltage monitor data."""
522-
_, _, _, _, volt_mnt1, volt_mnt2, volt_mnt3, volt_mnt4, _, _, _, _ = monitors
525+
volt_mnt1 = monitors[4]
526+
volt_mnt2 = monitors[5]
527+
volt_mnt3 = monitors[6]
528+
volt_mnt4 = monitors[7]
523529

524530
# SpatialDataArray
525531
nx, ny, nz = 9, 6, 5
@@ -679,6 +685,76 @@ def energy_band_monitor_data(monitors):
679685
return (eb_data1,)
680686

681687

688+
@pytest.fixture(scope="module")
689+
def electric_field_monitor_data(monitors):
690+
"""Creates different electric field monitor data."""
691+
monitor = monitors[12]
692+
693+
# TetrahedralGridDataset
694+
tet_grid_points = td.PointDataArray(
695+
[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
696+
dims=("index", "axis"),
697+
)
698+
699+
tet_grid_cells = td.CellDataArray(
700+
[[0, 1, 2, 4], [1, 2, 3, 4]],
701+
dims=("cell_index", "vertex_index"),
702+
)
703+
704+
tet_grid_values = td.PointDataArray(
705+
[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [3.0, 5.0, 1.0], [4.0, 5.0, 3.0], [5.0, 2.0, 1.0]],
706+
dims=(
707+
"index",
708+
"axis",
709+
),
710+
name="T",
711+
)
712+
713+
tet_grid = td.TetrahedralGridDataset(
714+
points=tet_grid_points,
715+
cells=tet_grid_cells,
716+
values=tet_grid_values,
717+
)
718+
719+
mnt_data1 = td.SteadyElectricFieldData(monitor=monitor, E=tet_grid)
720+
721+
# TriangularGridDataset
722+
tri_grid_points = td.PointDataArray(
723+
[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
724+
dims=("index", "axis"),
725+
)
726+
727+
tri_grid_cells = td.CellDataArray(
728+
[[0, 1, 2], [1, 2, 3]],
729+
dims=("cell_index", "vertex_index"),
730+
)
731+
732+
tri_grid_values = td.IndexedFieldVoltageDataArray(
733+
[
734+
[[1.0, 1.5], [-1.0, 1.1], [5.1, 0.0]],
735+
[[1.0, 1.5], [-1.0, 1.1], [5.1, 0.0]],
736+
[[1.0, 1.5], [-1.0, 1.1], [5.1, 0.0]],
737+
[[1.0, 1.5], [-1.0, 1.1], [5.1, 0.0]],
738+
],
739+
coords={"index": np.arange(4), "axis": np.arange(3), "voltage": [-1, 1]},
740+
name="T",
741+
)
742+
743+
tri_grid = td.TriangularGridDataset(
744+
normal_axis=1,
745+
normal_pos=0,
746+
points=tri_grid_points,
747+
cells=tri_grid_cells,
748+
values=tri_grid_values,
749+
)
750+
751+
mnt_data2 = td.SteadyElectricFieldData(monitor=monitor, E=tri_grid)
752+
753+
mnt_data3 = td.SteadyElectricFieldData(monitor=monitor, E=None)
754+
755+
return (mnt_data1, mnt_data2, mnt_data3)
756+
757+
682758
@pytest.fixture(scope="module")
683759
def simulation_data(
684760
heat_simulation,
@@ -853,11 +929,52 @@ def test_monitor_crosses_medium(mediums, structures, heat_simulation, conduction
853929

854930

855931
def test_heat_charge_mnt_data(
856-
temperature_monitor_data, voltage_monitor_data, capacitance_monitor_data
932+
temperature_monitor_data, voltage_monitor_data, electric_field_monitor_data
857933
):
858934
"""Tests whether different heat-charge monitor data can be created."""
859935
assert len(temperature_monitor_data) == 4, "Expected 4 temperature monitor data entries."
860936
assert len(voltage_monitor_data) == 4, "Expected 4 voltage monitor data entries."
937+
assert len(electric_field_monitor_data) == 3, "Expected 3 electric field monitor data entries."
938+
939+
for mnt_data in electric_field_monitor_data:
940+
assert "E" in mnt_data.field_components.keys()
941+
942+
symm_data = mnt_data.symmetry_expanded_copy
943+
assert symm_data.E == mnt_data.E
944+
945+
names = mnt_data.field_name("abs^2")
946+
assert names == "E²"
947+
names = mnt_data.field_name()
948+
assert names == "E"
949+
950+
# make sure an error is raised if we don't use a field data array
951+
# TriangularGridDataset
952+
tri_grid_points = td.PointDataArray(
953+
[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
954+
dims=("index", "axis"),
955+
)
956+
957+
tri_grid_cells = td.CellDataArray(
958+
[[0, 1, 2], [1, 2, 3]],
959+
dims=("cell_index", "vertex_index"),
960+
)
961+
962+
tri_grid_values = td.IndexedDataArray(
963+
[1.0, 2.0, 3.0, 4.0],
964+
dims=("index",),
965+
name="T",
966+
)
967+
968+
tri_grid = td.TriangularGridDataset(
969+
normal_axis=1,
970+
normal_pos=0,
971+
points=tri_grid_points,
972+
cells=tri_grid_cells,
973+
values=tri_grid_values,
974+
)
975+
976+
with pytest.raises(pd.ValidationError):
977+
_ = mnt_data.updated_copy(E=tri_grid)
861978

862979

863980
def test_grid_spec_validation(grid_specs):

tidy3d/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from tidy3d.components.tcad.data.types import (
3939
SteadyCapacitanceData,
40+
SteadyElectricFieldData,
4041
SteadyEnergyBandData,
4142
SteadyFreeCarrierData,
4243
SteadyPotentialData,
@@ -53,6 +54,7 @@
5354
from tidy3d.components.tcad.mesher import VolumeMesher
5455
from tidy3d.components.tcad.monitors.charge import (
5556
SteadyCapacitanceMonitor,
57+
SteadyElectricFieldMonitor,
5658
SteadyEnergyBandMonitor,
5759
SteadyFreeCarrierMonitor,
5860
SteadyPotentialMonitor,
@@ -136,6 +138,7 @@
136138
FluxTimeDataArray,
137139
HeatDataArray,
138140
IndexedDataArray,
141+
IndexedFieldVoltageDataArray,
139142
IndexedTimeDataArray,
140143
IndexedVoltageDataArray,
141144
ModeAmpsDataArray,
@@ -565,6 +568,7 @@ def set_logging_level(level: str) -> None:
565568
"HuraySurfaceRoughness",
566569
"IndexPerturbation",
567570
"IndexedDataArray",
571+
"IndexedFieldVoltageDataArray",
568572
"IndexedTimeDataArray",
569573
"IndexedVoltageDataArray",
570574
"InsulatingBC",
@@ -653,6 +657,8 @@ def set_logging_level(level: str) -> None:
653657
"Staircasing",
654658
"SteadyCapacitanceData",
655659
"SteadyCapacitanceMonitor",
660+
"SteadyElectricFieldData",
661+
"SteadyElectricFieldMonitor",
656662
"SteadyEnergyBandData",
657663
"SteadyEnergyBandMonitor",
658664
"SteadyFreeCarrierData",

tidy3d/components/data/data_array.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,9 +1175,9 @@ class SteadyVoltageDataArray(DataArray):
11751175

11761176

11771177
class PointDataArray(DataArray):
1178-
"""A two-dimensional array that stores coordinates of a collection of points.
1178+
"""A two-dimensional array that stores coordinates/field components for a collection of points.
11791179
Dimension ``index`` denotes the index of a point in the collection, and dimension ``axis``
1180-
denotes the point's coordinate along that axis.
1180+
denotes the field component (or point coordinate) in that direction.
11811181
11821182
Example
11831183
-------
@@ -1188,6 +1188,14 @@ class PointDataArray(DataArray):
11881188
>>> point3 = point_array.sel(index=3)
11891189
>>> # get x coordinates of all points
11901190
>>> x_coords = point_array.sel(axis=0)
1191+
>>>
1192+
>>> field_da = PointDataArray(
1193+
... np.random.random((120, 3)), coords=dict(index=np.arange(120), axis=np.arange(3)),
1194+
... )
1195+
>>> # get field of point number 90
1196+
>>> field_point90 = field_da.sel(index=90)
1197+
>>> # get z component of all points
1198+
>>> z_field = field_da.sel(axis=2)
11911199
"""
11921200

11931201
__slots__ = ()
@@ -1265,6 +1273,20 @@ class IndexedTimeDataArray(DataArray):
12651273
_dims = ("index", "t")
12661274

12671275

1276+
class IndexedFieldVoltageDataArray(DataArray):
1277+
"""Stores indexed values of vector fields for different voltages. It is typically used
1278+
in conjuction with a ``PointDataArray`` to store point-associated vector data.
1279+
Example
1280+
-------
1281+
>>> indexed_array = IndexedFieldVoltageDataArray(
1282+
... (1+1j) * np.random.random((4,3,2)), coords=dict(index=np.arange(4), axis=np.arange(3), voltage=[-1, 1])
1283+
... )
1284+
"""
1285+
1286+
__slots__ = ()
1287+
_dims = ("index", "axis", "voltage")
1288+
1289+
12681290
class SpatialVoltageDataArray(AbstractSpatialDataArray):
12691291
"""Spatial distribution with voltage mapping.
12701292
@@ -1319,11 +1341,18 @@ class PerturbationCoefficientDataArray(DataArray):
13191341
PointDataArray,
13201342
CellDataArray,
13211343
IndexedDataArray,
1344+
IndexedFieldVoltageDataArray,
13221345
IndexedVoltageDataArray,
13231346
SpatialVoltageDataArray,
13241347
PerturbationCoefficientDataArray,
13251348
IndexedTimeDataArray,
13261349
]
13271350
DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES}
13281351

1329-
IndexedDataArrayTypes = Union[IndexedDataArray, IndexedVoltageDataArray, IndexedTimeDataArray]
1352+
IndexedDataArrayTypes = Union[
1353+
IndexedDataArray,
1354+
IndexedVoltageDataArray,
1355+
IndexedTimeDataArray,
1356+
IndexedFieldVoltageDataArray,
1357+
PointDataArray,
1358+
]

tidy3d/components/data/unstructured/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,20 @@ def to_vtu(self, fname: str):
618618
writer.SetInputData(self._vtk_obj)
619619
writer.Write()
620620

621+
@classmethod
622+
@requires_vtk
623+
def _cell_to_point_data(
624+
cls,
625+
vtk_obj,
626+
):
627+
"""Get point data values from a VTK object."""
628+
629+
cellDataToPointData = vtk["mod"].vtkCellDataToPointData()
630+
cellDataToPointData.SetInputData(vtk_obj)
631+
cellDataToPointData.Update()
632+
633+
return cellDataToPointData.GetOutput()
634+
621635
@classmethod
622636
@requires_vtk
623637
def _get_values_from_vtk(

0 commit comments

Comments
 (0)