Skip to content

Commit b3b96a1

Browse files
authored
Merge pull request #677 from lhotse-speech/feature/fix-assertion
Fix assertions, rename variables
2 parents dc0c86b + 66a35b0 commit b3b96a1

File tree

4 files changed

+27
-24
lines changed

4 files changed

+27
-24
lines changed

lhotse/bin/modes/features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ def upload(
277277

278278

279279
def _upload_one(item: Features, url: str) -> Features:
280-
feats_mtx = item.load()
280+
feats_mat = item.load()
281281
feats_writer = LilcomURLWriter(url)
282-
new_key = feats_writer.write(key=item.storage_key, value=feats_mtx)
282+
new_key = feats_writer.write(key=item.storage_key, value=feats_mat)
283283
return fastcopy(
284284
item, storage_path=url, storage_key=new_key, storage_type=feats_writer.name
285285
)

lhotse/cut.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4650,37 +4650,37 @@ def compute_and_store_features_batch(
46504650
waves, sampling_rate=cuts[0].sampling_rate
46514651
)
46524652

4653-
for cut, feat_mtx in zip(cuts, features):
4653+
for cut, feat_mat in zip(cuts, features):
46544654
if isinstance(cut, PaddingCut):
46554655
# For padding cuts, just fill out the fields in the manfiest
46564656
# and don't store anything.
46574657
cuts_writer.write(
46584658
fastcopy(
46594659
cut,
4660-
num_frames=feat_mtx.shape[0],
4661-
num_features=feat_mtx.shape[1],
4660+
num_frames=feat_mat.shape[0],
4661+
num_features=feat_mat.shape[1],
46624662
frame_shift=frame_shift,
46634663
)
46644664
)
46654665
continue
46664666
# Store the computed features and describe them in a manifest.
4667-
if isinstance(feat_mtx, torch.Tensor):
4668-
feat_mtx = feat_mtx.cpu().numpy()
4669-
storage_key = feats_writer.write(cut.id, feat_mtx)
4667+
if isinstance(feat_mat, torch.Tensor):
4668+
feat_mat = feat_mat.cpu().numpy()
4669+
storage_key = feats_writer.write(cut.id, feat_mat)
46704670
feat_manifest = Features(
46714671
start=cut.start,
46724672
duration=cut.duration,
46734673
type=extractor.name,
4674-
num_frames=feat_mtx.shape[0],
4675-
num_features=feat_mtx.shape[1],
4674+
num_frames=feat_mat.shape[0],
4675+
num_features=feat_mat.shape[1],
46764676
frame_shift=frame_shift,
46774677
sampling_rate=cut.sampling_rate,
46784678
channels=0,
46794679
storage_type=feats_writer.name,
46804680
storage_path=str(feats_writer.storage_path),
46814681
storage_key=storage_key,
46824682
)
4683-
validate_features(feat_manifest, feats_data=feat_mtx)
4683+
validate_features(feat_manifest, feats_data=feat_mat)
46844684

46854685
# Update the cut manifest.
46864686
if isinstance(cut, MonoCut):

lhotse/features/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ def __init__(
824824

825825
if "b" not in mode:
826826
mode = mode + "b"
827-
assert mode == "wb" or "ab"
827+
assert mode in ("wb", "ab")
828828

829829
# ".lca" -> "lilcom chunky archive"
830830
self.storage_path_ = Path(storage_path).with_suffix(".lca")

lhotse/qa.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from lhotse.array import Array, TemporalArray
99
from lhotse.audio import Recording, RecordingSet
10-
from lhotse.cut import Cut, CutSet, MixedCut, PaddingCut
10+
from lhotse.cut import Cut, CutSet, MixedCut, MonoCut, PaddingCut
1111
from lhotse.features import FeatureSet, Features
1212
from lhotse.supervision import SupervisionSegment, SupervisionSet
1313
from lhotse.utils import compute_num_frames, overlaps
@@ -379,17 +379,20 @@ def validate_cut(c: Cut, read_data: bool = False) -> None:
379379
c.num_samples == samples.shape[1]
380380
), f"MonoCut {c.id}: expected {c.num_samples} samples, got {samples.shape[1]}"
381381

382-
# Conditions related to supervisions
383-
for s in c.supervisions:
384-
validate_supervision(s)
385-
assert s.recording_id == c.recording_id, (
386-
f"MonoCut {c.id}: supervision {s.id} has a mismatched recording_id "
387-
f"(expected {c.recording_id}, supervision has {s.recording_id})"
388-
)
389-
assert s.channel == c.channel, (
390-
f"MonoCut {c.id}: supervision {s.id} has a mismatched channel "
391-
f"(expected {c.channel}, supervision has {s.channel})"
392-
)
382+
# Conditions related to supervisions.
383+
# We only validate those for MonoCut; PaddingCut doesn't have supervisions,
384+
# and MixedCut may consist of more than one recording/channel.
385+
if isinstance(c, MonoCut):
386+
for s in c.supervisions:
387+
validate_supervision(s)
388+
assert s.recording_id == c.recording_id, (
389+
f"MonoCut {c.id}: supervision {s.id} has a mismatched recording_id "
390+
f"(expected {c.recording_id}, supervision has {s.recording_id})"
391+
)
392+
assert s.channel == c.channel, (
393+
f"MonoCut {c.id}: supervision {s.id} has a mismatched channel "
394+
f"(expected {c.channel}, supervision has {s.channel})"
395+
)
393396

394397
# Conditions related to custom fields
395398
if c.custom is not None:

0 commit comments

Comments
 (0)