Skip to content

Blackrock add block validation #1740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
146 changes: 86 additions & 60 deletions neo/rawio/blackrockrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,21 @@ def _parse_header(self):
self.__nsx_data_header = {}

for nsx_nb in self._avail_nsx:
spec = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
spec_version = self.__nsx_spec[nsx_nb] = self.__extract_nsx_file_spec(nsx_nb)
# read nsx headers
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = self.__nsx_header_reader[spec](nsx_nb)
nsx_header_reader = self.__nsx_header_reader[spec_version]
self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = nsx_header_reader(nsx_nb)

# The only way to know if it is the PTP-variant of file spec 3.0
# The only way to know if it is the Precision Time Protocol of file spec 3.0
# is to check for nanosecond timestamp resolution.
if (
is_ptp_variant = (
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
):
)
if is_ptp_variant:
nsx_dataheader_reader = self.__nsx_dataheader_reader["3.0-ptp"]
else:
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec]
nsx_dataheader_reader = self.__nsx_dataheader_reader[spec_version]
# for nsxdef get_analogsignal_shape(self, block_index, seg_index):
self.__nsx_data_header[nsx_nb] = nsx_dataheader_reader(nsx_nb)

Expand All @@ -355,8 +357,12 @@ def _parse_header(self):
else:
raise (ValueError("nsx_to_load is wrong"))

if not all(nsx_nb in self._avail_nsx for nsx_nb in self.nsx_to_load):
raise FileNotFoundError(f"nsx_to_load does not match available nsx list")
missing_nsx_files = [nsx_nb for nsx_nb in self.nsx_to_load if nsx_nb not in self._avail_nsx]
if missing_nsx_files:
missing_list = ", ".join(f"ns{nsx_nb}" for nsx_nb in missing_nsx_files)
raise FileNotFoundError(
f"Requested NSX file(s) not found: {missing_list}. Available NSX files: {self._avail_nsx}"
)

# check that all files come from the same specification
all_spec = [self.__nsx_spec[nsx_nb] for nsx_nb in self.nsx_to_load]
Expand All @@ -381,27 +387,29 @@ def _parse_header(self):
self.sig_sampling_rates = {}
if len(self.nsx_to_load) > 0:
for nsx_nb in self.nsx_to_load:
spec = self.__nsx_spec[nsx_nb]
# The only way to know if it is the PTP-variant of file spec 3.0
basic_header = self.__nsx_basic_header[nsx_nb]
spec_version = self.__nsx_spec[nsx_nb]
# The only way to know if it is the Precision Time Protocol of file spec 3.0
# is to check for nanosecond timestamp resolution.
if (
"timestamp_resolution" in self.__nsx_basic_header[nsx_nb].dtype.names
and self.__nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000
):
is_ptp_variant = (
"timestamp_resolution" in basic_header.dtype.names
and basic_header["timestamp_resolution"] == 1_000_000_000
)
if is_ptp_variant:
_data_reader_fun = self.__nsx_data_reader["3.0-ptp"]
else:
_data_reader_fun = self.__nsx_data_reader[spec]
_data_reader_fun = self.__nsx_data_reader[spec_version]
self.nsx_datas[nsx_nb] = _data_reader_fun(nsx_nb)

sr = float(self.main_sampling_rate / self.__nsx_basic_header[nsx_nb]["period"])
sr = float(self.main_sampling_rate / basic_header["period"])
self.sig_sampling_rates[nsx_nb] = sr

