Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 73 additions & 87 deletions fs2/tests/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from string import ascii_lowercase
from tempfile import TemporaryDirectory
from typing import Callable
from unittest import TestCase

import torch
from everyvoice.config.shared_types import ContactInformation
Expand All @@ -18,29 +17,28 @@
from ..type_definitions import SynthesizeOutputFormats


class TestDuplicateFilename(TestCase):
def setUp(self):
self.contact = ContactInformation(
contact_name="Test Runner", contact_email="info@everyvoice.ca"
)
self.output_key = "output"
self.outputs = {
self.output_key: torch.ones([3, 500, 80], device="cpu"),
"duration_prediction": torch.ones([3, 7], device="cpu"),
"tgt_lens": [490, 490, 490],
}
self.batch1 = {
"basename": ["This is a chunk", "This is another chunk", "This is a chunk"],
"raw_text": ["This is a chunk", "This is another chunk", "This is a chunk"],
"text": [
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
],
"speaker": ["S1", "S1", "S1"],
"language": ["L1", "L1", "L1"],
"is_last_input_chunk": [0, 1, 1],
}
class TestDuplicateFilename:
contact = ContactInformation(
contact_name="Test Runner", contact_email="info@everyvoice.ca"
)
output_key = "output"
outputs = {
output_key: torch.ones([3, 500, 80], device="cpu"),
"duration_prediction": torch.ones([3, 7], device="cpu"),
"tgt_lens": [490, 490, 490],
}
batch1 = {
"basename": ["This is a chunk", "This is another chunk", "This is a chunk"],
"raw_text": ["This is a chunk", "This is another chunk", "This is a chunk"],
"text": [
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
torch.IntTensor([2, 3, 4, 5, 6, 7, 8], device="cpu"),
],
"speaker": ["S1", "S1", "S1"],
"language": ["L1", "L1", "L1"],
"is_last_input_chunk": [0, 1, 1],
}

