diff --git a/everyvoice/model/aligner/wav2vec2aligner b/everyvoice/model/aligner/wav2vec2aligner index 98741ea4..30b92ad6 160000 --- a/everyvoice/model/aligner/wav2vec2aligner +++ b/everyvoice/model/aligner/wav2vec2aligner @@ -1 +1 @@ -Subproject commit 98741ea470ad5675badc62a22361baf53433948b +Subproject commit 30b92ad6ca4b73dec90f0c4064aaa339b42cde49 diff --git a/everyvoice/model/feature_prediction/FastSpeech2_lightning b/everyvoice/model/feature_prediction/FastSpeech2_lightning index 5840db6a..0a95bb80 160000 --- a/everyvoice/model/feature_prediction/FastSpeech2_lightning +++ b/everyvoice/model/feature_prediction/FastSpeech2_lightning @@ -1 +1 @@ -Subproject commit 5840db6a63c09081dd8d9de02d33db91096d4628 +Subproject commit 0a95bb80c64fa5a88bd6376e54277fbfef84d5c9 diff --git a/everyvoice/tests/preprocessed_audio_fixture.py b/everyvoice/tests/preprocessed_audio_fixture.py index 842c03c8..4efbcbc1 100644 --- a/everyvoice/tests/preprocessed_audio_fixture.py +++ b/everyvoice/tests/preprocessed_audio_fixture.py @@ -19,7 +19,7 @@ class PreprocessedAudioFixture: """ - A unittest fixture to preprocess the audio files. + A unittest fixture (implemented as a base class) to preprocess the audio files. """ _tempdir = tempfile.TemporaryDirectory(prefix="tmpdir_PreprocessedInputFixture_") @@ -67,12 +67,8 @@ class PreprocessedAudioFixture: preprocessor = Preprocessor(fp_config) - # def setUp(self): - # """Each test function should get a fresh preprocessor""" - # self.preprocessor = Preprocessor(self.fp_config) - @classmethod - def setUpClass(cls): + def setup_class(cls): """Generate a preprocessed test set that can be used in various test cases.""" # We only need to actually run this once if not PreprocessedAudioFixture._preprocess_ran: diff --git a/everyvoice/tests/regression/test-demo-app-mix.py b/everyvoice/tests/regression/test-demo-app-mix.py index b7580df5..0f51d8a5 100644 --- a/everyvoice/tests/regression/test-demo-app-mix.py +++ b/everyvoice/tests/regression/test-demo-app-mix.py @@ -30,7 +30,7 @@ def test_rundemo(self) -> None: with page.expect_download() as download_info: page.get_by_label("Download").click() download = download_info.value - self.assertTrue(download.suggested_filename.endswith(".wav")) + assert download.suggested_filename.endswith(".wav") page.get_by_label("Output Format").click() page.get_by_label("spec").click() synthesize_button.click() @@ -38,28 +38,28 @@ def test_rundemo(self) -> None: with page.expect_download() as download1_info: page.locator("#file_output").get_by_role("link").click() download = download1_info.value - self.assertTrue(download.suggested_filename.endswith(".pt")) + assert download.suggested_filename.endswith(".pt") page.get_by_label("Output Format").click() page.get_by_label("textgrid").click() synthesize_button.click() with page.expect_download() as download2_info: page.locator("#file_output").get_by_role("link").click() download = download2_info.value - self.assertTrue(download.suggested_filename.endswith(".TextGrid")) + assert download.suggested_filename.endswith(".TextGrid") page.get_by_label("Output Format").click() page.get_by_label("readalong-xml").click() synthesize_button.click() with page.expect_download() as download3_info: page.locator("#file_output").get_by_role("link").click() download = download3_info.value - self.assertTrue(download.suggested_filename.endswith(".readalong")) + assert download.suggested_filename.endswith(".readalong") page.get_by_label("Output Format").click() page.get_by_label("readalong-html").click() synthesize_button.click() with page.expect_download() as download4_info: page.locator("#file_output").get_by_role("link").click() download = download4_info.value - self.assertTrue(download.suggested_filename.endswith(".html")) + assert download.suggested_filename.endswith(".html") page.get_by_label("Language").click() page.get_by_label("und").click() page.get_by_label("Speaker", exact=True).click() diff --git a/everyvoice/tests/stubs.py b/everyvoice/tests/stubs.py index ac9fd7f0..740dd121 100644 --- a/everyvoice/tests/stubs.py +++ b/everyvoice/tests/stubs.py @@ -356,7 +356,7 @@ def __exit__(self, *_exc_info): def flatten_log(log_output: str) -> str: """Replace newlines and other sequences of whitespace by a single space. - Usage: self.assertIn("some text", flatten_log(captured_output)) + Usage: assert "some text" in flatten_log(captured_output) Avoids having to use self.assertRegex everywhere just because of rich or pretty printing of messages over multiple lines. diff --git a/everyvoice/tests/test_cli.py b/everyvoice/tests/test_cli.py index 3b6b4d03..c1ae1474 100755 --- a/everyvoice/tests/test_cli.py +++ b/everyvoice/tests/test_cli.py @@ -153,8 +153,8 @@ def get_dummy_models(cls) -> tuple[Path, Path]: def test_version(self): result = self.runner.invoke(app, ["--version"]) - self.assertEqual(result.exit_code, 0) - self.assertIn(VERSION, result.stdout) + assert result.exit_code == 0 + assert VERSION in result.stdout def test_submodule_versions(self): # Team decision 2025-02-10: we won't keep submodule versions in complete lockstep, @@ -181,12 +181,12 @@ def test_submodule_versions(self): def test_diagnostic(self): with capture_stdout(): result = self.runner.invoke(app, ["--diagnostic"]) - self.assertEqual(result.exit_code, 0) - self.assertIn("EveryVoice version", result.stdout) - self.assertIn("Python version", result.stdout) + assert result.exit_code == 0 + assert "EveryVoice version" in result.stdout + assert "Python version" in result.stdout # We can't really validate the whole dependency list, but we should at least find torch # [5:] ignores the header generated by everyvoice --diagnostic and only looks at deps - self.assertIn("torch", "".join(result.stdout.lower().splitlines()[5:])) + assert "torch" in "".join(result.stdout.lower().splitlines()[5:]) def wip_test_synthesize(self): # TODO: Here's a stub for getting synthesis unit tests working @@ -229,7 +229,7 @@ def wip_test_synthesize(self): "wav", ], ) - self.assertEqual(single_text_result.exit_code, 0) + assert single_text_result.exit_code == 0 self.assertEqual( len(list(Path("single_text/wav").glob("*.wav"))), 1 ) # assert synthesizes a single file @@ -249,7 +249,7 @@ def wip_test_synthesize(self): "wav", ], ) - self.assertEqual(filelist_result.exit_code, 0) + assert filelist_result.exit_code == 0 self.assertEqual( len(list((tmpdir / "filelist" / "wav").glob("*.wav"))), 3 ) # assert synthesizes three files @@ -258,16 +258,16 @@ def test_commands_present(self): result = self.runner.invoke(app, ["--help"]) # each command has some help for command in self.commands: - self.assertIn(command, result.stdout) + assert command in result.stdout # link to docs is present - self.assertIn("https://docs.everyvoice.ca", result.stdout) + assert "https://docs.everyvoice.ca" in result.stdout def test_command_help_messages(self): for command in self.commands: result = self.runner.invoke(app, [command, "--help"]) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 result = self.runner.invoke(app, [command, "-h"]) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 def test_update_schema(self): dummy_contact = ContactInformation( @@ -317,8 +317,8 @@ def test_update_schema(self): # Next, but only if everything above passed, we make sure we can't overwrite # existing schemas by accident. result = self.runner.invoke(app, ["update-schemas"]) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("FileExistsError", str(result)) + assert result.exit_code != 0 + assert "FileExistsError" in str(result) def test_evaluate(self): result = self.runner.invoke( @@ -331,12 +331,12 @@ def test_evaluate(self): self.data_dir / "lj" / "wavs" / "LJ050-0269.wav", ], ) - self.assertEqual(result.exit_code, 0) - self.assertIn("LJ010-0008", result.stdout) - self.assertIn("STOI", result.stdout) - self.assertIn("MOS", result.stdout) - self.assertIn("SI-SDR", result.stdout) - self.assertIn("PESQ", result.stdout) + assert result.exit_code == 0 + assert "LJ010-0008" in result.stdout + assert "STOI" in result.stdout + assert "MOS" in result.stdout + assert "SI-SDR" in result.stdout + assert "PESQ" in result.stdout dir_result = self.runner.invoke( app, [ @@ -347,22 +347,22 @@ def test_evaluate(self): self.data_dir / "LJ010-0008.wav", ], ) - self.assertEqual(dir_result.exit_code, 0) - self.assertIn("LJ050-0269", dir_result.stdout, "should print out the basenames") + assert dir_result.exit_code == 0 + assert "LJ050-0269", dir_result.stdout in "should print out the basenames" self.assertIn( "Average STOI", dir_result.stdout, "should report metrics in terms of averages", ) evaluation_output = self.data_dir / "lj" / "wavs" / "evaluation.json" - self.assertTrue(evaluation_output.exists(), "should print results to a file") + assert evaluation_output.exists(), "should print results to a file" evaluation_output.unlink() def test_old_inspect_checkpoint(self): result = self.runner.invoke( app, ["inspect-checkpoint", str(self.data_dir / "test.ckpt")] ) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 self.assertIn( "This command has been renamed to `everyvoice checkpoint inspect`", flatten_log(result.stdout), @@ -370,19 +370,19 @@ def test_old_inspect_checkpoint(self): def test_inspect_checkpoint_help(self): result = self.runner.invoke(app, ["checkpoint", "inspect", "--help"]) - self.assertIn("checkpoint inspect [OPTIONS] MODEL_PATH", result.stdout) + assert "checkpoint inspect [OPTIONS] MODEL_PATH" in result.stdout def test_inspect_checkpoint(self): result = self.runner.invoke( app, ["checkpoint", "inspect", str(self.data_dir / "test.ckpt")] ) - self.assertIn('global_step": 52256', result.stdout) + assert 'global_step": 52256' in result.stdout self.assertIn( "We couldn't read your file, possibly because the version of EveryVoice that created it is incompatible with your installed version.", result.stdout, ) - self.assertIn("It appears to have 0.0 M parameters.", result.stdout) - self.assertIn("Number of Parameters", result.stdout) + assert "It appears to have 0.0 M parameters." in result.stdout + assert "Number of Parameters" in result.stdout def test_inspect_not_a_checkpoint(self) -> None: result = self.runner.invoke(app, ["checkpoint", "inspect", os.devnull]) @@ -433,7 +433,7 @@ def test_preprocessing_with_wrong_config(self): str(self.config_dir / "everyvoice-spec-to-wav.yaml"), ], ) - self.assertEqual(result.exit_code, 1) + assert result.exit_code == 1 self.assertIn( "We are expecting a FastSpeech2Config but it looks like you provided a HiFiGANConfig", "\n".join(output), @@ -455,14 +455,14 @@ def test_expensive_imports_are_tucked_away(self): def test_demo_with_bad_args(self): result = self.runner.invoke(app, ["demo"]) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("Missing argument", result.output) + assert result.exit_code != 0 + assert "Missing argument" in result.output result = self.runner.invoke( app, ["demo", os.devnull, os.devnull, "--output-format", "not-a-format"] ) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("Invalid value", result.output) + assert result.exit_code != 0 + assert "Invalid value" in result.output EMPTY_DEMO_ARGS: dict[str, Any] = { "languages": [], @@ -481,7 +481,7 @@ def test_create_demo_app_with_errors(self): **self.EMPTY_DEMO_ARGS, # type: ignore[arg-type] outputs=[], ) - self.assertIn("Empty outputs list", str(cm.exception)) + assert "Empty outputs list" in str(cm.exception) class WrongEnum(str, enum.Enum): foo = "foo" @@ -494,13 +494,13 @@ class WrongEnum(str, enum.Enum): **self.EMPTY_DEMO_ARGS, # type: ignore[arg-type] outputs=outputs, ) - self.assertIn("Unknown output format 'foo'", str(cm.exception)) + assert "Unknown output format 'foo'" in str(cm.exception) def test_demo_with_bad_models(self) -> None: devnull = Path(os.devnull) with self.assertRaises(ValueError) as cm: create_demo_app(devnull, devnull, **self.EMPTY_DEMO_ARGS, outputs=["wav"]) # type: ignore[arg-type] - self.assertIn("It does not appear to be a valid checkpoint", str(cm.exception)) + assert "It does not appear to be a valid checkpoint" in str(cm.exception) with self.assertRaises(ValueError) as cm: create_demo_app( @@ -509,13 +509,13 @@ def test_demo_with_bad_models(self) -> None: **self.EMPTY_DEMO_ARGS, # type: ignore[arg-type] outputs=["wav"], ) - self.assertIn("maybe it's not actually a HiFiGAN model", str(cm.exception)) + assert "maybe it's not actually a HiFiGAN model" in str(cm.exception) def test_demo_with_wrong_models(self) -> None: fp_path, vocoder_path = self.get_dummy_models() with self.assertRaises(ValueError) as cm: create_demo_app(fp_path, fp_path, **self.EMPTY_DEMO_ARGS, outputs=["wav"]) # type: ignore[arg-type] - self.assertIn("maybe it's not actually a HiFiGAN model", str(cm.exception)) + assert "maybe it's not actually a HiFiGAN model" in str(cm.exception) with self.assertRaises(ValueError) as cm: create_demo_app( @@ -524,7 +524,7 @@ def test_demo_with_wrong_models(self) -> None: **self.EMPTY_DEMO_ARGS, # type: ignore[arg-type] outputs=["wav"], ) - self.assertIn("maybe it's not actually an fs2 model", str(cm.exception)) + assert "maybe it's not actually an fs2 model" in str(cm.exception) def test_g2p(self): result = self.runner.invoke( @@ -538,8 +538,8 @@ def test_g2p(self): ], ) - self.assertEqual(result.exit_code, 0) - self.assertIn("['hello', 'world']", result.stdout) + assert result.exit_code == 0 + assert "['hello', 'world']" in result.stdout self.assertNotIn("['HELLO', 'WORLD']", result.stdout) def mock_create_demo_app(self, *_args, **_kwargs): @@ -605,10 +605,10 @@ def test_create_demo_app(self): ip, # Mock IP address ], ) - self.assertEqual(result.exit_code, 0) - self.assertIn(f" - Port: {port}", result.output) - self.assertIn(" - Share: True", result.output) - self.assertIn(f" - Server Name: {ip}", result.output) + assert result.exit_code == 0 + assert f" - Port: {port}" in result.output + assert " - Share: True" in result.output + assert f" - Server Name: {ip}" in result.output def mock_demo_load_model_from_checkpoint( *_arg, **kwargs @@ -729,7 +729,7 @@ def test_create_demo_app_with_ui_config_file(self) -> None: ) # print(result.output, result.exit_code) # Debug output - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 self.assertIn( f"Using speakers from app config JSON: [('{config['speakers']['default']}', 'default')]", result.output, @@ -813,7 +813,7 @@ def test_create_demo_app_with_malformed_ui_config_file(self): ], ) # print(result.output, result.exit_code) # Debug output - self.assertNotEqual(result.exit_code, 0) + assert result.exit_code != 0 self.assertRegex( result.output, r"(?s)Your config file.*malformed.*has.*errors" ) @@ -918,7 +918,7 @@ def test_rename_speaker(self): ], ) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 self.assertIn( "Renamed speaker 'old_speaker' to 'new_speaker'.", result.output ) @@ -953,8 +953,8 @@ def test_rename_speaker_with_non_existing_speaker(self): ], ) # print(result.output) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("Speaker 'non_existing_speaker' not found", result.output) + assert result.exit_code != 0 + assert "Speaker 'non_existing_speaker' not found" in result.output def test_rename_speaker_with_no_speakers(self): with tempfile.TemporaryDirectory() as tmpdir_str: @@ -979,8 +979,8 @@ def test_rename_speaker_with_no_speakers(self): ], ) # print(result.output) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("No speakers found", result.output) + assert result.exit_code != 0 + assert "No speakers found" in result.output class TestBaseCLIHelper(TestCase): @@ -1002,10 +1002,10 @@ def test_save_configuration_to_log_dir(self): log_dir = config.training.logger.save_dir / config.training.logger.name log = log_dir / "log" - self.assertTrue(log.exists()) + assert log.exists() hparams = log_dir / "hparams.yaml" - self.assertTrue(hparams.exists()) + assert hparams.exists() with hparams.open(mode="r", encoding="UTF8") as f: config_reloaded = yaml.load(f, Loader=Loader) self.assertEqual( diff --git a/everyvoice/tests/test_configs.py b/everyvoice/tests/test_configs.py index 30b81c5a..3e015c9a 100755 --- a/everyvoice/tests/test_configs.py +++ b/everyvoice/tests/test_configs.py @@ -88,9 +88,9 @@ def test_from_object(self): vocoder=VocoderConfig(contact=TEST_CONTACT), training=E2ETrainingConfig(batch_size=32), ) - self.assertEqual(config_default.training.batch_size, 16) - self.assertEqual(config_declared.training.batch_size, 16) - self.assertEqual(config_32.training.batch_size, 32) + assert config_default.training.batch_size == 16 + assert config_declared.training.batch_size == 16 + assert config_32.training.batch_size == 32 def test_config_save_dirs(self): with TemporaryDirectory(prefix="test_config_save_dirs") as tempdir: @@ -100,7 +100,7 @@ def test_config_save_dirs(self): preprocessing_config = PreprocessingConfig(save_dir=1) with init_context({"writing_config": tempdir}): preprocessing_config = PreprocessingConfig(save_dir="./bloop") - self.assertTrue((tempdir / preprocessing_config.save_dir).exists()) + assert (tempdir / preprocessing_config.save_dir).exists() def test_config_partial(self): with ( @@ -113,7 +113,7 @@ def test_config_partial(self): config = PreprocessingConfig( path_to_audio_config_file=(tempdir / "audio.json") ) - self.assertTrue(isinstance(config.audio, AudioConfig)) + assert isinstance(config.audio, AudioConfig) # bad partial with self.assertRaises(exceptions.InvalidConfiguration): with NamedTemporaryFile(prefix="test", mode="w", suffix=".yaml") as tf: @@ -141,7 +141,7 @@ def test_config_partial(self): path_to_training_config_file=tempdir / "fp-training.json", ) _writer_helper(fp_config, tempdir / "fp.json") - self.assertTrue(isinstance(fp_config, FeaturePredictionConfig)) + assert isinstance(fp_config, FeaturePredictionConfig) # Vocoder Config _writer_helper( VocoderConfig(contact=TEST_CONTACT).training, @@ -158,7 +158,7 @@ def test_config_partial(self): path_to_training_config_file=tempdir / "vocoder-training.json", ) _writer_helper(vocoder_config, tempdir / "vocoder.json") - self.assertTrue(isinstance(vocoder_config, VocoderConfig)) + assert isinstance(vocoder_config, VocoderConfig) # E2E Config with mute_logger("everyvoice.config.utils"): e2e_config = EveryVoiceConfig( @@ -167,7 +167,7 @@ def test_config_partial(self): path_to_training_config_file=(tempdir / "training.json"), path_to_vocoder_config_file=(tempdir / "vocoder.json"), ) - self.assertTrue(isinstance(e2e_config, EveryVoiceConfig)) + assert isinstance(e2e_config, EveryVoiceConfig) def test_config_partial_override(self): """Test override of partial""" @@ -182,13 +182,13 @@ def test_config_partial_override(self): path_to_audio_config_file=tf.name, audio=AudioConfig(min_audio_length=1.0), ) - self.assertEqual(config.audio.min_audio_length, 1.0) + assert config.audio.min_audio_length == 1.0 # override with dict with mute_logger("everyvoice.config.utils"): config = PreprocessingConfig( path_to_audio_config_file=tf.name, audio={"max_audio_length": 1.0} ) - self.assertEqual(config.audio.max_audio_length, 1.0) + assert config.audio.max_audio_length == 1.0 # pass something invalid with mute_logger("everyvoice.config.utils"): with self.assertRaises(ValidationError): @@ -202,15 +202,15 @@ def test_update_from_file(self): with open(TEST_DATA_DIR / "update.yaml", encoding="utf8") as f: update = yaml.safe_load(f) self.config.update_config(update) - self.assertEqual(self.config.feature_prediction.training.batch_size, 123) - self.assertEqual(self.config.vocoder.training.batch_size, 456) + assert self.config.feature_prediction.training.batch_size == 123 + assert self.config.vocoder.training.batch_size == 456 def test_string_to_callable(self): # Test Basic Functionality config = FeaturePredictionConfig( contact=TEST_CONTACT, text=TextConfig(cleaners=["everyvoice.utils.lower"]) ) - self.assertEqual(config.text.cleaners, [lower]) + assert config.text.cleaners == [lower] # Test missing function with self.assertRaises(AttributeError): config.update_config({"text": {"cleaners": ["everyvoice.utils.foobarfoo"]}}) @@ -222,23 +222,23 @@ def test_string_to_callable(self): config.update_config({"text": {"cleaners": [1]}}) # Test plain string config = LoggerConfig(sub_dir_callable="foobar") - self.assertEqual(config.sub_dir_callable(), "foobar") + assert config.sub_dir_callable() == "foobar" def test_call_sub_dir(self): config = LoggerConfig() # sub_dir should get called from sub_dir_callable and be a string of an int - self.assertTrue(isinstance(int(config.sub_dir), int)) + assert isinstance(int(config.sub_dir), int) # Just in case we're super speedy time.sleep(1) self.assertGreater(int(config.sub_dir_callable()), int(config.sub_dir)) serialized_config = config.model_dump() # Exclude sub_dir by default when serializing as it should get overriden on each run - self.assertTrue("sub_dir" not in serialized_config) + assert "sub_dir" not in serialized_config def test_properly_deserialized_callables(self): config = TextConfig(cleaners=[nfc_normalize, "everyvoice.utils.lower"]) for fn in config.cleaners: - self.assertTrue(isinstance(fn, Callable)) + assert isinstance(fn, Callable) def test_string_to_dict(self): base_config = EveryVoiceConfig( @@ -253,21 +253,21 @@ def test_string_to_dict(self): ] # test_missing = ["training.foobar.gan_type=original"] test_dict = expand_config_string_syntax(test_string) - self.assertEqual(test_dict, {"vocoder": {"training": {"gan_type": "wgan"}}}) + assert test_dict == {"vocoder": {"training": {"gan_type": "wgan"}}} for bs in test_bad_strings: with self.assertRaises(ValueError): expand_config_string_syntax(bs) - self.assertEqual(base_config.vocoder.training.gan_type.value, "original") + assert base_config.vocoder.training.gan_type.value == "original" config = base_config.combine_configs(base_config, test_dict) - self.assertEqual(config["vocoder"]["training"]["gan_type"], "wgan") + assert config["vocoder"]["training"]["gan_type"] == "wgan" def test_changes(self): """Test that the changes to the config are correct""" self.config.update_config( {"feature_prediction": {"text": {"cleaners": ["everyvoice.utils.lower"]}}} ) - self.assertEqual(self.config.feature_prediction.text.cleaners, [lower]) + assert self.config.feature_prediction.text.cleaners == [lower] def test_load_empty_config(self): with NamedTemporaryFile( @@ -316,10 +316,10 @@ def test_shared_sox(self) -> None: feature_prediction=FeaturePredictionConfig(contact=TEST_CONTACT), ) sox_effects = config.vocoder.preprocessing.source_data[0].sox_effects - self.assertEqual(len(config.vocoder.preprocessing.source_data), 4) + assert len(config.vocoder.preprocessing.source_data) == 4 for d_other in config.vocoder.preprocessing.source_data[1:]: - self.assertEqual(sox_effects, d_other.sox_effects) - self.assertEqual(sox_effects[0], ["channels", "1"]) + assert sox_effects == d_other.sox_effects + assert sox_effects[0] == ["channels", "1"] def test_correct_number_typing(self): batch_size = 64.0 @@ -330,7 +330,7 @@ def test_correct_number_typing(self): vocoder=VocoderConfig(contact=TEST_CONTACT), ) self.assertIsInstance(batch_size, float) - self.assertEqual(config.training.batch_size, 64) + assert config.training.batch_size == 64 self.assertIsInstance(config.feature_prediction.training.batch_size, int) @@ -344,8 +344,8 @@ def validate_config_path(self, path: Path): """ Helper method to validate a path once loaded by a config. """ - self.assertTrue(path.is_absolute(), msg=path) - self.assertTrue(path.exists(), msg=path) + assert path.is_absolute(), f"{path} should be absolute" + assert path.exists(), f"{path} should exist" def test_preprocessing_config(self): """Create a PreprocessingConfig which pydantic will validate for us.""" @@ -353,16 +353,16 @@ def test_preprocessing_config(self): with config_path.open("r", encoding="utf8") as f: pre_test = yaml.safe_load(f) self.assertFalse(Path(pre_test["save_dir"]).is_absolute()) - self.assertEqual(len(pre_test["source_data"]), 1) + assert len(pre_test["source_data"]) == 1 for data in pre_test["source_data"]: self.assertFalse(Path(data["data_dir"]).is_absolute()) self.assertFalse(Path(data["filelist"]).is_absolute()) config = PreprocessingConfig.load_config_from_path(config_path) # print(config.model_dump_json(indent=2)) - self.assertTrue(isinstance(config, PreprocessingConfig)) - self.assertEqual(config.dataset, self.DATASET_NAME) + assert isinstance(config, PreprocessingConfig) + assert config.dataset == self.DATASET_NAME self.validate_config_path(config.save_dir) - self.assertEqual(len(config.source_data), 1) + assert len(config.source_data) == 1 for data in config.source_data: self.validate_config_path(data.data_dir) self.validate_config_path(data.filelist) @@ -383,7 +383,7 @@ def test_feature_prediction_config(self): with mute_logger("everyvoice.configs.text_config"): config = FeaturePredictionConfig.load_config_from_path(config_path) # print(config.model_dump_json(indent=2)) - self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME) + assert config.preprocessing.dataset == self.DATASET_NAME self.validate_config_path(config.path_to_text_config_file) self.validate_config_path(config.path_to_text_config_file) self.validate_config_path(config.training.logger.save_dir) @@ -404,8 +404,8 @@ def test_vocoder_config(self): self.assertFalse(Path(training["validation_filelist"]).is_absolute()) config = VocoderConfig.load_config_from_path(config_path) # print(config.model_dump_json(indent=2)) - self.assertTrue(isinstance(config, VocoderConfig)) - self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME) + assert isinstance(config, VocoderConfig) + assert config.preprocessing.dataset == self.DATASET_NAME self.validate_config_path(config.path_to_preprocessing_config_file) self.validate_config_path(config.training.logger.save_dir) self.validate_config_path(config.training.training_filelist) @@ -429,7 +429,7 @@ def test_everyvoice_config(self): with mute_logger("everyvoice.config.text_config"): config = EveryVoiceConfig.load_config_from_path(config_path) # print(config.model_dump_json(indent=2)) - self.assertTrue(isinstance(config, EveryVoiceConfig)) + assert isinstance(config, EveryVoiceConfig) self.assertEqual( config.feature_prediction.preprocessing.dataset, self.DATASET_NAME ) @@ -483,15 +483,15 @@ def test_everyvoice_config(self): # path_to_text_config_file=text_config_path, # path_to_training_config_file=aligner_training_path, # ) - # self.assertTrue(isinstance(aligner_config, AlignerConfig)) + # assert isinstance(aligner_config, AlignerConfig) # aligner_config_path = tempdir / "aligner.json" # _writer_helper(aligner_config, aligner_config_path) # # Reload and validate # with mute_logger("everyvoice.config.utils"): # config = AlignerConfig.load_config_from_path(aligner_config_path) - # self.assertTrue(isinstance(config, AlignerConfig)) - # self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME) + # assert isinstance(config, AlignerConfig) + # assert config.preprocessing.dataset == self.DATASET_NAME # self.validate_config_path(config.path_to_model_config_file) # self.validate_config_path(config.path_to_preprocessing_config_file) # self.validate_config_path(config.path_to_text_config_file) @@ -511,7 +511,7 @@ def test_everyvoice_config(self): # config = PreprocessingConfig( # path_to_audio_config_file=(tempdir / "audio.json") # ) - # self.assertTrue(isinstance(config.audio, AudioConfig)) + # assert isinstance(config.audio, AudioConfig) # # Write shared: # _writer_helper( # PreprocessingConfig(dataset=self.DATASET_NAME), @@ -536,7 +536,7 @@ def test_everyvoice_config(self): # path_to_training_config_file=tempdir / "aligner-training.json", # ) # _writer_helper(aligner_config, tempdir / "aligner.json") - # self.assertTrue(isinstance(aligner_config, AlignerConfig)) + # assert isinstance(aligner_config, AlignerConfig) # # Create the missing partial config file by deleting. # # NOTE, we need the file to exists if we want to write its parent config to disk. # (tempdir / "preprocessing.json").unlink() @@ -557,13 +557,13 @@ def test_ckpt_epochs_cannot_be_negative(self): every_n_epochs aka ckpt_epochs must be None or non-negative. """ config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=None) - self.assertEqual(config.ckpt_epochs, None) + assert config.ckpt_epochs is None config = BaseTrainingConfig(ckpt_epochs=0, ckpt_steps=None) - self.assertEqual(config.ckpt_epochs, 0) + assert config.ckpt_epochs == 0 config = BaseTrainingConfig(ckpt_epochs=10, ckpt_steps=None) - self.assertEqual(config.ckpt_epochs, 10) + assert config.ckpt_epochs == 10 with self.assertRaises(ValueError): _ = BaseTrainingConfig(ckpt_epochs=-1, ckpt_steps=None) @@ -573,13 +573,13 @@ def test_ckpt_steps_cannot_be_negative(self): every_n_train_steps aka ckpt_steps must be None or non-negative. """ config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=None) - self.assertEqual(config.ckpt_steps, None) + assert config.ckpt_steps is None config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=0) - self.assertEqual(config.ckpt_steps, 0) + assert config.ckpt_steps == 0 config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=10) - self.assertEqual(config.ckpt_steps, 10) + assert config.ckpt_steps == 10 with self.assertRaises(ValueError): _ = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=-1) @@ -589,20 +589,20 @@ def test_mutually_exclusive_ckpt_options(self): ckpt_epochs and ckpt_steps must be mutually exclusive. """ config = BaseTrainingConfig() - self.assertEqual(config.ckpt_epochs, 1) - self.assertEqual(config.ckpt_steps, None) + assert config.ckpt_epochs == 1 + assert config.ckpt_steps is None config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=None) - self.assertEqual(config.ckpt_epochs, None) - self.assertEqual(config.ckpt_steps, None) + assert config.ckpt_epochs is None + assert config.ckpt_steps is None config = BaseTrainingConfig(ckpt_epochs=7, ckpt_steps=None) - self.assertEqual(config.ckpt_epochs, 7) - self.assertEqual(config.ckpt_steps, None) + assert config.ckpt_epochs == 7 + assert config.ckpt_steps is None config = BaseTrainingConfig(ckpt_epochs=None, ckpt_steps=11) - self.assertEqual(config.ckpt_epochs, None) - self.assertEqual(config.ckpt_steps, 11) + assert config.ckpt_epochs is None + assert config.ckpt_steps == 11 with self.assertRaises(ValueError): _ = BaseTrainingConfig( diff --git a/everyvoice/tests/test_custom_g2p.py b/everyvoice/tests/test_custom_g2p.py index 15cc4e8f..b9fe07dd 100755 --- a/everyvoice/tests/test_custom_g2p.py +++ b/everyvoice/tests/test_custom_g2p.py @@ -100,15 +100,15 @@ def test_basic_g2p(self): def test_unusual_ipa_code(self): # sal-apa goes to sal-ipa instead of sal-apa-ipa sal_apa_g2p = get_g2p_engine("sal-apa") - self.assertEqual(sal_apa_g2p("ac"), list("ats")) + assert sal_apa_g2p("ac") == list("ats") # but iku-sro goes to iku-sro-ipa, not iku-ipa iku_sro_g2p = get_g2p_engine("iku-sro") - self.assertEqual(iku_sro_g2p("akaq"), list("akaq")) + assert iku_sro_g2p("akaq") == list("akaq") def test_phonemizer_normalization(self): moh_g2p = get_g2p_engine("moh") - self.assertEqual(moh_g2p("\u00e9"), ["\u00e9"]) + assert moh_g2p("\u00e9") == ["\u00e9"] def test_invalid_lang_id(self): """ @@ -136,7 +136,7 @@ def test_custom_g2p_engine(self): get_g2p_engine(lang_id) with mute_logger("everyvoice.config.text_config"): TextConfig(g2p_engines={lang_id: "everyvoice.tests.g2p_engines.valid"}) - self.assertIn(lang_id, AVAILABLE_G2P_ENGINES) + assert lang_id in AVAILABLE_G2P_ENGINES self.assertIs( AVAILABLE_G2P_ENGINES[lang_id], everyvoice.tests.g2p_engines.valid, @@ -160,12 +160,12 @@ def test_autoload(self): Default G2PEngine should autoload a CachingG2PEngine(lang_id). """ lang_id = "eng" - self.assertIn(lang_id, AVAILABLE_G2P_ENGINES) - self.assertEqual(AVAILABLE_G2P_ENGINES[lang_id], DEFAULT_G2P) + assert lang_id in AVAILABLE_G2P_ENGINES + assert AVAILABLE_G2P_ENGINES[lang_id] == DEFAULT_G2P g2p_engine = get_g2p_engine(lang_id) self.assertFalse(isinstance(g2p_engine, str)) - self.assertTrue(isinstance(g2p_engine, CachingG2PEngine)) + assert isinstance(g2p_engine, CachingG2PEngine) class TextConfigWithG2pTest(TestCase): @@ -188,7 +188,7 @@ def test_no_user_provided_g2p_engines(self): """ num_g2p_engines = len(AVAILABLE_G2P_ENGINES.keys()) TextConfig() - self.assertEqual(num_g2p_engines, len(AVAILABLE_G2P_ENGINES.keys())) + assert num_g2p_engines == len(AVAILABLE_G2P_ENGINES.keys()) def test_loading_g2p_engines(self): """ @@ -203,8 +203,8 @@ def test_loading_g2p_engines(self): lang_id_2: "everyvoice.tests.g2p_engines.valid", } ) - self.assertIn(lang_id_1, AVAILABLE_G2P_ENGINES) - self.assertIn(lang_id_2, AVAILABLE_G2P_ENGINES) + assert lang_id_1 in AVAILABLE_G2P_ENGINES + assert lang_id_2 in AVAILABLE_G2P_ENGINES self.assertIs( AVAILABLE_G2P_ENGINES[lang_id_1], everyvoice.tests.g2p_engines.valid, @@ -230,7 +230,7 @@ def test_loading_g2p_engines_with_invalid_module(self): ): TextConfig(g2p_engines={lang_id: "unknown_module.g2p"}) self.assertNotIn(lang_id, AVAILABLE_G2P_ENGINES) - self.assertIn("Invalid G2P engine", "\n".join(logs.output)) + assert "Invalid G2P engine" in "\n".join(logs.output) def test_g2p_engine_signature_multiple_arguments(self): """ @@ -279,7 +279,7 @@ def test_overriding_default_g2p_engine(self): """ num_g2p_engines = len(AVAILABLE_G2P_ENGINES.keys()) lang_id = "fra" - self.assertIn(lang_id, AVAILABLE_G2P_ENGINES) + assert lang_id in AVAILABLE_G2P_ENGINES old_g2p_engine = AVAILABLE_G2P_ENGINES[lang_id] with mute_logger("everyvoice.config.text_config"): TextConfig(g2p_engines={lang_id: "everyvoice.tests.g2p_engines.valid"}) diff --git a/everyvoice/tests/test_dataloader.py b/everyvoice/tests/test_dataloader.py index 9e35dcca..1c49561d 100755 --- a/everyvoice/tests/test_dataloader.py +++ b/everyvoice/tests/test_dataloader.py @@ -1,9 +1,8 @@ #!/usr/bin/env python import sys -from unittest import TestCase -from pytest import main +from pytest import fixture, main, raises from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel from everyvoice.dataloader import BaseDataModule @@ -24,69 +23,65 @@ from everyvoice.utils import filter_dataset_based_on_target_text_representation_level -class DataLoaderTest(PreprocessedAudioFixture, TestCase): - """Basic test for dataloaders""" - - def setUp(self) -> None: - super().setUp() - - self.config = EveryVoiceConfig( +@fixture +def config() -> EveryVoiceConfig: + return EveryVoiceConfig( + contact=TEST_CONTACT, + feature_prediction=FeaturePredictionConfig(contact=TEST_CONTACT), + vocoder=VocoderConfig( contact=TEST_CONTACT, - feature_prediction=FeaturePredictionConfig(contact=TEST_CONTACT), - vocoder=VocoderConfig( - contact=TEST_CONTACT, - training=HiFiGANTrainingConfig( - training_filelist=PreprocessedAudioFixture.lj_preprocessed - / "preprocessed_filelist.psv", - validation_filelist=PreprocessedAudioFixture.lj_preprocessed - / "validation_preprocessed_filelist.psv", - ), - preprocessing=PreprocessingConfig( - save_dir=PreprocessedAudioFixture.lj_preprocessed, - ), + training=HiFiGANTrainingConfig( + training_filelist=PreprocessedAudioFixture.lj_preprocessed + / "preprocessed_filelist.psv", + validation_filelist=PreprocessedAudioFixture.lj_preprocessed + / "validation_preprocessed_filelist.psv", ), - ) + preprocessing=PreprocessingConfig( + save_dir=PreprocessedAudioFixture.lj_preprocessed, + ), + ), + ) - def test_base_data_loader(self): - bdm = BaseDataModule(self.config.vocoder) - with self.assertRaises(NotImplementedError): + +class TestDataLoader(PreprocessedAudioFixture): + """Basic test for dataloaders""" + + def test_base_data_loader(self, config): + bdm = BaseDataModule(config.vocoder) + with raises(NotImplementedError): bdm.load_dataset() - def test_spec_dataset(self): + def test_spec_dataset(self, config): dataset = SpecDataset( - self.config.vocoder.training.filelist_loader( - self.config.vocoder.training.training_filelist + config.vocoder.training.filelist_loader( + config.vocoder.training.training_filelist ), - self.config.vocoder, + config.vocoder, use_segments=True, ) for sample in dataset: spec, audio, basename, spec_from_audio = sample - self.assertTrue(isinstance(basename, str)) - self.assertEqual(spec.size(), spec_from_audio.size()) - self.assertEqual( - spec.size(0), self.config.vocoder.preprocessing.audio.n_mels - ) - self.assertEqual( - spec.size(1), - self.config.vocoder.preprocessing.audio.vocoder_segment_size - / ( - self.config.vocoder.preprocessing.audio.fft_hop_size - * ( - self.config.vocoder.preprocessing.audio.output_sampling_rate - // self.config.vocoder.preprocessing.audio.input_sampling_rate - ) - ), + assert isinstance(basename, str) + assert spec.size() == spec_from_audio.size() + assert spec.size(0) == config.vocoder.preprocessing.audio.n_mels + assert spec.size( + 1 + ) == config.vocoder.preprocessing.audio.vocoder_segment_size / ( + config.vocoder.preprocessing.audio.fft_hop_size + * ( + config.vocoder.preprocessing.audio.output_sampling_rate + // config.vocoder.preprocessing.audio.input_sampling_rate + ) ) - def test_hifi_data_loader(self): - hfgdm = HiFiGANDataModule(self.config.vocoder) + def test_hifi_data_loader(self, config): + hfgdm = HiFiGANDataModule(config.vocoder) hfgdm.load_dataset() - self.assertEqual(len(hfgdm.train_dataset), 5) + assert len(hfgdm.train_dataset) == 5 def test_filter_dataset(self): train_dataset = [{"character_tokens": "b", "phone_tokens": ""}] * 4 - with self.assertRaises(SystemExit) as cm: + with raises(SystemExit) as cm: with mute_logger("everyvoice.utils"): filter_dataset_based_on_target_text_representation_level( TargetTrainingTextRepresentationLevel.characters, @@ -94,8 +89,8 @@ def test_filter_dataset(self): "training", 6, ) - self.assertEqual(cm.exception.code, 1) - with self.assertRaises(SystemExit) as cm: + assert cm.value.code == 1 + with raises(SystemExit) as cm: with mute_logger("everyvoice.utils"): filter_dataset_based_on_target_text_representation_level( TargetTrainingTextRepresentationLevel.ipa_phones, @@ -103,7 +98,7 @@ def test_filter_dataset(self): "training", 4, ) - self.assertEqual(cm.exception.code, 1) + assert cm.value.code == 1 train_ds = filter_dataset_based_on_target_text_representation_level( TargetTrainingTextRepresentationLevel.characters, train_dataset, @@ -116,8 +111,8 @@ def test_filter_dataset(self): "validation", 4, ) - self.assertEqual(len(train_ds), 4) - self.assertEqual(len(val_ds), 4) + assert len(train_ds) == 4 + assert len(val_ds) == 4 def test_hifi_ft_data_loader(self): """TODO: can't make this test until I generate some synthesized samples""" @@ -131,18 +126,18 @@ def test_e2e_data_module(self): # TODO: once e2e is done pass - def test_imbalanced_sampler(self): + def test_imbalanced_sampler(self, config): dataset = SpecDataset( - self.config.vocoder.training.filelist_loader( - self.config.vocoder.training.training_filelist + config.vocoder.training.filelist_loader( + config.vocoder.training.training_filelist ), - self.config.vocoder, + config.vocoder, use_segments=True, ) sampler = ImbalancedDatasetSampler(dataset) print(sampler.weights) sample = list(sampler) - self.assertEqual(len(sample), 5) + assert len(sample) == 5 if __name__ == "__main__": diff --git a/everyvoice/tests/test_doctests.py b/everyvoice/tests/test_doctests.py index ea172a56..1821cc33 100755 --- a/everyvoice/tests/test_doctests.py +++ b/everyvoice/tests/test_doctests.py @@ -2,7 +2,6 @@ import doctest import sys -from unittest import TestCase from pytest import main @@ -12,22 +11,18 @@ import everyvoice.wizard.utils -class RunDocTests(TestCase): - - def test_run_all_doctests(self): - for module_with_doctests in ( - everyvoice.demo.app, - everyvoice.text.features, - everyvoice.text.text_processor, - everyvoice.text.utils, - everyvoice.utils, - everyvoice.wizard.utils, - ): - with self.subTest( - "Running doctests in", module=module_with_doctests.__name__ - ): - results = doctest.testmod(module_with_doctests) - self.assertFalse(results.failed, results) +def test_run_all_doctests(subtests) -> None: + for module_with_doctests in ( + everyvoice.demo.app, + everyvoice.text.features, + everyvoice.text.text_processor, + everyvoice.text.utils, + everyvoice.utils, + everyvoice.wizard.utils, + ): + with subtests.test("Running doctests in", module=module_with_doctests.__name__): + results = doctest.testmod(module_with_doctests) + assert not results.failed, results if __name__ == "__main__": diff --git a/everyvoice/tests/test_evaluation.py b/everyvoice/tests/test_evaluation.py index 6a3d5398..b1b219b3 100755 --- a/everyvoice/tests/test_evaluation.py +++ b/everyvoice/tests/test_evaluation.py @@ -27,10 +27,10 @@ def test_squim_evaluation(self): subj_model, subj_sr, ) - self.assertEqual(round(mos, 2), 4.47) + assert round(mos, 2) == 4.47 self.assertLess(stoi, 1) - self.assertEqual(round(pesq, 2), 3.88) - self.assertEqual(round(si_sdr, 2), 28.64) + assert round(pesq, 2) == 3.88 + assert round(si_sdr, 2) == 28.64 if __name__ == "__main__": diff --git a/everyvoice/tests/test_model.py b/everyvoice/tests/test_model.py index 200f857a..d597c173 100755 --- a/everyvoice/tests/test_model.py +++ b/everyvoice/tests/test_model.py @@ -81,8 +81,8 @@ def setUp(self) -> None: def test_hparams(self): self.hifi_gan = HiFiGAN(self.config.vocoder) - self.assertEqual(self.config.vocoder, self.hifi_gan.hparams.config) - self.assertEqual(self.config.vocoder, self.hifi_gan.config) + assert self.config.vocoder == self.hifi_gan.hparams.config + assert self.config.vocoder == self.hifi_gan.config def test_checkpoints_only_contain_serializable_content(self): """These tests help remove any dependencies on specific versions of Pydantic. @@ -300,13 +300,13 @@ def test_wrong_model_type(self): ckpt_fn = tmpdir_str + "/checkpoint.ckpt" trainer.save_checkpoint(ckpt_fn) m = torch.load(ckpt_fn, weights_only=True) - self.assertIn("model_info", m.keys()) + assert "model_info" in m.keys() m["model_info"]["name"] = "BAD_TYPE" torch.save(m, ckpt_fn) m = torch.load(ckpt_fn, weights_only=True) - self.assertIn("model_info", m.keys()) - self.assertEqual(m["model_info"]["name"], "BAD_TYPE") - # self.assertEqual(m["model_info"]["version"], "1.0") + assert "model_info" in m.keys() + assert m["model_info"]["name"] == "BAD_TYPE" + # assert m["model_info"]["version"] == "1.0" with self.assertRaisesRegex( TypeError, r"Wrong model type \(BAD_TYPE\), we are expecting a 'FastSpeech2' model", @@ -373,9 +373,9 @@ def test_missing_model_version(self): ckpt_fn = tmpdir_str + "/checkpoint.ckpt" trainer.save_checkpoint(ckpt_fn) m = torch.load(ckpt_fn, weights_only=True) - self.assertIn("model_info", m.keys()) - self.assertEqual(m["model_info"]["name"], ModelType.__name__) - self.assertEqual(m["model_info"]["version"], CANARY_VERSION) + assert "model_info" in m.keys() + assert m["model_info"]["name"] == ModelType.__name__ + assert m["model_info"]["version"] == CANARY_VERSION del m["model_info"]["version"] torch.save(m, ckpt_fn) if isinstance(model, FastSpeech2): @@ -387,7 +387,7 @@ def test_missing_model_version(self): else: with mute_logger("everyvoice.config.text_config"): model = ModelType.load_from_checkpoint(ckpt_fn) - self.assertIn(model._VERSION, ["1.0", "1.1"]) + assert model._VERSION, ["1.0" in "1.1"] def test_newer_model_version(self): """ @@ -448,9 +448,9 @@ def test_newer_model_version(self): ckpt_fn = tmpdir_str + "/checkpoint.ckpt" trainer.save_checkpoint(ckpt_fn) m = torch.load(ckpt_fn, weights_only=True) - self.assertIn("model_info", m.keys()) - self.assertEqual(m["model_info"]["name"], ModelType.__name__) - self.assertEqual(m["model_info"]["version"], NEWER_VERSION) + assert "model_info" in m.keys() + assert m["model_info"]["name"] == ModelType.__name__ + assert m["model_info"]["version"] == NEWER_VERSION with self.assertRaisesRegex( ValueError, r"Your model was created with a newer version of EveryVoice, please update your software.", @@ -484,7 +484,7 @@ def test_config_versionless(self): self.assertNotIn("VERSION", arguments) c = ConfigType(**arguments) - self.assertEqual(c.VERSION, "1.0") + assert c.VERSION == "1.0" def test_config_newer_version(self): """ @@ -514,7 +514,7 @@ def test_newer_version(self): from packaging.version import Version self.assertFalse("10.0" > "9.0") - self.assertTrue(Version("10.0") > Version("9.0")) + assert Version("10.0") > Version("9.0") if __name__ == "__main__": diff --git a/everyvoice/tests/test_preprocessing.py b/everyvoice/tests/test_preprocessing.py index 6a1cefc7..3a512be0 100755 --- a/everyvoice/tests/test_preprocessing.py +++ b/everyvoice/tests/test_preprocessing.py @@ -47,7 +47,7 @@ class PreprocessingTest(PreprocessedAudioFixture, TestCase): filelist = generic_psv_filelist_reader(TEST_DATA_DIR / "metadata.psv") def test_read_filelist(self): - self.assertEqual(self.filelist[0]["basename"], "LJ050-0269") + assert self.filelist[0]["basename"] == "LJ050-0269" def test_no_permissions(self): no_permissions_args = self.fp_config.model_dump() @@ -110,16 +110,16 @@ def test_process_empty_audio(self) -> None: for audiofile in ["empty.wav", "zeros.wav"]: with mute_logger("everyvoice.preprocessor.preprocessor"): audio, sr = self.preprocessor.process_audio(TEST_DATA_DIR / audiofile) - self.assertEqual(audio, None) - self.assertEqual(sr, None) + assert audio is None + assert sr is None def test_too_short(self) -> None: # too-short.wav is only .28s long, shorter than our minimum .4s with mute_logger("everyvoice.preprocessor.preprocessor"): audio, sr = self.preprocessor.process_audio(TEST_DATA_DIR / "too-short.wav") - self.assertEqual(audio, None) - self.assertEqual(sr, None) - self.assertEqual(self.preprocessor.counters.value("audio_too_short"), 1) + assert audio is None + assert sr is None + assert self.preprocessor.counters.value("audio_too_short") == 1 def test_process_bad_sox_effects(self) -> None: sox_errors_before = self.preprocessor.counters.value("sox_error") @@ -148,8 +148,8 @@ def test_multichannel_audio_skipped(self): ) # Should return None, None indicating the file was skipped - self.assertEqual(audio, None) - self.assertEqual(sr, None) + assert audio is None + assert sr is None # Should be added to the multichannel files list self.assertIn( @@ -157,7 +157,7 @@ def test_multichannel_audio_skipped(self): ) # Should increment the counter - self.assertEqual(self.preprocessor.counters.value("multichannel_files"), 1) + assert self.preprocessor.counters.value("multichannel_files") == 1 def test_multichannel_files_report(self): """Test that multichannel files appear in the report""" @@ -173,11 +173,11 @@ def test_multichannel_files_report(self): report = self.preprocessor.report() # Check that multichannel files are mentioned in the report - self.assertIn("multichannel_files", report) + assert "multichannel_files" in report expected_count = initial_count + 1 - self.assertIn(f"multichannel_files {expected_count}", report) - self.assertIn(f"Multichannel Audio Files ({expected_count} total)", report) - self.assertIn(str(multichannel_audio_path), report) + assert f"multichannel_files {expected_count}" in report + assert f"Multichannel Audio Files ({expected_count} total)" in report + assert str(multichannel_audio_path) in report def test_multichannel_files_empty_report(self): """Test that report works correctly when no multichannel files exist""" @@ -188,7 +188,7 @@ def test_multichannel_files_empty_report(self): report = fresh_preprocessor.report() # Should show 0 multichannel files and no multichannel files section - self.assertIn("multichannel_files 0", report) + assert "multichannel_files 0" in report self.assertNotIn("Multichannel Audio Files", report) def test_multichannel_files_file_creation(self): @@ -221,14 +221,14 @@ def test_multichannel_files_file_creation(self): # Check that multichannel_files.txt was created multichannel_file = save_dir / "multichannel_files.txt" - self.assertTrue(multichannel_file.exists()) + assert multichannel_file.exists() # Check the content of multichannel_files.txt with open(multichannel_file, "r") as f: content = f.read() - self.assertIn("Multichannel Audio Files", content) - self.assertIn("multichannel_test.wav", content) - self.assertIn("=" * 50, content) + assert "Multichannel Audio Files" in content + assert "multichannel_test.wav" in content + assert "=" * 50 in content def test_multichannel_preprocess_file_output(self): """Test the exact multichannel file output code path from preprocess method""" @@ -261,13 +261,13 @@ def test_multichannel_preprocess_file_output(self): # Verify the file was created and has correct content multichannel_file = save_dir / "multichannel_files.txt" - self.assertTrue(multichannel_file.exists()) + assert multichannel_file.exists() with open(multichannel_file, "r") as f: content = f.read() - self.assertIn("Multichannel Audio Files (1 total)", content) - self.assertIn("multichannel_test.wav", content) - self.assertIn("=" * 50, content) + assert "Multichannel Audio Files (1 total)" in content + assert "multichannel_test.wav" in content + assert "=" * 50 in content def test_audio_too_long_condition(self): """Test that audio files longer than max_audio_length are skipped""" @@ -277,11 +277,11 @@ def test_audio_too_long_condition(self): audio, sr = self.preprocessor.process_audio(long_audio_path, hop_size=256) # Should return None, None indicating the file was skipped - self.assertEqual(audio, None) - self.assertEqual(sr, None) + assert audio is None + assert sr is None # Should increment the counter - self.assertEqual(self.preprocessor.counters.value("audio_too_long"), 1) + assert self.preprocessor.counters.value("audio_too_long") == 1 def test_full_preprocess_with_multichannel_files(self): """Test the full preprocess method creates multichannel_files.txt""" @@ -349,9 +349,9 @@ def test_full_preprocess_with_multichannel_files(self): # Verify the content with open(multichannel_file, "r") as f: content = f.read() - self.assertIn("Multichannel Audio Files (1 total)", content) - self.assertIn("multichannel_test.wav", content) - self.assertIn("=" * 50, content) + assert "Multichannel Audio Files (1 total)" in content + assert "multichannel_test.wav" in content + assert "=" * 50 in content def test_process_audio(self): import torchaudio @@ -360,8 +360,8 @@ def test_process_audio(self): audio, sr = self.preprocessor.process_audio( self.wavs_dir / (entry["basename"] + ".wav"), hop_size=256 ) - self.assertEqual(sr, 22050) - self.assertEqual(audio.dtype, float32) + assert sr == 22050 + assert audio.dtype == float32 # test that truncating according to hop size actually happened raw_audio, raw_sr = torchaudio.load( str(self.wavs_dir / (entry["basename"] + ".wav")) @@ -430,9 +430,9 @@ def test_spectral_feats(self): linear_preprocessor.config.preprocessing.audio.n_fft // 2 + 1, ) # check all same length - self.assertEqual(feats.size(1), linear_feats.size(1)) + assert feats.size(1) == linear_feats.size(1) # check all same length - self.assertEqual(complex_feats.size(1), linear_feats.size(1)) + assert complex_feats.size(1) == linear_feats.size(1) def test_bad_pitch(self): """Some files don't have any pitch values so we should make sure we handle these properly""" @@ -445,9 +445,9 @@ def test_bad_pitch(self): audio, self.preprocessor.input_spectral_transform ) frame_pitch_pyworld = preprocessor_pyworld.extract_pitch(audio) - self.assertEqual(frame_pitch_pyworld.max(), 0) - self.assertEqual(frame_pitch_pyworld.min(), 0) - self.assertEqual(frame_pitch_pyworld.size(0), feats.size(1)) + assert frame_pitch_pyworld.max() == 0 + assert frame_pitch_pyworld.min() == 0 + assert frame_pitch_pyworld.size(0) == feats.size(1) def test_pitch(self): pyworld_config = VocoderConfig( @@ -487,9 +487,9 @@ def test_pitch(self): frame_pitch_pyworld, durs ) # Ensure avg pitch for each phone - self.assertEqual(len(durs), pyworld_phone_avg_energy.size(0)) + assert len(durs) == pyworld_phone_avg_energy.size(0) # Ensure same number of frames - self.assertEqual(frame_pitch_pyworld.size(0), feats.size(1)) + assert frame_pitch_pyworld.size(0) == feats.size(1) # TODO: test nans: torch.any(torch.Tensor([[torch.nan, 2]]).isnan()) @@ -524,7 +524,7 @@ def test_duration(self): # note: this is off by a few frames due to mismatches in hop size between the aligner the test data # was trained with and the settings defined by the spectral transform function here. # It would be a problem if it weren't but it's not really relevant since we're using jointly learned alignments now. - self.assertTrue(feats.size(1) - int(sum(durs)) <= 10) + assert feats.size(1) - int(sum(durs)) <= 10 def test_energy(self): frame_energy_config = VocoderConfig( @@ -563,9 +563,9 @@ def test_energy(self): frame_energy, durs ) # Ensure avg energy for each phone - self.assertEqual(phone_avg_energy.size(0), len(durs)) + assert phone_avg_energy.size(0) == len(durs) # Ensure same number of frames - self.assertEqual(frame_energy.size(0), feats.size(1)) + assert frame_energy.size(0) == feats.size(1) def test_sanity(self): """TODO: make sanity checking code for each type of data, maybe also data analysis tooling""" @@ -744,7 +744,7 @@ def test_mixed_cleaners(self) -> None: ) if result.exit_code != 0 or stubs.VERBOSE_OVERRIDE: print(result.output) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 os.chdir("mixed-cleaners") with open( "config/everyvoice-shared-text.yaml", "r", encoding="utf8" @@ -752,23 +752,23 @@ def test_mixed_cleaners(self) -> None: text_config = TextConfig(**yaml.load(f, Loader=yaml.FullLoader)) symbols = text_config.symbols.all_except_punctuation for character in ("é", "É", "é"): # nfc(é), nfc(É), nfd(é) - self.assertIn(character, symbols) + assert character in symbols result = runner.invoke( app, ["preprocess", "config/everyvoice-text-to-spec.yaml"] ) if result.exit_code != 0 or stubs.VERBOSE_OVERRIDE: print(result.output) - self.assertEqual(result.exit_code, 0) + assert result.exit_code == 0 filelist = generic_psv_filelist_reader("preprocessed/filelist.psv") - self.assertEqual(filelist[4]["label"], "lowercase-only") - self.assertIn("/é/", filelist[4]["character_tokens"]) # lower NFD only + assert filelist[4]["label"] == "lowercase-only" + assert "/é/" in filelist[4]["character_tokens"] # lower NFD only self.assertNotIn("/é/", filelist[4]["character_tokens"]) # not NFC self.assertNotIn("/É/", filelist[4]["character_tokens"]) # not upper - self.assertEqual(filelist[8]["label"], "nfc-only") - self.assertIn("/é/", filelist[8]["character_tokens"]) # lower NFC + assert filelist[8]["label"] == "nfc-only" + assert "/é/" in filelist[8]["character_tokens"] # lower NFC self.assertNotIn("/é/", filelist[8]["character_tokens"]) # not NFD - self.assertEqual(filelist[9]["label"], "nfc-only") - self.assertIn("/É/", filelist[9]["character_tokens"]) # upper NFC + assert filelist[9]["label"] == "nfc-only" + assert "/É/" in filelist[9]["character_tokens"] # upper NFC self.assertNotIn("/É/", filelist[9]["character_tokens"]) # not NFD def test_incremental_preprocess(self): @@ -894,8 +894,8 @@ def fail_config_lock( to_process=to_process, ) log_output = "\n".join(logs.output) - self.assertIn("Config lock mismatch:", log_output) - self.assertIn(message, log_output) + assert "Config lock mismatch:" in log_output + assert message in log_output fail_config_lock( fp_config.preprocessing.audio, @@ -941,20 +941,20 @@ def test_train_split(self): PreprocessingConfig's train_split should be [0., 1.]. """ config = PreprocessingConfig(train_split=0.5) - self.assertEqual(config.train_split, 0.5) + assert config.train_split == 0.5 config = PreprocessingConfig(train_split=0.0) - self.assertEqual(config.train_split, 0.0) + assert config.train_split == 0.0 config = PreprocessingConfig(train_split=1.0) - self.assertEqual(config.train_split, 1.0) + assert config.train_split == 1.0 with self.assertRaises(ValidationError), capture_stdout() as cout: config = PreprocessingConfig(train_split=-0.1) - self.assertIn("Input should be greater than or equal to 0", cout.getvalue()) + assert "Input should be greater than or equal to 0" in cout.getvalue() with self.assertRaises(ValidationError), capture_stdout() as cout: config = PreprocessingConfig(train_split=1.1) - self.assertIn("Input should be less than or equal to 1", cout.getvalue()) + assert "Input should be less than or equal to 1" in cout.getvalue() def test_no_speaker(self) -> None: """Exercise getting the default speaker and languages during preprocessing""" @@ -1015,15 +1015,15 @@ def test_stats(self) -> None: ) assert char_length_data is not None char_length_stats = char_length_data.calculate_stats() - self.assertEqual(char_length_stats["min"], 83) - self.assertEqual(char_length_stats["max"], 118) + assert char_length_stats["min"] == 83 + assert char_length_stats["max"] == 118 self.assertAlmostEqual(char_length_stats["std"], sqrt(200.5), places=4) - self.assertEqual(char_length_stats["mean"], 105) + assert char_length_stats["mean"] == 105 assert phone_length_data is not None phone_length_stats = phone_length_data.calculate_stats() - self.assertEqual(phone_length_stats["min"], 76) - self.assertEqual(phone_length_stats["max"], 111) + assert phone_length_stats["min"] == 76 + assert phone_length_stats["max"] == 111 self.assertAlmostEqual(phone_length_stats["std"], sqrt(216.3), places=4) self.assertAlmostEqual(phone_length_stats["mean"], 98.4, places=4) @@ -1072,20 +1072,20 @@ def test_missing_audio_files_detection(self): with open(missing_files_path, "r", encoding="utf8") as f: content = f.read() - self.assertIn("Missing Audio Files (2 total)", content) - self.assertIn("nonexistent1.wav", content) - self.assertIn("nonexistent2.wav", content) + assert "Missing Audio Files (2 total)" in content + assert "nonexistent1.wav" in content + assert "nonexistent2.wav" in content # Check that missing files are also included in summary report summary_path = tmpdir / "preprocessed" / "summary.txt" - self.assertTrue(summary_path.exists()) + assert summary_path.exists() with open(summary_path, "r", encoding="utf8") as f: summary_content = f.read() - self.assertIn("Missing Audio Files (2 total)", summary_content) - self.assertIn("nonexistent1.wav", summary_content) - self.assertIn("nonexistent2.wav", summary_content) + assert "Missing Audio Files (2 total)" in summary_content + assert "nonexistent1.wav" in summary_content + assert "nonexistent2.wav" in summary_content def test_no_missing_files(self): """Test that missing_files.txt is not created when all files exist""" @@ -1115,7 +1115,7 @@ def test_no_missing_files(self): # Check that summary doesn't mention missing files summary_path = tmpdir / "preprocessed" / "summary.txt" - self.assertTrue(summary_path.exists()) + assert summary_path.exists() with open(summary_path, "r", encoding="utf8") as f: summary_content = f.read() @@ -1133,8 +1133,8 @@ def test_missing_files_spec_processing(self): # Should return None for both specs and track missing file self.assertIsNone(input_spec) self.assertIsNone(output_spec) - self.assertEqual(len(preprocessor.missing_files_list), 1) - self.assertIn("nonexistent", preprocessor.missing_files_list[0]) + assert len(preprocessor.missing_files_list) == 1 + assert "nonexistent" in preprocessor.missing_files_list[0] def test_missing_files_report_formatting(self): """Test report method includes missing files section with correct formatting""" @@ -1149,9 +1149,9 @@ def test_missing_files_report_formatting(self): report = preprocessor.report() # Check report contains missing files section - self.assertIn("Missing Audio Files (2 total)", report) - self.assertIn("- /path/to/missing1.wav", report) - self.assertIn("- /path/to/missing2.wav", report) + assert "Missing Audio Files (2 total)" in report + assert "- /path/to/missing1.wav" in report + assert "- /path/to/missing2.wav" in report def test_missing_files_basename_with_wav_extension(self): """Test missing files when basename already has .wav extension""" @@ -1191,14 +1191,14 @@ def test_missing_files_basename_with_wav_extension(self): with open(missing_files_path, "r", encoding="utf8") as f: content = f.read() - self.assertIn("missing.wav", content) + assert "missing.wav" in content def test_empty_missing_files_list_report(self): """Test report method when no missing files exist""" preprocessor = Preprocessor(FeaturePredictionConfig(contact=TEST_CONTACT)) # Empty missing files list (default state) - self.assertEqual(len(preprocessor.missing_files_list), 0) + assert len(preprocessor.missing_files_list) == 0 report = preprocessor.report() @@ -1295,7 +1295,7 @@ def test_working_call1(self) -> None: apply_sox_effects_to_file( self.audiofile, tmpdir / "output1.wav", self.many_effects ) - self.assertTrue((tmpdir / "output1.wav").exists()) + assert (tmpdir / "output1.wav").exists() def test_working_call2(self) -> None: with tempfile.TemporaryDirectory(prefix="sox_effects_", dir=".") as tmpdir_s: @@ -1303,7 +1303,7 @@ def test_working_call2(self) -> None: apply_sox_effects_to_file( self.audiofile, tmpdir / "output2.wav", self.many_effects[:-1] ) - self.assertTrue((tmpdir / "output2.wav").exists()) + assert (tmpdir / "output2.wav").exists() def test_working_call3(self) -> None: with tempfile.TemporaryDirectory(prefix="sox_effects_", dir=".") as tmpdir_s: @@ -1311,7 +1311,7 @@ def test_working_call3(self) -> None: apply_sox_effects_to_file( self.audiofile, tmpdir / "output3.wav", self.many_effects[1:] ) - self.assertTrue((tmpdir / "output3.wav").exists()) + assert (tmpdir / "output3.wav").exists() class PreprocessingHierarchyTest(TestCase): @@ -1358,14 +1358,14 @@ def test_hierarchy(self): if t == "audio" else list(tmpdir.glob(f"**/{t}/LJ010/*.pt")) ) - self.assertEqual(len(files), 1) + assert len(files) == 1 # Second speaker has 5 recordings files = ( list(tmpdir.glob(f"**/{t}/LJ050/*.wav")) if t == "audio" else list(tmpdir.glob(f"**/{t}/LJ050/*.pt")) ) - self.assertEqual(len(files), 5) + assert len(files) == 5 if __name__ == "__main__": diff --git a/everyvoice/tests/test_subsample.py b/everyvoice/tests/test_subsample.py index 536f1d94..6f76fd4b 100755 --- a/everyvoice/tests/test_subsample.py +++ b/everyvoice/tests/test_subsample.py @@ -30,12 +30,12 @@ def test_sv(self): "psv", ], ) - self.assertEqual(result.exit_code, 0) - self.assertIn("basename|", result.stdout) - self.assertIn("LJ050-0269|", result.stdout) - self.assertIn("LJ050-0270|", result.stdout) - self.assertIn("LJ050-0271|", result.stdout) - self.assertIn("LJ050-0272.wav|", result.stdout) + assert result.exit_code == 0 + assert "basename|" in result.stdout + assert "LJ050-0269|" in result.stdout + assert "LJ050-0270|" in result.stdout + assert "LJ050-0271|" in result.stdout + assert "LJ050-0272.wav|" in result.stdout self.assertNotIn("LJ050-0273|", result.stdout) def test_festival(self): @@ -46,14 +46,14 @@ def test_festival(self): [str(self.metadata_path), str(self.wavs_path), "-d", "7", "-f", "festival"], ) - self.assertEqual(result.exit_code, 0) - self.assertIn("LJ050-0269", result.stdout) - self.assertIn("LJ050-0270", result.stdout) + assert result.exit_code == 0 + assert "LJ050-0269" in result.stdout + assert "LJ050-0270" in result.stdout self.assertNotIn("LJ050-0271", result.stdout) def test_help(self): result = self.runner.invoke(app, ["--help"]) - self.assertIn("Standalone test script for subsampling corpora", result.stdout) + assert "Standalone test script for subsampling corpora" in result.stdout def test_speakerid(self): self.metadata_path = ( @@ -77,9 +77,9 @@ def test_speakerid(self): ], ) - self.assertIn("basename|", result.stdout) - self.assertIn("LJ050-0269|", result.stdout) - self.assertIn("LJ050-0272.wav|", result.stdout) + assert "basename|" in result.stdout + assert "LJ050-0269|" in result.stdout + assert "LJ050-0272.wav|" in result.stdout self.assertNotIn("LJ050-0270|", result.stdout) def test_error_validation(self): @@ -89,8 +89,8 @@ def test_error_validation(self): result = self.runner.invoke( app, [str(self.metadata_path), str(self.wavs_path), "-d", "7", "-f", "txt"] ) - self.assertNotEqual(result.exit_code, 0) - self.assertIn("Invalid value for", result.output) + assert result.exit_code != 0 + assert "Invalid value for" in result.output self.assertRegex( result.output, r"(?s)txt is not one of psv tsv csv festival".replace(" ", r".*"), @@ -115,7 +115,7 @@ def test_error_validation(self): ], ) - self.assertNotEqual(result.exit_code, 0) + assert result.exit_code != 0 self.assertRegex( result.output, r"Invalid value: Festival formatted files cannot have a speaker id.".replace( @@ -137,7 +137,7 @@ def test_error_validation(self): "psv", ], ) - self.assertNotEqual(result.exit_code, 0) + assert result.exit_code != 0 self.assertRegex( result.output, r"A \.wav file could not be found".replace(" ", r"[\s\S]*"), diff --git a/everyvoice/tests/test_text.py b/everyvoice/tests/test_text.py index 392e8ce4..89ceef79 100755 --- a/everyvoice/tests/test_text.py +++ b/everyvoice/tests/test_text.py @@ -42,7 +42,7 @@ def test_text_to_sequence(self): def test_token_sequence_to_text(self): sequence = [60, 57, 64, 64, 67, 1, 75, 67, 70, 64, 56] - self.assertEqual(self.base_text_processor.encode_text("hello world"), sequence) + assert self.base_text_processor.encode_text("hello world") == sequence def test_hardcoded_symbols(self): self.assertEqual( @@ -191,9 +191,9 @@ def test_phonological_features(self): encode_as_phonological_features=True, ) self.assertEqual(moh_text_processor.decode_tokens(g2p_tokens, "", ""), "séːɡũ") - self.assertEqual(len(g2p_tokens), len(feats)) + assert len(g2p_tokens) == len(feats) self.assertNotEqual(len(g2p_tokens), len(one_hot_tokens)) - self.assertEqual(len(feats[0]), N_PHONOLOGICAL_FEATURES) + assert len(feats[0]) == N_PHONOLOGICAL_FEATURES def test_duplicates_removed(self): duplicate_symbols_text_processor = TextProcessor( @@ -219,7 +219,7 @@ def test_dipgrahs(self): ) text = "ee" # should be treated as "ee" and not two instances of "e" sequence = digraph_text_processor.encode_text(text) - self.assertEqual(len(sequence), 1) + assert len(sequence) == 1 def test_normalization(self): # This test doesn't really test very much, but just here to highlight that base cleaning doesn't involve NFC @@ -246,8 +246,8 @@ def test_missing_symbol(self): text = "h3llo world" sequence = self.base_text_processor.encode_text(text) self.assertNotEqual(self.base_text_processor.decode_tokens(sequence), text) - self.assertIn("3", self.base_text_processor.missing_symbols) - self.assertEqual(self.base_text_processor.missing_symbols["3"], 1) + assert "3" in self.base_text_processor.missing_symbols + assert self.base_text_processor.missing_symbols["3"] == 1 def test_use_slash(self): text = "word/token" @@ -256,9 +256,9 @@ def test_use_slash(self): ) sequence = text_processor.encode_text(text) decoded = text_processor.decode_tokens(sequence) - self.assertEqual(decoded, "w/o/r/d/" + JOINER_SUBSTITUTION + "/t/o/k/e/n") + assert decoded == "w/o/r/d/" + JOINER_SUBSTITUTION + "/t/o/k/e/n" encoded = text_processor.encode_escaped_string_sequence(decoded) - self.assertEqual(encoded, sequence) + assert encoded == sequence with self.assertRaises(exceptions.OutOfVocabularySymbolError): # / is OOV, so JOINER_SUBSTITUTION will also be OOV @@ -275,10 +275,10 @@ def test_encode_string_tokens(self): self.base_text_processor.encode_string_tokens([JOINER_SUBSTITUTION]) def test_is_sentence_final(self): - self.assertTrue(is_sentence_final("!")) - self.assertTrue(is_sentence_final("?")) - self.assertTrue(is_sentence_final(".")) - self.assertTrue(is_sentence_final("᙮")) + assert is_sentence_final("!") + assert is_sentence_final("?") + assert is_sentence_final(".") + assert is_sentence_final("᙮") self.assertFalse(is_sentence_final("¡")) self.assertFalse(is_sentence_final("¿")) @@ -463,7 +463,7 @@ def test_custom_desired_length(self): def test_normalization(self): a = "Welcome to the EveryVoice Documentation! Please read the background section below." text = " Welcome to the EveryVoice Documentation!\n\n\n\nPlease read the background section below. " - self.assertEqual([a], chunk_text(text)) + assert [a] == chunk_text(text) def test_quote_toggling(self): text = 'There are approximately "70 Indigenous languages spoken in Canada. The majority of these languages" now have fewer than 500 fluent speakers remaining.' @@ -496,7 +496,7 @@ def test_custom_weak_boundaries(self): # With custom weak boundaries self.assertNotIn(a, chunk_text(text, 15, 30, weak_boundaries=":;")) # Without custom weak boundaries - self.assertIn(a, chunk_text(text, 15, 30)) + assert a, chunk_text(text, 15 in 30) def test_custom_strong_boundaries(self): """ diff --git a/everyvoice/tests/test_utils.py b/everyvoice/tests/test_utils.py index 4f4e1c60..919f9812 100755 --- a/everyvoice/tests/test_utils.py +++ b/everyvoice/tests/test_utils.py @@ -28,7 +28,7 @@ class VersionTest(TestCase): def test_version_is_pep440_compliant(self): - self.assertTrue(is_canonical(VERSION)) + assert is_canonical(VERSION) class UtilsTest(TestCase): @@ -49,12 +49,12 @@ def test_write_filelist(self): write_filelist(basic_files, basic_path) with open(basic_path, encoding="utf8") as f: headers = f.readline().strip().split("|") - self.assertEqual(len(headers), 5) - self.assertEqual(headers[0], "basename") - self.assertEqual(headers[1], "language") - self.assertEqual(headers[2], "characters") - self.assertEqual(headers[3], "phones") - self.assertEqual(headers[4], "extra") + assert len(headers) == 5 + assert headers[0] == "basename" + assert headers[1] == "language" + assert headers[2] == "characters" + assert headers[3] == "phones" + assert headers[4] == "extra" class ContextableBaseModel(BaseModel): @@ -92,7 +92,7 @@ def test_using_a_directory_with_context(self): root_dir = Path(__file__).parent / "data" root_dir = root_dir.resolve() directory = Path("hierarchy") - self.assertTrue((root_dir / directory).exists()) + assert (root_dir / directory).exists() with init_context({"writing_config": root_dir}): PathIsADirectory(path=directory) except ValueError: @@ -127,7 +127,7 @@ def test_using_a_directory(self): root_dir = Path(__file__).parent / "data" root_dir = root_dir.resolve() directory = Path("hierarchy") - self.assertTrue((root_dir / directory).exists()) + assert (root_dir / directory).exists() PathIsADirectory(path=root_dir / directory) except ValueError: self.fail("Failed to detect that the argument is a directory") @@ -173,7 +173,7 @@ def test_already_absolute(self): path = Path(__file__).absolute() with init_context({"config_path": path.parent / "data"}): dir = RelativePathToAbsolute(path=path) - self.assertEqual(dir.path, path) + assert dir.path == path def test_should_not_change(self): """ @@ -181,7 +181,7 @@ def test_should_not_change(self): """ path = Path("data") test = RelativePathToAbsolute(path=path) - self.assertEqual(test.path, path) + assert test.path == path def test_with_context(self): """ @@ -191,7 +191,7 @@ def test_with_context(self): path = Path("data") with init_context({"config_path": root_dir}): dir = RelativePathToAbsolute(path=path) - self.assertTrue(dir.path.is_absolute()) + assert dir.path.is_absolute() class DirectoryPathMustExist(ContextableBaseModel): @@ -226,8 +226,8 @@ def test_using_a_directory_with_context(self): with init_context({"writing_config": root_dir.resolve()}): dir = DirectoryPathMustExist(path=directory) # Note: dir.path shouldn't not change to an absolute value. - self.assertEqual(dir.path, directory) - self.assertTrue((root_dir / directory).exists()) + assert dir.path == directory + assert (root_dir / directory).exists() # Note: since dir.path is NOT replaced with an absolute it # shouldn't exist because it was created relative to the context's # path. @@ -243,7 +243,7 @@ def test_path_already_exists(self): dir = DirectoryPathMustExist(path=path) # There should be no info logged. self.assertListEqual(output, []) - self.assertTrue(dir.path.exists()) + assert dir.path.exists() def test_using_a_directory(self): """ @@ -255,13 +255,13 @@ def test_using_a_directory(self): with patch_logger(everyvoice.config.validation_helpers) as logger: with self.assertLogs(logger) as cm: dir = DirectoryPathMustExist(path=path) - self.assertEqual(dir.path, path) + assert dir.path == path self.assertIn( f"Directory at {path} does not exist. Creating...", "".join(cm.output), ) - self.assertTrue(path.exists()) - self.assertTrue(dir.path.exists()) + assert path.exists() + assert dir.path.exists() class GetDeviceFromAcceleratorTest(TestCase): @@ -272,16 +272,16 @@ def test_auto(self): ) def test_cpu(self): - self.assertEqual(get_device_from_accelerator("cpu"), torch.device("cpu")) + assert get_device_from_accelerator("cpu") == torch.device("cpu") def test_gpu(self): - self.assertEqual(get_device_from_accelerator("gpu"), torch.device("cuda:0")) + assert get_device_from_accelerator("gpu") == torch.device("cuda:0") def test_mps(self): - self.assertEqual(get_device_from_accelerator("mps"), torch.device("mps")) + assert get_device_from_accelerator("mps") == torch.device("mps") def test_unknown_accelerator(self): - self.assertEqual(get_device_from_accelerator("unknown"), torch.device("cpu")) + assert get_device_from_accelerator("unknown") == torch.device("cpu") if __name__ == "__main__": diff --git a/everyvoice/tests/test_wizard.py b/everyvoice/tests/test_wizard.py index 7cae1f10..29dbdbb5 100755 --- a/everyvoice/tests/test_wizard.py +++ b/everyvoice/tests/test_wizard.py @@ -203,7 +203,7 @@ def recursive_helper(steps_and_answers: Iterable[StepAndAnswer]): tour = Tour(name, steps=[step for (step, *_) in steps_and_answers]) # fail on accidentally shared initializer - self.assertTrue(tour.state == {} or tour.state == {"dataset_0": {}}) + assert tour.state == {} or tour.state == {"dataset_0": {}} with capture_stdout() as out: recursive_helper(steps_and_answers) return tour, out.getvalue() @@ -230,8 +230,8 @@ def test_config_format_effect(self): testing that no exceptions get raised. """ config_step = basic.ConfigFormatStep(name="Config Step") - self.assertTrue(config_step.validate("yaml")) - self.assertTrue(config_step.validate("json")) + assert config_step.validate("yaml") + assert config_step.validate("json") with tempfile.TemporaryDirectory() as tmpdirname: config_step._state = State() config_step.state[SN.output_step.value] = tmpdirname @@ -277,7 +277,7 @@ def test_config_format_effect(self): f.read(), "basename|language|speaker|text\n0001|und|default|hello\n0002|und|default|hello\n0003|und|default|hello\n", ) - self.assertIn("Congratulations", stdout.getvalue()) + assert "Congratulations" in stdout.getvalue() self.assertTrue( (Path(tmpdirname) / config_step.name / "logs_and_checkpoints").exists() ) @@ -300,8 +300,8 @@ def validate(self, x): second_step.validate = MethodType(validate, second_step) for i, node in enumerate(PreOrderIter(root_step)): if i != 0: - self.assertEqual(second_step.parent.response, "foo") - self.assertTrue(node.validate("bar")) + assert second_step.parent.response == "foo" + assert node.validate("bar") self.assertFalse(node.validate("foo")) node.run() @@ -318,8 +318,8 @@ def test_visualize(self): with capture_stdout() as out: tour.visualize() log = out.getvalue() - self.assertIn("── Contact Name ", log) - self.assertIn("── Validate Wavs ", log) + assert "── Contact Name " in log + assert "── Validate Wavs " in log def test_name_step(self): """Exercise providing a valid dataset name.""" @@ -327,9 +327,9 @@ def test_name_step(self): with capture_stdout() as stdout: with patch_questionary("myname"): step.run() - self.assertEqual(step.response, "myname") - self.assertIn("'myname'", stdout.getvalue()) - self.assertTrue(step.completed) + assert step.response == "myname" + assert "'myname'" in stdout.getvalue() + assert step.completed def test_bad_name_step(self): """Exercise providing an invalid dataset name.""" @@ -341,18 +341,18 @@ def test_bad_name_step(self): self.assertFalse(step.validate("foo/bar")) self.assertFalse(step.validate("")) output = stdout.getvalue() - self.assertIn("'foo/bar'", output) - self.assertIn("is not valid", output) - self.assertIn("your project needs a name", output) + assert "'foo/bar'" in output + assert "is not valid" in output + assert "your project needs a name" in output step = basic.NameStep("") with capture_stdout() as stdout: with patch_questionary(("bad/name", "good-name"), True): step.run() output = stdout.getvalue() - self.assertIn("'bad/name'", stdout.getvalue()) - self.assertIn("is not valid", stdout.getvalue()) - self.assertEqual(step.response, "good-name") + assert "'bad/name'" in stdout.getvalue() + assert "is not valid" in stdout.getvalue() + assert step.response == "good-name" def test_bad_contact_name_step(self): """Exercise providing an invalid contact name.""" @@ -361,8 +361,8 @@ def test_bad_contact_name_step(self): self.assertFalse(step.validate("a")) self.assertFalse(step.validate("")) output = stdout.getvalue() - self.assertIn("Sorry", output) - self.assertIn("EveryVoice requires a name", output) + assert "Sorry" in output + assert "EveryVoice requires a name" in output def test_bad_contact_email_step(self): """Exercise providing an invalid contact email.""" @@ -371,7 +371,7 @@ def test_bad_contact_email_step(self): self.assertFalse(step.validate("test")) self.assertFalse(step.validate("test@")) self.assertFalse(step.validate("test@test.")) - self.assertTrue(step.validate("test@test.ca")) + assert step.validate("test@test.ca") self.assertFalse(step.validate("")) output = stdout.getvalue().replace(" \n", " ") # Supporting email-validator prior and post 2.2.0 where the error string changed. @@ -379,8 +379,8 @@ def test_bad_contact_email_step(self): "It must have exactly one @-sign" in output or "An email address must have an @-sign" in output ) - self.assertIn("There must be something after the @-sign", output) - self.assertIn("An email address cannot end with a period", output) + assert "There must be something after the @-sign" in output + assert "An email address cannot end with a period" in output def test_no_permissions(self): """Exercise lacking permissions, then trying again""" @@ -388,17 +388,17 @@ def test_no_permissions(self): permission_step = find_step(SN.dataset_permission_step, tour.steps) self.assertGreater(len(permission_step.children), 8) self.assertGreater(len(tour.root.descendants), 14) - self.assertIn("dataset_0", tour.state) + assert "dataset_0" in tour.state with patch_menu_prompt(0): # 0 is no, I don't have permission permission_step.run() - self.assertEqual(permission_step.children, ()) + assert permission_step.children == () self.assertLess(len(tour.root.descendants), 10) self.assertNotIn("dataset_0", tour.state) more_dataset_step = find_step(SN.more_datasets_step, tour.steps) with patch_menu_prompt(1): # 1 is Yes, I have more data more_dataset_step.run() - self.assertIn("dataset_0", tour.state) + assert "dataset_0" in tour.state self.assertGreater(len(more_dataset_step.descendants), 8) self.assertGreater(len(tour.root.descendants), 14) @@ -442,17 +442,17 @@ def test_output_path_step(self): ro_dir.mkdir(mode=0x555) with capture_stdout() as out: self.assertFalse(step.validate(str(ro_dir))) - self.assertIn("could not create", out.getvalue()) + assert "could not create" in out.getvalue() # Case with a deep path, make sure we don't leave part of it around - self.assertTrue(step.validate(tmpdir / "deep" / "path")) + assert step.validate(tmpdir / "deep" / "path") self.assertFalse((tmpdir / "deep").exists()) # Good case with capture_stdout() as stdout: with monkeypatch(step, "prompt", Say(tmpdirname)): step.run() - self.assertIn("will put your files", stdout.getvalue()) + assert "will put your files" in stdout.getvalue() def test_more_data_step(self): """Exercise giving an invalid response and a yes response to more data.""" @@ -463,12 +463,12 @@ def test_more_data_step(self): step = tour.steps[1] self.assertFalse(step.validate("foo")) - self.assertTrue(step.validate("yes")) - self.assertEqual(len(step.children), 0) + assert step.validate("yes") + assert len(step.children) == 0 with patch_menu_prompt(0): # answer 0 is "no" step.run() - self.assertEqual(len(step.children), 1) + assert len(step.children) == 1 self.assertIsInstance(step.children[0], basic.ConfigFormatStep) with patch_menu_prompt(1): # answer 1 is "yes" @@ -481,8 +481,8 @@ def test_no_data_to_save(self): step = tour.steps[0] with patch_menu_prompt(0), capture_stdout() as out: # answer 0 is "no" step.run() - self.assertEqual(len(step.children), 0) - self.assertIn("No dataset to save", out.getvalue()) + assert len(step.children) == 0 + assert "No dataset to save" in out.getvalue() def test_dataset_name(self): step = dataset.DatasetNameStep() @@ -490,10 +490,10 @@ def test_dataset_name(self): with capture_stdout() as stdout: step.run() output = stdout.getvalue().replace(" \n", " ").split("\n") - self.assertIn("your dataset needs a name", output[0]) - self.assertIn("is not valid", output[1]) - self.assertIn("finished the configuration", "".join(output[2:])) - self.assertTrue(step.completed) + assert "your dataset needs a name" in output[0] + assert "is not valid" in output[1] + assert "finished the configuration" in "".join(output[2:]) + assert step.completed def test_unique_dataset_name(self): tour = Tour( @@ -506,15 +506,15 @@ def test_unique_dataset_name(self): ) with patch_questionary("set1"): tour.steps[0].run() - self.assertEqual(tour.state["dataset_0"][SN.dataset_name_step], "set1") + assert tour.state["dataset_0"][SN.dataset_name_step] == "set1" with patch_questionary(("set1", "set2")), capture_stdout() as out: tour.steps[1].run() - self.assertIn("Please choose unique", flatten_log(out.getvalue())) - self.assertEqual(tour.state["dataset_1"][SN.dataset_name_step], "set2") + assert "Please choose unique" in flatten_log(out.getvalue()) + assert tour.state["dataset_1"][SN.dataset_name_step] == "set2" with patch_questionary(("set1", "set2", "set3")), capture_stdout() as out: tour.steps[2].run() - self.assertIn("Please choose unique", flatten_log(out.getvalue())) - self.assertEqual(tour.state["dataset_2"][SN.dataset_name_step], "set3") + assert "Please choose unique" in flatten_log(out.getvalue()) + assert tour.state["dataset_2"][SN.dataset_name_step] == "set3" def test_speaker_name(self): step = dataset.AddSpeakerStep() @@ -522,10 +522,10 @@ def test_speaker_name(self): with capture_stdout() as stdout: step.run() output = stdout.getvalue().replace(" \n", " ").split("\n") - self.assertIn("speaker needs an ID", output[0]) - self.assertIn("is not valid", output[1]) - self.assertIn("will be used as the speaker ID", "".join(output[2:])) - self.assertTrue(step.completed) + assert "speaker needs an ID" in output[0] + assert "is not valid" in output[1] + assert "will be used as the speaker ID" in "".join(output[2:]) + assert step.completed def test_wavs_dir(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -543,8 +543,8 @@ def test_wavs_dir(self): with patch_questionary(("not-a-path", no_wavs_dir, has_wavs_dir)): with capture_stdout(): step.run() - self.assertTrue(step.completed) - self.assertEqual(step.response, has_wavs_dir) + assert step.completed + assert step.response == has_wavs_dir # No symlinks on Windows, so just skip those test case subitems if os.name == "nt": @@ -560,8 +560,8 @@ def test_wavs_dir(self): with patch_questionary(wavs_are_symlinks): with capture_stdout(): step.run() - self.assertTrue(step.completed) - self.assertEqual(step.response, wavs_are_symlinks) + assert step.completed + assert step.response == wavs_are_symlinks wavs_dir_is_symlink = os.path.join(tmpdirname, "link-to-wavs-dir") os.symlink(wavs_are_symlinks, wavs_dir_is_symlink) @@ -569,8 +569,8 @@ def test_wavs_dir(self): with patch_questionary(wavs_dir_is_symlink): with capture_stdout(): step.run() - self.assertTrue(step.completed) - self.assertEqual(step.response, wavs_dir_is_symlink) + assert step.completed + assert step.response == wavs_dir_is_symlink # wavs_are_symlinks and wavs_dir_is_symlink pass even if # WavsDirStep.validate() were to use Path.glob(), but deeper_links @@ -583,8 +583,8 @@ def test_wavs_dir(self): with patch_questionary(deeper_links): with capture_stdout(): step.run() - self.assertTrue(step.completed) - self.assertEqual(step.response, deeper_links) + assert step.completed + assert step.response == deeper_links def test_sample_rate_config(self): step = dataset.SampleRateConfigStep("") @@ -602,9 +602,9 @@ def test_sample_rate_config(self): step.run() output = stdout.getvalue().replace(" \n", " ").split(".\n") for i in range(4): - self.assertIn("not a valid sample rate", output[i]) - self.assertTrue(step.completed) - self.assertEqual(step.response, 512) + assert "not a valid sample rate" in output[i] + assert step.completed + assert step.response == 512 def test_whitespace_always_collapsed(self): tour = Tour("unit testing", steps=dataset.get_dataset_steps()) @@ -699,37 +699,37 @@ def test_dataset_subtour(self): format_step = find_step(SN.filelist_format_step, tour.steps) with patch_menu_prompt(0): # 0 is "psv" format_step.run() - self.assertEqual(len(format_step.children), 3) + assert len(format_step.children) == 3 step = format_step.children[0] self.assertIsInstance(step, dataset.HasHeaderLineStep) - self.assertEqual(step.name, SN.data_has_header_line_step.value) + assert step.name == SN.data_has_header_line_step.value with patch_menu_prompt(1): # 1 is "yes" step.run() - self.assertEqual(step.state[SN.data_has_header_line_step.value], "yes") - self.assertEqual(len(step.state["filelist_data_list"]), 5) + assert step.state[SN.data_has_header_line_step.value] == "yes" + assert len(step.state["filelist_data_list"]) == 5 step = format_step.children[1] self.assertIsInstance(step, dataset.HeaderStep) - self.assertEqual(step.name, SN.basename_header_step.value) + assert step.name == SN.basename_header_step.value with patch_menu_prompt(1): # 1 is second column step.run() - self.assertEqual(step.response, 1) - self.assertEqual(step.state["filelist_headers"][1], "basename") + assert step.response == 1 + assert step.state["filelist_headers"][1] == "basename" step = format_step.children[2] self.assertIsInstance(step, dataset.HeaderStep) - self.assertEqual(step.name, SN.text_header_step.value) + assert step.name == SN.text_header_step.value with patch_menu_prompt(1): # 1 is second remaining column, i.e., third column step.run() - self.assertEqual(step.state["filelist_headers"][2], "text") + assert step.state["filelist_headers"][2] == "text" text_representation_step = find_step( SN.filelist_text_representation_step, tour.steps ) with patch_menu_prompt(0): # 0 is "characters" text_representation_step.run() - self.assertEqual(step.state["filelist_headers"][2], "characters") + assert step.state["filelist_headers"][2] == "characters" text_processing_step = find_step(SN.text_processing_step, tour.steps) # 0 is lowercase, 1 is NFC Normalization, select both @@ -751,27 +751,27 @@ def test_dataset_subtour(self): children_before = len(speaker_step.children) with patch_menu_prompt(0): # 0 is "no" speaker_step.run() - self.assertEqual(len(speaker_step.children), children_before + 1) + assert len(speaker_step.children) == children_before + 1 self.assertIsInstance(speaker_step.children[0], dataset.KnowSpeakerStep) know_speaker_step = speaker_step.children[0] children_before = len(know_speaker_step.children) with patch_menu_prompt(1): # 1 is "yes" know_speaker_step.run() - self.assertEqual(len(know_speaker_step.children), children_before + 1) + assert len(know_speaker_step.children) == children_before + 1 self.assertIsInstance(know_speaker_step.children[0], dataset.AddSpeakerStep) add_speaker_step = know_speaker_step.children[0] children_before = len(add_speaker_step.children) with patch_questionary("default"), capture_stdout(): add_speaker_step.run() - self.assertEqual(len(add_speaker_step.children), children_before) + assert len(add_speaker_step.children) == children_before language_step = find_step(SN.data_has_language_value_step, tour.steps) children_before = len(language_step.children) with patch_menu_prompt(0): # 0 is "no" language_step.run() - self.assertEqual(len(language_step.children), children_before + 1) + assert len(language_step.children) == children_before + 1 self.assertIsInstance(language_step.children[0], dataset.SelectLanguageStep) select_lang_step = language_step.children[0] @@ -789,12 +789,12 @@ def test_dataset_subtour(self): ) # Make sure realoading the data as dict stripped the header line - self.assertEqual(len(step.state["filelist_data"]), 4) + assert len(step.state["filelist_data"]) == 4 custom_g2p_step = find_step(SN.custom_g2p_step, tour.steps) with monkeypatch(custom_g2p_step, "prompt", Say(0)): custom_g2p_step.run() - self.assertEqual(custom_g2p_step.language_codes, ["fin"]) + assert custom_g2p_step.language_codes == ["fin"] wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps) with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))): @@ -803,15 +803,15 @@ def test_dataset_subtour(self): validate_wavs_step = find_step(SN.validate_wavs_step, tour.steps) with patch_menu_prompt(1), capture_stdout() as out: validate_wavs_step.run() - self.assertEqual(step.state[SN.validate_wavs_step][:2], "No") + assert step.state[SN.validate_wavs_step][:2] == "No" self.assertRegex(out.getvalue(), "Warning: .*4.* wav files were not found") symbol_set_step = find_step(SN.symbol_set_step, tour.steps) - self.assertEqual(len(symbol_set_step.state["filelist_data"]), 4) + assert len(symbol_set_step.state["filelist_data"]) == 4 with capture_stdout(), capture_stderr(): symbol_set_step.run() - self.assertEqual(len(symbol_set_step.state[SN.symbol_set_step.value]), 2) - self.assertIn("t͡s", symbol_set_step.state[SN.symbol_set_step.value]["phones"]) + assert len(symbol_set_step.state[SN.symbol_set_step.value]) == 2 + assert "t͡s" in symbol_set_step.state[SN.symbol_set_step.value]["phones"] self.assertNotIn( ":", symbol_set_step.state[SN.symbol_set_step.value]["characters"] ) @@ -872,8 +872,8 @@ def test_empty_filelist(self): with patch_menu_prompt(1) as stdout: format_step.run() output = stdout.getvalue() - self.assertIn("is empty", output) - self.assertEqual(cm.exception.code, 1) + assert "is empty" in output + assert cm.exception.code == 1 def test_wrong_fileformat_psv(self): tour = Tour( @@ -895,8 +895,8 @@ def test_wrong_fileformat_psv(self): output = flatten_log(stdout.getvalue()) self.assertRegex(output, r"does not look like a .*'tsv'") self.assertRegex(output, r"does not look like a .*'csv'") - self.assertIn("is not in the festival format", output) - self.assertTrue(format_step.completed) + assert "is not in the festival format" in output + assert format_step.completed # print(format_step.state) def test_wrong_fileformat_festival(self): @@ -920,7 +920,7 @@ def test_wrong_fileformat_festival(self): self.assertRegex(output, r"does not look like a .*'psv'") self.assertRegex(output, r"does not look like a .*'tsv'") self.assertRegex(output, r"does not look like a .*'csv'") - self.assertTrue(format_step.completed) + assert format_step.completed # print(format_step.state) def test_validate_path(self): @@ -969,13 +969,13 @@ def test_prompt(self): answer = prompts.get_response_from_menu_prompt( choices=("choice1", "choice2") ) - self.assertEqual(answer, "choice1") + assert answer == "choice1" with patch_menu_prompt(1) as stdout: answer = prompts.get_response_from_menu_prompt( "some question", ("choice1", "choice2") ) - self.assertEqual(answer, "choice2") - self.assertIn("some question", stdout.getvalue()) + assert answer == "choice2" + assert "some question" in stdout.getvalue() with patch_menu_prompt((2, 4)): answer = prompts.get_response_from_menu_prompt( choices=("a", "b", "c", "d", "e"), multi=True @@ -985,7 +985,7 @@ def test_prompt(self): answer = prompts.get_response_from_menu_prompt( choices=("a", "b", "c", "d", "e"), return_indices=True ) - self.assertEqual(answer, 1) + assert answer == 1 def test_monkey_tour_1(self): with tempfile.TemporaryDirectory() as tmpdirname, capture_stdout(): @@ -1000,8 +1000,8 @@ def test_monkey_tour_1(self): ), ], ) - self.assertEqual(tour.state[SN.name_step.value], "my-dataset-name") - self.assertEqual(tour.state[SN.output_step.value], tmpdirname) + assert tour.state[SN.name_step.value] == "my-dataset-name" + assert tour.state[SN.output_step.value] == tmpdirname def test_monkey_tour_2(self): data_dir = Path(__file__).parent / "data" @@ -1049,19 +1049,19 @@ def test_monkey_tour_2(self): ) tree = str(RenderTree(tour.root)) - self.assertIn("├── Validate Wavs Step", tree) - self.assertIn("│ └── Validate Wavs Step", tree) - self.assertIn("Great! All audio files found in directory", out) + assert "├── Validate Wavs Step" in tree + assert "│ └── Validate Wavs Step" in tree + assert "Great! All audio files found in directory" in out # print(tour.state) - self.assertEqual(len(tour.state["filelist_data"]), 5) - self.assertTrue(tour.steps[-1].completed) + assert len(tour.state["filelist_data"]) == 5 + assert tour.steps[-1].completed def test_get_iso_code(self): - self.assertEqual(utils.get_iso_code("eng"), "eng") - self.assertEqual(utils.get_iso_code("[eng]"), "eng") - self.assertEqual(utils.get_iso_code("es"), "es") - self.assertEqual(utils.get_iso_code("[es]"), "es") + assert utils.get_iso_code("eng") == "eng" + assert utils.get_iso_code("[eng]") == "eng" + assert utils.get_iso_code("es") == "es" + assert utils.get_iso_code("[es]") == "es" self.assertIs(utils.get_iso_code(None), None) def test_with_language_column(self): @@ -1134,17 +1134,17 @@ def test_with_language_column(self): ], ) - self.assertEqual(tour.state["dataset_0"][SN.speaker_header_step.value], 2) - self.assertEqual(tour.state["dataset_0"][SN.language_header_step.value], 3) - self.assertTrue(tour.steps[-1].completed) + assert tour.state["dataset_0"][SN.speaker_header_step.value] == 2 + assert tour.state["dataset_0"][SN.language_header_step.value] == 3 + assert tour.steps[-1].completed with open( tmpdir / "out/project/config/everyvoice-text-to-spec.yaml", encoding="utf8", ) as f: text_to_spec_config = "\n".join(f) - self.assertIn("multilingual: true", text_to_spec_config) - self.assertIn("multispeaker: true", text_to_spec_config) + assert "multilingual: true" in text_to_spec_config + assert "multispeaker: true" in text_to_spec_config def test_no_header_line(self): with tempfile.TemporaryDirectory() as tmpdir_s: @@ -1228,19 +1228,19 @@ def test_no_header_line(self): ), ], ) - self.assertEqual(len(tour.state["dataset_0"]["filelist_data"]), 3) + assert len(tour.state["dataset_0"]["filelist_data"]) == 3 with open( tmpdir / "out/project/dataset-filelist.psv", encoding="utf8" ) as f: output_filelist = list(f) - self.assertEqual(len(output_filelist), 4) + assert len(output_filelist) == 4 with open( tmpdir / "out/project/config/everyvoice-text-to-spec.yaml", encoding="utf8", ) as f: text_to_spec_config = "\n".join(f) - self.assertIn("multilingual: false", text_to_spec_config) - self.assertIn("multispeaker: false", text_to_spec_config) + assert "multilingual: false" in text_to_spec_config + assert "multispeaker: false" in text_to_spec_config def test_running_out_of_columns(self): with tempfile.TemporaryDirectory() as tmpdir_s: @@ -1331,7 +1331,7 @@ def test_running_out_of_columns(self): self.assertEqual( tour.state["dataset_0"][SN.data_has_language_value_step.value], "no" ) - self.assertEqual(len(tour.state["dataset_0"]["filelist_data"]), 3) + assert len(tour.state["dataset_0"]["filelist_data"]) == 3 with open( tmpdir / "out/project/dataset-filelist.psv", encoding="utf8" ) as f: @@ -1385,7 +1385,7 @@ def test_leading_white_space_in_outpath(self): ) tour.run() self.assertFalse(tour.state[SN.output_step.value].startswith(" ")) - self.assertEqual(tour.state[SN.output_step.value], tmpdirname) + assert tour.state[SN.output_step.value] == tmpdirname def test_leading_white_space_in_wav_dir(self): """ @@ -1399,7 +1399,7 @@ def test_leading_white_space_in_wav_dir(self): with create_app_session(input=pipe_input, output=DummyOutput()): step.run() self.assertFalse(step.response.startswith(" ")) - self.assertEqual(step.response, str(path)) + assert step.response == str(path) def test_leading_white_space_in_filelist(self): """ @@ -1413,7 +1413,7 @@ def test_leading_white_space_in_filelist(self): with create_app_session(input=pipe_input, output=DummyOutput()): step.run() self.assertFalse(step.response.startswith(" ")) - self.assertEqual(step.response, str(path)) + assert step.response == str(path) def test_festival(self): with tempfile.TemporaryDirectory() as tmpdir_s: @@ -1689,8 +1689,8 @@ def test_multilingual_multispeaker_true_config(self) -> None: encoding="utf8", ) as f: text_to_spec_config = "\n".join(f) - self.assertIn("multilingual: true", text_to_spec_config) - self.assertIn("multispeaker: true", text_to_spec_config) + assert "multilingual: true" in text_to_spec_config + assert "multispeaker: true" in text_to_spec_config # Assertions about dataset-specific and global cleaners with open( @@ -1910,8 +1910,8 @@ def test_multilingual_multispeaker_false_config(self) -> None: encoding="utf8", ) as f: text_to_spec_config = "\n".join(f) - self.assertIn("multilingual: false", text_to_spec_config) - self.assertIn("multispeaker: false", text_to_spec_config) + assert "multilingual: false" in text_to_spec_config + assert "multispeaker: false" in text_to_spec_config trivial_tour_results = { SN.name_step.value: "project_name", @@ -1936,7 +1936,7 @@ def test_control_c_go_back(self): ): with patch_menu_prompt(0): # say 0==go back each time tour.run() - self.assertEqual(tour.state, self.trivial_tour_results) + assert tour.state == self.trivial_tour_results def test_control_c_continue(self): # Ctrl-C plus option 1 continues @@ -1948,7 +1948,7 @@ def test_control_c_continue(self): # Ctrl-C once, then hit 1 to continue with patch_menu_prompt([KeyboardInterrupt(), 1], multi=True): tour.run() - self.assertEqual(tour.state, self.trivial_tour_results) + assert tour.state == self.trivial_tour_results def test_control_c_display_tree(self): # Ctrl-C plus option 2 displays the current tree @@ -1959,8 +1959,8 @@ def test_control_c_display_tree(self): ): with patch_menu_prompt(2) as output: tour.run() - self.assertIn("Contact Name: Jane Doe", flatten_log(output.getvalue())) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "Contact Name: Jane Doe" in flatten_log(output.getvalue()) + assert tour.state == self.trivial_tour_results progress_template = dedent( """\ @@ -1993,7 +1993,7 @@ def test_control_c_save_progress(self): ): with patch_menu_prompt(3): tour.run() - self.assertTrue(progress_file.exists()) + assert progress_file.exists() with open(progress_file, encoding="utf8") as f: progress_contents = f.read() # print(progress_contents) @@ -2011,8 +2011,8 @@ def test_resume_from(self): tour = make_trivial_tour() with patch_questionary("email@mail.com"), capture_stdout() as out: tour.run(resume_from=progress_file) - self.assertIn("Applying saved response", out.getvalue()) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "Applying saved response" in out.getvalue() + assert tour.state == self.trivial_tour_results def test_resume_from_the_future(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -2031,10 +2031,10 @@ def test_resume_from_the_future(self): with patch_questionary("email@mail.com"), capture_stdout() as out: tour.run(resume_from=changed_version) output = flatten_log(out.getvalue()) - self.assertIn("Proceeding anyway", output) - self.assertIn("consider updating your software", output) - self.assertIn("Applying saved response", output) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "Proceeding anyway" in output + assert "consider updating your software" in output + assert "Applying saved response" in output + assert tour.state == self.trivial_tour_results def test_resume_from_near_past(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -2049,10 +2049,10 @@ def test_resume_from_near_past(self): with patch_questionary("email@mail.com"), capture_stdout() as out: tour.run(resume_from=changed_version) output = flatten_log(out.getvalue()) - self.assertIn("expected to be compatible", output) - self.assertIn("Proceeding anyway", output) - self.assertIn("Applying saved response", output) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "expected to be compatible" in output + assert "Proceeding anyway" in output + assert "Applying saved response" in output + assert tour.state == self.trivial_tour_results def test_resume_from_far_past(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -2065,10 +2065,10 @@ def test_resume_from_far_past(self): with patch_questionary("email@mail.com"), capture_stdout() as out: tour.run(resume_from=changed_version) output = flatten_log(out.getvalue()) - self.assertIn("not fully compatible", output) - self.assertIn("Proceeding anyway", output) - self.assertIn("Applying saved response", output) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "not fully compatible" in output + assert "Proceeding anyway" in output + assert "Applying saved response" in output + assert tour.state == self.trivial_tour_results def test_resume_with_invalid_progress_files(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -2082,8 +2082,8 @@ def test_resume_with_invalid_progress_files(self): tour = make_trivial_tour() with patch_questionary("email@mail.com"), capture_stdout() as out: tour.run(resume_from=invalid_response) - self.assertIn("Error: saved response 'invalid email'", out.getvalue()) - self.assertEqual(tour.state, self.trivial_tour_results) + assert "Error: saved response 'invalid email'" in out.getvalue() + assert tour.state == self.trivial_tour_results # From here on, it's lots of ways to fail, which always causes a SystemExit @@ -2125,7 +2125,7 @@ def test_resume_with_invalid_progress_files(self): f.write("".join(progress_lines[-4:-2])) with self.assertRaises(SystemExit), capture_stdout() as out: tour.run(resume_from=questions_out_of_order) - self.assertIn("out of sync", out.getvalue()) + assert "out of sync" in out.getvalue() extra_question_not_in_tour = tmpdir / "extra-question" with open(extra_question_not_in_tour, "w", encoding="utf8") as f: @@ -2135,7 +2135,7 @@ def test_resume_with_invalid_progress_files(self): tour = make_trivial_tour() with self.assertRaises(SystemExit), capture_stdout() as out: tour.run(resume_from=extra_question_not_in_tour) - self.assertIn("saved responses left", out.getvalue()) + assert "saved responses left" in out.getvalue() wrong_software_name = tmpdir / "wrong-software-name" with open(wrong_software_name, "w", encoding="utf8") as f: @@ -2143,7 +2143,7 @@ def test_resume_with_invalid_progress_files(self): f.write("".join(progress_lines[1:])) with self.assertRaises(SystemExit), capture_stdout() as out: tour.run(resume_from=wrong_software_name) - self.assertIn("it is for software", flatten_log(out.getvalue())) + assert "it is for software" in flatten_log(out.getvalue()) def test_control_c_exit(self): # Ctrl-C plus option 4 (Exit) exits @@ -2169,12 +2169,12 @@ def test_trace(self): tour.run() for step in tour.steps: # When not the current step: - self.assertIn(step.name.replace(" Step", "") + " ", out.getvalue()) + assert step.name.replace(" Step", "") + " " in out.getvalue() # When it is the current step: self.assertRegex(out.getvalue(), step.name.replace(" Step", "") + " *←") # When previously filled: if step != tour.steps[-1]: - self.assertIn(step.name.replace(" Step", "") + ": ", out.getvalue()) + assert step.name.replace(" Step", "") + ": " in out.getvalue() def test_debug_state(self): tour = make_trivial_tour(debug_state=True) diff --git a/everyvoice/tests/test_wizard_helpers.py b/everyvoice/tests/test_wizard_helpers.py index 8ca2bb02..706b28de 100755 --- a/everyvoice/tests/test_wizard_helpers.py +++ b/everyvoice/tests/test_wizard_helpers.py @@ -263,18 +263,18 @@ def test_enum_dict(self): """Enum values need to behave the same with or without .value""" d = EnumDict() d[SN.audio_config_step] = "foo" - self.assertEqual(d[SN.audio_config_step.value], "foo") - self.assertEqual(d.get(SN.audio_config_step.value), "foo") + assert d[SN.audio_config_step.value] == "foo" + assert d.get(SN.audio_config_step.value) == "foo" d[SN.wavs_dir_step.value] = "bar" - self.assertEqual(d[SN.wavs_dir_step], "bar") - self.assertEqual(d.get(SN.wavs_dir_step), "bar") + assert d[SN.wavs_dir_step] == "bar" + assert d.get(SN.wavs_dir_step) == "bar" self.assertEqual(d.get(SN.filelist_format_step, None), None) self.assertEqual(d.get(SN.filelist_format_step.value, None), None) d.update({SN.contact_email_step: "a@b.com"}) - self.assertEqual(d[SN.contact_email_step.value], "a@b.com") + assert d[SN.contact_email_step.value] == "a@b.com" self.assertEqual( d, @@ -308,10 +308,10 @@ def test_node_mixin(self): forward_order = list(PreOrderIter(root)) for prev, next in zip(forward_order, forward_order[1:] + [None]): - self.assertEqual(prev.next(), next) + assert prev.next() == next for next, prev in zip(forward_order, [None] + forward_order[:-1]): - self.assertEqual(next.prev(), prev) + assert next.prev() == prev if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 5ed3d6db..3a2962e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "ipatok>=0.4.1", "librosa==0.11.0", "lightning>=2.1.0", - "loguru==0.6.0", + "loguru>=0.6.0", "matplotlib>=3.9.0", "merge-args", "nltk==3.9.3", @@ -112,7 +112,8 @@ test = [ "jsonschema>=4.17.3", "pep440>=0.1.2", "playwright>=1.52.0", - "pytest", + "pytest>=7", + "pytest-subtests", ] docs = [ "mkdocs>=1.5.2",