if spec in ["2.2", "2.3", "3.0"]:
if spec_version in ["2.2", "2.3", "3.0"]:
ext_header = self.__nsx_ext_header[nsx_nb]
elif spec == "2.1":
elif spec_version == "2.1":
ext_header = []
keys = ["labels", "units", "min_analog_val", "max_analog_val", "min_digital_val", "max_digital_val"]
params = self.__nsx_params[spec](nsx_nb)
params = self.__nsx_params[spec_version](nsx_nb)
for i in range(len(params["labels"])):
d = {}
for key in keys:
Expand All @@ -415,11 +423,11 @@ def _parse_header(self):
signal_buffers.append((stream_name, buffer_id))
signal_streams.append((stream_name, stream_id, buffer_id))
for i, chan in enumerate(ext_header):
if spec in ["2.2", "2.3", "3.0"]:
if spec_version in ["2.2", "2.3", "3.0"]:
ch_name = chan["electrode_label"].decode()
ch_id = str(chan["electrode_id"])
units = chan["units"].decode()
elif spec == "2.1":
elif spec_version == "2.1":
ch_name = chan["labels"]
ch_id = str(self.__nsx_ext_header[nsx_nb][i]["electrode_id"])
units = chan["units"]
Expand Down Expand Up @@ -809,7 +817,7 @@ def __extract_nsx_file_spec(self, nsx_nb):
"""
Extract file specification from an .nsx file.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

# Header structure of files specification 2.2 and higher. For files 2.1
# and lower, the entries ver_major and ver_minor are not supported.
Expand All @@ -829,7 +837,7 @@ def __extract_nev_file_spec(self):
"""
Extract file specification from an .nev file
"""
filename = ".".join([self._filenames["nev"], "nev"])
filename = f"{self._filenames['nev']}.nev"
# Header structure of files specification 2.2 and higher. For files 2.1
# and lower, the entries ver_major and ver_minor are not supported.
dt0 = [("file_id", "S8"), ("ver_major", "uint8"), ("ver_minor", "uint8")]
Expand Down Expand Up @@ -879,7 +887,7 @@ def __read_nsx_header_variant_b(self, nsx_nb):
"""
Extract nsx header information from a 2.2 or 2.3 .nsx file
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

# basic header (file_id: NEURALCD)
dt0 = [
Expand Down Expand Up @@ -911,7 +919,6 @@ def __read_nsx_header_variant_b(self, nsx_nb):

# extended header (type: CC)
offset_dt0 = np.dtype(dt0).itemsize
shape = nsx_basic_header["channel_count"]
dt1 = [
("type", "S2"),
("electrode_id", "uint16"),
Expand All @@ -930,28 +937,32 @@ def __read_nsx_header_variant_b(self, nsx_nb):
# filter settings used to create nsx from source signal
("hi_freq_corner", "uint32"),
("hi_freq_order", "uint32"),
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth
("hi_freq_type", "uint16"), # 0=None, 1=Butterworth, 2=Chebyshev
("lo_freq_corner", "uint32"),
("lo_freq_order", "uint32"),
("lo_freq_type", "uint16"),
] # 0=None, 1=Butterworth
] # 0=None, 1=Butterworth, 2=Chebyshev

nsx_ext_header = np.memmap(filename, shape=shape, offset=offset_dt0, dtype=dt1, mode="r")
channel_count = int(nsx_basic_header["channel_count"])
nsx_ext_header = np.memmap(filename, shape=channel_count, offset=offset_dt0, dtype=dt1, mode="r")

return nsx_basic_header, nsx_ext_header