def test_duplicate_filename(self):
"""
Expand Down Expand Up @@ -77,28 +75,25 @@ def test_duplicate_filename(self):
)
output_dir = writer.save_dir
# print(output_dir, *output_dir.glob("**")) # For debugging
self.assertTrue(output_dir.exists())
self.assertTrue(
(
output_dir
/ "This-is-a-chunkThis--9fc7184d--S1--L1--ckpt=77--v_ckpt=10--pred.wav"
).exists()
)
self.assertTrue(
(
output_dir / "This-is-a-chunk--S1--L1--ckpt=77--v_ckpt=10--pred.wav"
).exists()
)
assert output_dir.exists()
assert (
output_dir
/ "This-is-a-chunkThis--9fc7184d--S1--L1--ckpt=77--v_ckpt=10--pred.wav"
).exists()
assert (
output_dir / "This-is-a-chunk--S1--L1--ckpt=77--v_ckpt=10--pred.wav"
).exists()


class ChunkingTestBase(TestCase):
class ChunkingTestBase:
# Type declaractions only, values are injected by setup_class
get_test_callback: Callable
outputs: dict
batch1: dict
batch2: dict

@classmethod
def setUpClass(cls):
def setup_class(cls):
# Define the function that gets the callbacks, get_test_callback
with TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
Expand Down Expand Up @@ -190,7 +185,7 @@ def test_wav_chunks(self):
)
output_dir = writer.save_dir
# print(output_dir, *output_dir.glob("**")) # For debugging
self.assertTrue(output_dir.exists())
assert output_dir.exists()

# Batch 2
writer = next(iter(writers.values()))
Expand All @@ -204,19 +199,16 @@ def test_wav_chunks(self):
)

# Test that the correctly named files were outputted
self.assertTrue(
(output_dir / "one--S1--L1--ckpt=77--v_ckpt=10--pred.wav").exists()
)
self.assertTrue(
(output_dir / "twothreefour--S2--L2--ckpt=77--v_ckpt=10--pred.wav").exists()
)
assert (output_dir / "one--S1--L1--ckpt=77--v_ckpt=10--pred.wav").exists()
assert (
output_dir / "twothreefour--S2--L2--ckpt=77--v_ckpt=10--pred.wav"
).exists()

# Tests that last_file_written contains the correct most recent filename written
# This is important for the demo
self.assertEqual(
(output_dir / "twothreefour--S2--L2--ckpt=77--v_ckpt=10--pred.wav"),
Path(writer.last_file_written),
)
assert (
output_dir / "twothreefour--S2--L2--ckpt=77--v_ckpt=10--pred.wav"
) == Path(writer.last_file_written)

# Checks that the files have reasonable lengths
output_one = AudioSegment.from_file(
Expand All @@ -228,7 +220,7 @@ def test_wav_chunks(self):

# There are four chunks but two outputs.
# Output one contains only one chunk, so output_two should be 3 times longer
self.assertEqual(len(output_one) * 3, len(output_two))
assert len(output_one) * 3 == len(output_two)


class TestWritingSpec(ChunkingTestBase):
Expand All @@ -250,7 +242,7 @@ def test_spec_chunks(self):
)
output_dir = writer.save_dir
# print(output_dir, *output_dir.glob("**")) # For debugging
self.assertTrue(output_dir.exists())
assert output_dir.exists()

# Batch 2
writer = next(iter(writers.values()))
Expand All @@ -264,14 +256,10 @@ def test_spec_chunks(self):
)

# Test that the correctly named files were outputted
self.assertTrue(
(output_dir / "one--S1--L1--spec-pred-22050-mel-librosa.pt").exists()
)
self.assertTrue(
(
output_dir / "twothreefour--S2--L2--spec-pred-22050-mel-librosa.pt"
).exists()
)
assert (output_dir / "one--S1--L1--spec-pred-22050-mel-librosa.pt").exists()
assert (
output_dir / "twothreefour--S2--L2--spec-pred-22050-mel-librosa.pt"
).exists()

# Checks that the files have reasonable lengths
output_one = torch.load(
Expand All @@ -283,7 +271,7 @@ def test_spec_chunks(self):

# There are four chunks but two outputs.
# Output one contains only one chunk, so output_two should be 3 times longer
self.assertEqual(output_one.size(-1) * 3, output_two.size(-1))
assert output_one.size(-1) * 3 == output_two.size(-1)


class TestWritingTextGrid(ChunkingTestBase):
Expand All @@ -305,7 +293,7 @@ def test_textgrid_chunks(self):
)
output_dir = writer.save_dir
# print(output_dir, *output_dir.glob("**/*")) # For debugging
self.assertTrue(output_dir.exists())
assert output_dir.exists()

# Batch 2
writer = next(iter(writers.values()))
Expand All @@ -318,12 +306,10 @@ def test_textgrid_chunks(self):
_dataloader_idx=1,
)
# Test that the correctly named files were outputted
self.assertTrue(
(output_dir / "one--S1--L1--22050-mel-librosa.TextGrid").exists()
)
self.assertTrue(
(output_dir / "twothreefour--S2--L2--22050-mel-librosa.TextGrid").exists()
)
assert (output_dir / "one--S1--L1--22050-mel-librosa.TextGrid").exists()
assert (
output_dir / "twothreefour--S2--L2--22050-mel-librosa.TextGrid"
).exists()

# Check that the correct words were added to the first TextGrid
tg = TextGrid(
Expand All @@ -333,11 +319,11 @@ def test_textgrid_chunks(self):

phones = [interval[2] for interval in tiers[0].get_all_intervals()]
for phone, char in zip(list(phones), list("one")):
self.assertEqual(phone, char)
assert phone == char

words = tiers[2].get_all_intervals()
self.assertEqual(len(words), 1)
self.assertEqual(words[0][2], "one")
assert len(words) == 1
assert words[0][2] == "one"

# Check that the correct words were added to the second TextGrid
tg = TextGrid(
Expand All @@ -347,17 +333,17 @@ def test_textgrid_chunks(self):

phones = [interval[2] for interval in tiers[0].get_all_intervals()]
for phone, char in zip(list(phones), list("twothreefour")):
self.assertEqual(phone, char)
assert phone == char

words = tiers[2].get_all_intervals()
self.assertEqual(len(words), 3)
self.assertEqual(words[0][2], "two")
self.assertEqual(words[1][2], "three")
self.assertEqual(words[2][2], "four")
assert len(words) == 3
assert words[0][2] == "two"
assert words[1][2] == "three"
assert words[2][2] == "four"


class TestWritingReadAlongXML(ChunkingTestBase):
def test_writing_readalong(self):
def test_writing_readalong(self, subtests):
writers = self.get_test_callback([SynthesizeOutputFormats.readalong_xml])

# Batch 1
Expand All @@ -372,7 +358,7 @@ def test_writing_readalong(self):
)
output_dir = writer.save_dir

self.assertTrue(output_dir.exists())
assert output_dir.exists()

# Batch 2
writer = next(iter(writers.values()))
Expand All @@ -392,17 +378,17 @@ def test_writing_readalong(self):
output_dir / "twothreefour--S2--L2--22050-mel-librosa.readalong",
)
for output_file in output_files:
with self.subTest(output_file=output_file):
self.assertTrue(output_file.exists())
with subtests.test(output_file=output_file):
assert output_file.exists()
with open(output_file, "r", encoding="utf8") as f:
readalong = f.read()
# print(readalong)
self.assertIn("<read-along", readalong)
self.assertIn('<w time="0.0" dur=', readalong)
assert "<read-along" in readalong
assert '<w time="0.0" dur=' in readalong


class TestWritingReadAlongHTML(ChunkingTestBase):
def test_writing_readalong(self) -> None:
def test_writing_readalong(self, subtests) -> None:
writers = self.get_test_callback([SynthesizeOutputFormats.readalong_html])

for writer in writers.values():
Expand All @@ -416,7 +402,7 @@ def test_writing_readalong(self) -> None:
_dataloader_idx=idx,
)
output_dir = writer.save_dir
self.assertTrue(output_dir.exists())
assert output_dir.exists()

# Test that the correctly named files were outputted
# print(output_dir, *output_dir.glob("**/*")) # For debugging
Expand All @@ -426,10 +412,10 @@ def test_writing_readalong(self) -> None:
)
for output_file_basename in output_file_basenames:
output_file = output_dir.parent / "readalongs" / output_file_basename
with self.subTest(output_file=output_file):
self.assertTrue(output_file.exists())
with subtests.test(output_file=output_file):
assert output_file.exists()
with open(output_file, "r", encoding="utf8") as f:
readalong = f.read()
# print(readalong)
self.assertIn("<read-along", readalong)
self.assertIn("<span slot", readalong)
assert "<read-along" in readalong
assert "<span slot" in readalong
Loading
Loading