diff --git a/orsopy/fileio/base.py b/orsopy/fileio/base.py index 87b2c855..cf8154b8 100644 --- a/orsopy/fileio/base.py +++ b/orsopy/fileio/base.py @@ -870,7 +870,7 @@ def _read_header_data(file: Union[TextIO, str], validate: bool = False) -> Tuple # numerical array and start collecting the numbers for this # dataset _d = np.array([np.fromstring(v, dtype=float, sep=" ") for v in _ds_lines]) - data.append(_d) + data.append(_d.T) _ds_lines = [] # append '---' to signify the start of a new yaml document @@ -883,7 +883,7 @@ def _read_header_data(file: Union[TextIO, str], validate: bool = False) -> Tuple # append the last numerical array _d = np.array([np.fromstring(v, dtype=float, sep=" ") for v in _ds_lines]) - data.append(_d) + data.append(_d.T) yml = "".join(header) diff --git a/orsopy/fileio/orso.py b/orsopy/fileio/orso.py index 2f830015..be551caa 100644 --- a/orsopy/fileio/orso.py +++ b/orsopy/fileio/orso.py @@ -166,8 +166,11 @@ class OrsoDataset: data: Union[np.ndarray, Sequence[np.ndarray], Sequence[Sequence]] def __post_init__(self): - if self.data.shape[1] != len(self.info.columns): + if len(self.data) != len(self.info.columns): raise ValueError("Data has to have the same number of columns as header") + column_lengths = set(len(c) for c in self.data) + if len(column_lengths) > 1: + raise ValueError("Columns must all have the same length in first dimension") def header(self) -> str: """ @@ -252,13 +255,13 @@ def save_orso( ds1 = datasets[0] header += ds1.header() - np.savetxt(f, ds1.data, header=header, fmt="%-22.16e") + np.savetxt(f, np.asarray(ds1.data).T, header=header, fmt="%-22.16e") for dsi in datasets[1:]: # write an optional spacer string between dataset e.g. \n f.write(data_separator) hi = ds1.diff_header(dsi) - np.savetxt(f, dsi.data, header=hi, fmt="%-22.16e") + np.savetxt(f, np.asarray(dsi.data).T, header=hi, fmt="%-22.16e") def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]: diff --git a/orsopy/fileio/tests/test_orso.py b/orsopy/fileio/tests/test_orso.py index 5ff3823a..8fc32772 100644 --- a/orsopy/fileio/tests/test_orso.py +++ b/orsopy/fileio/tests/test_orso.py @@ -107,8 +107,8 @@ def test_write_read(self): # test write and read of multiple datasets info = fileio.Orso.empty() info2 = fileio.Orso.empty() - data = np.zeros((100, 3)) - data[:] = np.arange(100.0)[:, None] + data = np.zeros((3, 100)) + data[:] = np.arange(100.0)[None, :] info.columns = [ fileio.Column("Qz", "1/angstrom"), @@ -190,14 +190,14 @@ def test_unique_dataset(self): info2.data_set = 0 info2.columns = [Column("stuff")] * 4 - ds = OrsoDataset(info, np.empty((2, 4))) - ds2 = OrsoDataset(info2, np.empty((2, 4))) + ds = OrsoDataset(info, np.empty((4, 2))) + ds2 = OrsoDataset(info2, np.empty((4, 2))) with pytest.raises(ValueError): fileio.save_orso([ds, ds2], "test_data_set.ort") with pytest.raises(ValueError): - OrsoDataset(info, np.empty((2, 5))) + OrsoDataset(info, np.empty((5, 2))) def test_user_data(self): # test write and read of userdata @@ -208,8 +208,8 @@ def test_user_data(self): fileio.ErrorColumn("R"), ] - data = np.zeros((100, 3)) - data[:] = np.arange(100.0)[:, None] + data = np.zeros((3, 100)) + data[:] = np.arange(100.0)[None, :] dct = {"ci": "1", "foo": ["bar", 1, 2, 3.5]} info.user_data = dct ds = fileio.OrsoDataset(info, data) @@ -247,7 +247,7 @@ def test_save_numpy_scalar_dtypes(self): info = fileio.Orso.empty() info.data_source.measurement.instrument_settings.wavelength = Value(np.float64(10.0)) info.data_source.measurement.instrument_settings.incident_angle = Value(np.int32(2)) - ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(10, 2)) + ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(2, 10)) # .ort test: fileio.save_orso([ds], "test_numpy.ort") ls = fileio.load_orso("test_numpy.ort") diff --git a/orsopy/fileio/tests/test_schema.py b/orsopy/fileio/tests/test_schema.py index 10cfb2fd..aa55698d 100644 --- a/orsopy/fileio/tests/test_schema.py +++ b/orsopy/fileio/tests/test_schema.py @@ -18,7 +18,7 @@ def test_example_ort(self): schema = json.load(f) dct_list, data, version = _read_header_data(pth / "test_example.ort", validate=True) - assert data[0].shape == (2, 4) + assert data[0].shape == (4, 2) assert version == "0.1" # d contains datetime.datetime objects, which would fail the @@ -34,4 +34,4 @@ def test_example_ort(self): assert len(dct_list) == 2 assert dct_list[1]["data_set"] == "spin_down" assert data[1].shape == (4, 4) - np.testing.assert_allclose(data[1][2:], data[0]) + np.testing.assert_allclose(data[1][:, 2:], data[0])