def __read_nsx_dataheader(self, nsx_nb, offset):
"""
Reads data header following the given offset of an nsx file.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

ts_size = "uint64" if self.__nsx_basic_header[nsx_nb]["ver_major"] >= 3 else "uint32"
major_version = self.__nsx_basic_header[nsx_nb]["ver_major"]
ts_size = "uint64" if major_version >= 3 else "uint32"

# dtypes data header, the header flag is always set to 1
dt2 = [("header_flag", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]

# dtypes data header
dt2 = [("header", "uint8"), ("timestamp", ts_size), ("nb_data_points", "uint32")]
packet_header = np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]

return np.memmap(filename, dtype=dt2, shape=1, offset=offset, mode="r")[0]
return packet_header

def __read_nsx_dataheader_variant_a(self, nsx_nb, filesize=None, offset=None):
"""
Expand All @@ -971,32 +982,46 @@ def __read_nsx_dataheader_variant_b(
Reads the nsx data header for each data block following the offset of
file spec 2.2, 2.3, and 3.0.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

filesize = self.__get_file_size(filename)
filesize_bytes = self.__get_file_size(filename)

data_header = {}
index = 0

if offset is None:
offset = self.__nsx_basic_header[nsx_nb]["bytes_in_headers"]

while offset < filesize:
dh = self.__read_nsx_dataheader(nsx_nb, offset)
data_header[index] = {
"header": dh["header"],
"timestamp": dh["timestamp"],
"nb_data_points": dh["nb_data_points"],
"offset_to_data_block": offset + dh.dtype.itemsize,
offset_to_first_data_block = int(self.__nsx_basic_header[nsx_nb]["bytes_in_headers"])
else:
offset_to_first_data_block = int(offset)

channel_count = int(self.__nsx_basic_header[nsx_nb]["channel_count"])
current_offset_bytes = offset_to_first_data_block
data_block_index = 0
while current_offset_bytes < filesize_bytes:
packet_header = self.__read_nsx_dataheader(nsx_nb, current_offset_bytes)
header_flag = packet_header["header_flag"]
# NSX data blocks must have header_flag = 1, other values indicate file corruption
if header_flag != 1:
raise ValueError(
f"Invalid NSX data block header at offset {current_offset_bytes:#x} in ns{nsx_nb} file. "
f"Expected header_flag=1, got {header_flag}. "
f"This may indicate file corruption or unsupported NSX format variant. "
f"Block index: {data_block_index}, File size: {filesize_bytes} bytes"
)
timestamp = packet_header["timestamp"]
num_data_points = int(packet_header["nb_data_points"])
offset_to_data_block_start = current_offset_bytes + packet_header.dtype.itemsize

data_header[data_block_index] = {
"header": header_flag,
"timestamp": timestamp,
"nb_data_points": num_data_points,
"offset_to_data_block": offset_to_data_block_start,
}

# data size = number of data points * (2bytes * number of channels)
# use of `int` avoids overflow problem
data_size = int(dh["nb_data_points"]) * int(self.__nsx_basic_header[nsx_nb]["channel_count"]) * 2
# define new offset (to possible next data block)
offset = int(data_header[index]["offset_to_data_block"]) + data_size
# Jump to the next data block, the data is encoded as int16
data_block_size_bytes = num_data_points * channel_count * np.dtype("int16").itemsize
current_offset_bytes = offset_to_data_block_start + data_block_size_bytes

index += 1
data_block_index += 1

return data_header

Expand Down Expand Up @@ -1082,19 +1107,20 @@ def __read_nsx_data_variant_b(self, nsx_nb):
Extract nsx data (blocks) from a 2.2, 2.3, or 3.0 .nsx file.
Blocks can arise if the recording was paused by the user.
"""
filename = ".".join([self._filenames["nsx"], f"ns{nsx_nb}"])
filename = f"{self._filenames['nsx']}.ns{nsx_nb}"

data = {}
for data_bl in self.__nsx_data_header[nsx_nb].keys():
data_header = self.__nsx_data_header[nsx_nb]
number_of_channels = int(self.__nsx_basic_header[nsx_nb]["channel_count"])

for data_block in data_header.keys():
# get shape and offset of data
shape = (
int(self.__nsx_data_header[nsx_nb][data_bl]["nb_data_points"]),
int(self.__nsx_basic_header[nsx_nb]["channel_count"]),
)
offset = int(self.__nsx_data_header[nsx_nb][data_bl]["offset_to_data_block"])
number_of_samples = int(data_header[data_block]["nb_data_points"])
shape = (number_of_samples, number_of_channels)
offset = int(data_header[data_block]["offset_to_data_block"])

# read data
data[data_bl] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")
data[data_block] = np.memmap(filename, dtype="int16", shape=shape, offset=offset, mode="r")

return data

Expand Down
Loading