diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b82803c2..ca11102b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,17 +31,17 @@ jobs: - name: Create Librosa Results for Testing Against run: | - python testing/librosa_results_for_testing_against.py + python testing_mmm_audio/validation/librosa_results_for_testing_against.py - name: Run UnitTests.mojo id: run-tests run: | - mojo testing/UnitTests.mojo + mojo testing_mmm_audio/UnitTests.mojo - name: Test Building Mojo Files run: | - python testing/test_build_mojo_files.py + python testing_mmm_audio/test_build_mojo_files.py - name: Validate Against Snapshots run: | - python testing/validate_against_snapshot.py + python testing_mmm_audio/validation/validate_against_snapshot.py diff --git a/.gitignore b/.gitignore index 2132191d..11083304 100644 --- a/.gitignore +++ b/.gitignore @@ -222,11 +222,12 @@ __marimo__/ doc_generation/docs_md site .vscode -testing/librosa_results +testing_mmm_audio/validation/librosa_results .venv-257 venv0257 mine.worktrees *.o -testing/outputs -testing/mojo_results -testing/validation_results +testing_mmm_audio/outputs +testing_mmm_audio/validation/mojo_results +testing_mmm_audio/validation/validation_results + diff --git a/doc_generation/static_docs/contributing/testing.md b/doc_generation/static_docs/contributing/testing.md index 21a0b9d6..38b2067a 100644 --- a/doc_generation/static_docs/contributing/testing.md +++ b/doc_generation/static_docs/contributing/testing.md @@ -4,14 +4,14 @@ There are three kinds of tests the MMMAudio is set to run (see below). All these ## 1. Unit Tests -These can be found in [UnitTests.mojo](https://github.com/spluta/MMMAudio/blob/dev/testing/UnitTests.mojo). Some of these tests rely on the Librosa "results" generated by [`/testing/librosa_results_for_testing_against.py`](https://github.com/spluta/MMMAudio/blob/dev/testing/librosa_results_for_testing_against.py), so be sure to run this first. +These can be found in [UnitTests.mojo](https://github.com/spluta/MMMAudio/blob/dev/testing_mmm_audio/UnitTests.mojo). Some of these tests rely on the Librosa "results" generated by [`/testing_mmm_audio/validation/librosa_results_for_testing_against.py`](https://github.com/spluta/MMMAudio/blob/dev/testing_mmm_audio/validation/librosa_results_for_testing_against.py), so be sure to run this first. ## 2. "Smoke" Tests > Turn everything on and see if anything "catches on fire." -See [`testing/test_build_mojo_files.py`](https://github.com/spluta/MMMAudio/blob/dev/testing/test_build_mojo_files.py) to see how they are run. This script tries to "build" each `.mojo` file, generating a `.o` file which is then deleted. This way any syntax issues can be identified quickly and easily across all the (appropriate) `.mojo` files in the code base. +See [`testing_mmm_audio/test_build_mojo_files.py`](https://github.com/spluta/MMMAudio/blob/dev/testing_mmm_audio/test_build_mojo_files.py) to see how they are run. This script tries to "build" each `.mojo` file, generating a `.o` file which is then deleted. This way any syntax issues can be identified quickly and easily across all the (appropriate) `.mojo` files in the code base. ## 3. Snapshot Tests -Some code, such as the audio analyses, are difficult to test against a "ground truth" because different codebases have different opinions about how to implement them, have different levels of precision, or other concerns. These tests run select `.mojo` files, the outputs of which are [compared against](https://github.com/spluta/MMMAudio/blob/dev/testing/validate_against_snapshot.py) a "snapshot" of what they previously output (the previous snapshot is [taken](https://github.com/spluta/MMMAudio/blob/dev/testing/make_validation_snapshot.py) when the files are known to be working). These tests pass if the current outputs match the previous snapshot. \ No newline at end of file +Some code, such as the audio analyses, are difficult to test against a "ground truth" because different codebases have different opinions about how to implement them, have different levels of precision, or other concerns. These tests run select `.mojo` files, the outputs of which are [compared against](https://github.com/spluta/MMMAudio/blob/dev/testing_mmm_audio/validation/validate_against_snapshot.py) a "snapshot" of what they previously output (the previous snapshot is [taken](https://github.com/spluta/MMMAudio/blob/dev/testing_mmm_audio/make_validation_snapshot.py) when the files are known to be working). These tests pass if the current outputs match the previous snapshot. \ No newline at end of file diff --git a/examples/SpectrogramExample.mojo b/examples/SpectrogramExample.mojo new file mode 100644 index 00000000..d8f0745f --- /dev/null +++ b/examples/SpectrogramExample.mojo @@ -0,0 +1,37 @@ +from mmm_audio import * + +struct Spectrogram(FFTProcessable): + var world: World + var m: Messenger + var mags: List[Float64] + + fn __init__(out self, world: World, fftsize: Int = 1024): + self.world = world + self.m = Messenger(world) + self.mags = List[Float64](length=(fftsize // 2) + 1, fill=0.0) + + fn next_frame(mut self, mut mags: List[Float64], mut freqs: List[Float64]): + for i in range(len(mags)): + self.mags[i] = mags[i] + + fn send_streams(mut self) -> None: + self.m.reply_stream("mags", self.mags) + +struct SpectrogramExample(Movable, Copyable): + var world: World + var buf: Buffer + var play: Play + var m: Messenger + var fftproces: FFTProcess[Spectrogram] + + fn __init__(out self, world: World): + self.world = world + self.buf = Buffer.load("resources/Shiverer.wav") + self.play = Play(world) + self.m = Messenger(world) + self.fftproces = FFTProcess(world, Spectrogram(world)) + + fn next(mut self) -> MFloat[2]: + sig = self.play.next(self.buf) + _ = self.fftproces.next(sig) + return sig \ No newline at end of file diff --git a/examples/SpectrogramExample.py b/examples/SpectrogramExample.py new file mode 100644 index 00000000..727431f5 --- /dev/null +++ b/examples/SpectrogramExample.py @@ -0,0 +1,95 @@ +import sys +from pathlib import Path +import numpy as np + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from mmm_python import * +from PySide6.QtWidgets import QApplication, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsRectItem +from PySide6.QtCore import Signal, QObject, Qt +from PySide6.QtGui import QBrush, QPen, QColor + +class SpectrogramGraphicsView(QGraphicsView): + + def __init__(self, num_bars=513, parent=None): + super().__init__(parent) + self.num_bars = num_bars + self.bar_height = 400 + + self.scene = QGraphicsScene(self) + self.setScene(self.scene) + + self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.setRenderHint(self.renderHints().Antialiasing, False) + self.setViewportUpdateMode(QGraphicsView.MinimalViewportUpdate) + self.setOptimizationFlags(QGraphicsView.DontSavePainterState) + self.setBackgroundBrush(QBrush(QColor(0, 0, 0))) + + self.bars = [] + self.bar_width = 800 / num_bars + brush = QBrush(QColor(255, 255, 255)) + pen = QPen(Qt.NoPen) + + for i in range(num_bars): + bar = QGraphicsRectItem(0, 0, self.bar_width, 0) + bar.setBrush(brush) + bar.setPen(pen) + bar.setPos(i * self.bar_width, self.bar_height) + self.scene.addItem(bar) + self.bars.append(bar) + + self.scene.setSceneRect(0, 0, 800, self.bar_height) + self.setMinimumSize(800, 400) + + def update_data(self, data): + for i, mag in enumerate(data): + normalized = mag / 15.0 + bar_height = normalized * self.bar_height + + self.bars[i].setRect(0, 0, self.bar_width, bar_height) + self.bars[i].setPos(i * self.bar_width, self.bar_height - bar_height) + + def resizeEvent(self, event): + super().resizeEvent(event) + self.fitInView(self.scene.sceneRect(), Qt.IgnoreAspectRatio) + + +class SpectrogramWindow(QMainWindow): + data_ready = Signal(object) + + def __init__(self): + super().__init__() + self.setWindowTitle("Real-time Spectrogram") + + self.spectrogram_widget = SpectrogramGraphicsView(num_bars=513) + self.setCentralWidget(self.spectrogram_widget) + + self.data_ready.connect(self.spectrogram_widget.update_data) + + def callback(self, args): + self.data_ready.emit(args) + + def closeEvent(self, event): + QApplication.quit() + event.accept() + + +if __name__ == "__main__": + + app = QApplication() + + window = SpectrogramWindow() + window.show() + + m = MMMAudio(128, graph_name="SpectrogramExample", package_name="examples") + m.register_callback("mags", window.callback) + m.start_audio() + + def shutdown(): + m.stop_audio() + m.stop_process() + + app.aboutToQuit.connect(shutdown) + + sys.exit(app.exec()) \ No newline at end of file diff --git a/examples/ToPythonExample.mojo b/examples/ToPythonExample.mojo new file mode 100644 index 00000000..596ea02c --- /dev/null +++ b/examples/ToPythonExample.mojo @@ -0,0 +1,31 @@ + +from mmm_audio import * + +struct ToPythonExample(Movable, Copyable): + var world: World + var m: Messenger + var yin: BufferedInput[YIN[1024],1024,512] + var buf: Buffer + var play: Play + var vals: List[Float64] + + fn __init__(out self, world: World): + self.world = world + self.m = Messenger(self.world) + self.buf = Buffer.load("resources/Shiverer.wav") + self.play = Play(self.world) + yin = YIN[1024](self.world) + self.yin = BufferedInput[YIN[1024],1024,512](self.world,yin^) + self.vals = List[Float64]() + for i in range(1025): + self.vals.append(i / 1024.0) + + fn next(mut self) -> SIMD[DType.float64, 2]: + + sig = self.play.next(self.buf) + self.yin.next(sig) + self.m.reply_stream("pitch", self.yin.process.pitch) + self.m.reply_stream("vals", self.vals) + self.m.reply_stream("bool", random_float64() > 0.5) + + return SIMD[DType.float64, 2](sig, sig) \ No newline at end of file diff --git a/examples/ToPythonExample.py b/examples/ToPythonExample.py new file mode 100644 index 00000000..c6b3a400 --- /dev/null +++ b/examples/ToPythonExample.py @@ -0,0 +1,9 @@ +from mmm_python import * +m = MMMAudio(128, graph_name="ToPythonExample", package_name="examples") +m.register_callback("pitch", lambda args: print(f"pitch: {args}")) +m.register_callback("vals", lambda args: print(f"vals: {args}")) +m.register_callback("bool", lambda args: print(f"bool: {args}")) +m.register_callback("trig", lambda args: print(f"trig: {args}")) +m.start_audio() + +m.stop_audio() \ No newline at end of file diff --git a/mmm_audio/BufferedProcess_Module.mojo b/mmm_audio/BufferedProcess_Module.mojo index e42385cd..bd33da23 100644 --- a/mmm_audio/BufferedProcess_Module.mojo +++ b/mmm_audio/BufferedProcess_Module.mojo @@ -8,23 +8,42 @@ from math import floor # parameters. I think `hop_size` would still be a parameter of the BufferedProcess struct. trait BufferedProcessable(Movable, Copyable): """Trait that user structs must implement to be used with a BufferedProcess. - - Requires two functions: - - - `next_window(buffer: List[Float64]) -> None`: This function is called when enough samples have been buffered. - The user can process the input buffer in place meaning that the samples you want to return to the output need - to replace the samples that you receive in the input list. - - - `get_messages() -> None`: This function is called at the top of each audio block to allow the user to retrieve any messages - they may have sent to this process. Put your [Messenger](Messenger.md) message retrieval code here. (e.g. `self.messenger.update(self.param, "param_name")`) """ + fn next_window(mut self, mut buffer: List[Float64]) -> None: + """This function is called when enough samples have been buffered. + The user can process the input buffer in place meaning that the samples you want to return to the output need + to replace the samples that you receive in the input list. + + This function has a default implementation that does nothing so it is possible to *not* + implement it. This would probably be because a stereo process is implementing `next_stereo_window()` instead. + """ return None fn next_stereo_window(mut self, mut buffer: List[SIMD[DType.float64, 2]]) -> None: + """The stereo version of `next_window()`. See that for details. + + This function has a default implementation that does nothing so it is possible to *not* + implement it. This would probably be because a mono process is implementing `next_window()` instead. + """ return None fn get_messages(mut self) -> None: + """This function is called at the top of each audio block to allow the user to retrieve any messages + they may have sent to this process. Put your [Messenger](Messenger.md) message retrieval code here. + (e.g. `self.messenger.update(self.param, "param_name")`). + + This method has a default implementation that does nothing, so it is not necessary to + implement it if you don't need to retrieve any messages. + """ + return None + + fn send_streams(mut self) -> None: + """This function can be used to stream data back to Python. Put your [Messenger](Messenger.md) message sending code here. + (e.g. `self.messenger.reply_stream("stream_name", value)`). + + This method has a default implementation that does nothing, so it is not necessary to implement it if you don't need to send any stream data. + """ return None struct BufferedInput[T: BufferedProcessable, window_size: Int = 1024, hop_size: Int = 512, input_window_shape: Int = WindowType.hann](Movable, Copyable): @@ -161,6 +180,9 @@ struct BufferedProcess[T: BufferedProcessable, window_size: Int = 1024, hop_size """ if self.world[].top_of_block: self.process.get_messages() + + if self.world[].messengerManager.accepting_stream_data: + self.process.send_streams() self.input_buffer[self.input_buffer_write_head] = input self.input_buffer[self.input_buffer_write_head + Self.window_size] = input diff --git a/mmm_audio/FFTProcess_Module.mojo b/mmm_audio/FFTProcess_Module.mojo index 4475cde8..5301f7a6 100644 --- a/mmm_audio/FFTProcess_Module.mojo +++ b/mmm_audio/FFTProcess_Module.mojo @@ -41,6 +41,10 @@ struct FFTProcessor[T: FFTProcessable, window_size: Int = 1024](BufferedProcessa @doc_private fn get_messages(mut self) -> None: self.process.get_messages() + + @doc_private + fn send_streams(mut self) -> None: + self.process.send_streams() trait FFTProcessable(Movable,Copyable): """Implement this trait in a custom struct to pass to `FFTProcess` @@ -50,10 +54,33 @@ trait FFTProcessable(Movable,Copyable): using a struct that implements FFTProcessable. """ fn next_frame(mut self, mut magnitudes: List[Float64], mut phases: List[Float64]) -> None: + """This function is called when the internal buffered process has enough samples to + perform an FFT. The user can modify the magnitudes and phases in place to achieve + their desired spectral processing. The modified magnitudes and phases will then be + used by the internal buffered process to perform an IFFT and return to the time domain. + + This function has a default implementation that does nothing so it is possible to *not* + implement it. This would probably be because a stereo process is implementing + `next_stereo_frame()` instead. + """ return None fn next_stereo_frame(mut self, mut magnitudes: List[SIMD[DType.float64,2]], mut phases: List[SIMD[DType.float64,2]]) -> None: + """The stereo version of `next_frame()`. See that for details. + """ return None fn get_messages(mut self) -> None: + """This function is called at the top of each audio block to allow the user to retrieve any messages + they may have sent to this process. Put your [Messenger](Messenger.md) message retrieval code here. (e.g. `self.messenger.update(self.param, "param_name")`). + + This method has a default implementation that does nothing, so it is not necessary to implement it if you don't need to retrieve any messages. + """ + return None + fn send_streams(mut self) -> None: + """This function can be used to stream data back to Python. Put your [Messenger](Messenger.md) message sending code here. + (e.g. `self.messenger.reply_stream("stream_name", value)`). + + This method has a default implementation that does nothing, so it is not necessary to implement it if you don't need to send any stream data. + """ return None struct FFTProcess[T: FFTProcessable, window_size: Int = 1024, hop_size: Int = 512, input_window_shape: Int = WindowType.hann, output_window_shape: Int = WindowType.hann](Movable,Copyable): diff --git a/mmm_audio/MMMAudioBridge.mojo b/mmm_audio/MMMAudioBridge.mojo index 5d88a565..b30d73c8 100644 --- a/mmm_audio/MMMAudioBridge.mojo +++ b/mmm_audio/MMMAudioBridge.mojo @@ -2,6 +2,7 @@ # i don't want it to be in this directory, but it needs to be here due to a mojo compiler bug from python import PythonObject +from python import Python from python.bindings import PythonModuleBuilder from os import abort @@ -13,6 +14,9 @@ from examples.FeedbackDelays import FeedbackDelays struct MMMAudioBridge(Representable, Movable): var world: World var graph: FeedbackDelays # The audio graph instance + var block_counter: Int # count how many blocks this audio thread has been alive + var np: PythonObject # an instance of numpy for sending np arrays back to python + var pydict: PythonObject # a Python dictionary for sending messages back to python var osc_buffers: UnsafePointer[mut=True, OscBuffers, MutExternalOrigin] var windows: UnsafePointer[mut=True, Windows, MutExternalOrigin] @@ -45,6 +49,18 @@ struct MMMAudioBridge(Representable, Movable): self.graph = FeedbackDelays(self.world) + self.block_counter = 0 + + self.np = PythonObject(None) + + self.pydict = PythonObject(None) + + try: + self.np = Python.import_module("numpy") + self.pydict = Python.dict() + except error: + print("Error occurred while importing numpy. Error: ", error) + @staticmethod fn set_channel_count(py_selfA: PythonObject, args: PythonObject) raises -> PythonObject: var num_in_chans = Int(py=args[0]) @@ -168,13 +184,21 @@ struct MMMAudioBridge(Representable, Movable): self.world[].top_of_block = True self.world[].messengerManager.transfer_msgs() - + self.world[].messengerManager.accepting_stream_data = False + for i in range(self.world[].block_size): self.world[].block_state = i # Update the block state if i == 1: self.world[].top_of_block = False self.world[].messengerManager.empty_msg_dicts() + elif i == self.world[].block_size - 1 and self.block_counter % 10 == 0: + # Add to the "stream" Dicts every 10 blocks. This is a hard coded throttling of sending + # data back because otherwise it is way to much CPU (and not really necessary?) + # Perhaps in the future the "10" can be turned into a user defined parameter when they + # boot up an audio process. There might be times when a user would want to use the CPU + # to stream back data every block. + self.world[].messengerManager.accepting_stream_data = True if self.world[].top_of_block: self.world[].print_counter += 1 @@ -187,6 +211,7 @@ struct MMMAudioBridge(Representable, Movable): # Fill the wire buffer with the sample data for j in range(min(self.world[].num_out_chans, samples.__len__())): loc_out_buffer[i * self.world[].num_out_chans + j] = samples[Int(j)] + @staticmethod fn next(py_selfA: PythonObject, in_buffer: PythonObject, out_buffer: PythonObject) raises -> PythonObject: @@ -201,9 +226,104 @@ struct MMMAudioBridge(Representable, Movable): for i in range(py_self[0].world[].block_size): loc_out_buffer[i * py_self[0].world[].num_out_chans + j] = 0.0 - py_self[0].get_audio_samples(loc_in_buffer, loc_out_buffer) - - return PythonObject(None) # Return a PythonObject wrapping the float value + py_self[0].world[].messengerManager.reply_stream_float.clear() # Clear the reply_stream_float dictionary at the start of each block + py_self[0].world[].messengerManager.reply_stream_floats.clear() # Clear the reply_stream_floats dictionary at the start of each block + + py_self[0].get_audio_samples(loc_in_buffer, loc_out_buffer) + + ############################################################################ + # Any data that needs to go back to Python needs to get into a Python.dict() + ############################################################################ + # even though this is a lot of code right here inside the next function, I think + # it's better to keep the Python.dict() pretty localized because I think a lot of the + # overhead comes from dealing with Python objects. + + py_self[0].pydict.clear() + + # check the "reply_once" Dicts every block + for pf in py_self[0].world[].messengerManager.reply_once_float.take_items(): + py_self[0].pydict[pf.key] = pf.value + + for pfs in py_self[0].world[].messengerManager.reply_once_floats.take_items(): + arr = py_self[0].np.empty(len(pfs.value),dtype=py_self[0].np.float64) + for i in range(len(pfs.value)): + arr[i] = pfs.value[i] + py_self[0].pydict[pfs.key] = arr + + for pi in py_self[0].world[].messengerManager.reply_once_int.take_items(): + py_self[0].pydict[pi.key] = pi.value + + for pis in py_self[0].world[].messengerManager.reply_once_ints.take_items(): + arr = py_self[0].np.empty(len(pis.value),dtype=py_self[0].np.int64) + for i in range(len(pis.value)): + arr[i] = pis.value[i] + py_self[0].pydict[pis.key] = arr + + for pb in py_self[0].world[].messengerManager.reply_once_bool.take_items(): + py_self[0].pydict[pb.key] = pb.value + + for ps in py_self[0].world[].messengerManager.reply_once_string.take_items(): + py_self[0].pydict[ps.key] = ps.value + + for pss in py_self[0].world[].messengerManager.reply_once_strings.take_items(): + arr = Python.list() + for i in range(len(pss.value)): + arr.append(pss.value[i]) + py_self[0].pydict[pss.key] = arr + + for pt in py_self[0].world[].messengerManager.reply_once_trig: + py_self[0].pydict[pt] = None # Set the key to True to indicate the trig was received + + py_self[0].world[].messengerManager.reply_once_trig.clear() + + if py_self[0].world[].messengerManager.accepting_stream_data: + + # float + for pf in py_self[0].world[].messengerManager.reply_stream_float.take_items(): + py_self[0].pydict[pf.key] = pf.value + + # floats + for pfs in py_self[0].world[].messengerManager.reply_stream_floats.take_items(): + arr = py_self[0].np.empty(len(pfs.value),dtype=py_self[0].np.float64) + for i in range(len(pfs.value)): + arr[i] = pfs.value[i] + py_self[0].pydict[pfs.key] = arr + + # int + for pi in py_self[0].world[].messengerManager.reply_stream_int.take_items(): + py_self[0].pydict[pi.key] = pi.value + + # ints + for pis in py_self[0].world[].messengerManager.reply_stream_ints.take_items(): + arr = py_self[0].np.empty(len(pis.value),dtype=py_self[0].np.int64) + for i in range(len(pis.value)): + arr[i] = pis.value[i] + py_self[0].pydict[pis.key] = arr + + # bool + for pb in py_self[0].world[].messengerManager.reply_stream_bool.take_items(): + py_self[0].pydict[pb.key] = pb.value + + # bools + # for pbs in py_self[0].world[].messengerManager.reply_stream_bools.take_items(): + # arr = py_self[0].np.empty(len(pbs.value),dtype=py_self[0].np.bool_) + # for i in range(len(pbs.value)): + # arr[i] = pbs.value[i] + # py_self[0].pydict[pbs.key] = arr + + # string + for ps in py_self[0].world[].messengerManager.reply_stream_string.take_items(): + py_self[0].pydict[ps.key] = ps.value + + # strings + for pss in py_self[0].world[].messengerManager.reply_stream_strings.take_items(): + arr = Python.list() + for i in range(len(pss.value)): + arr.append(pss.value[i]) + py_self[0].pydict[pss.key] = arr + + py_self[0].block_counter += 1 + return py_self[0].pydict # this is needed to make the module importable in Python - so simple! @export diff --git a/mmm_audio/Messenger_Module.mojo b/mmm_audio/Messenger_Module.mojo index 2ebd6363..f5f4ec7d 100644 --- a/mmm_audio/Messenger_Module.mojo +++ b/mmm_audio/Messenger_Module.mojo @@ -1,5 +1,6 @@ from mmm_audio import * from collections import Dict, Set +from python import PythonObject struct Messenger(Copyable, Movable): """Communication between Python and Mojo. @@ -8,6 +9,12 @@ struct Messenger(Copyable, Movable): any parameters registered with it accordingly. Each data type has its own `update` function and `notify_update` which will return a Bool indicating whether the parameter was updated. For example usage, see the MessengerExample.mojo file in the [Examples](../examples/index.md) folder. + + Similarly, each data type has its own `reply_stream` and `reply_once` functions to send values back to Python. `reply_stream` is for continuously changing values that you want to "stream" back to Python. Every 10 audio blocks, the most recent value is sent (whatever the value was on the last sample of that audio block). + + `reply_once` is for sending values just once (or infrequently). At the end of the current audio block this value will be sent to Python. `reply_once` is best used inside an `if` statement in Mojo so the sending only happens in certain circumstances. + + See the Python class `MMMAudio`'s `register_callback` for information on how to receive the values. """ var namespace: Optional[String] @@ -22,6 +29,8 @@ struct Messenger(Copyable, Movable): For example, if a Float64 updates with the name 'freq' and this Messenger has the namespace 'synth1', then to update the freq value from Python, the user must send 'synth1.freq'. + Similarly, if a value is being sent from Mojo to Python and a namespace is supplied, the value will be sent under the name 'namespace.value_name'. For example, if a Float64 value is being sent with the name 'spec_cent' and the namespace is 'sound1', then in Python this value will be received under the name 'sound1.spec_cent'. + Args: world: An `World` to the world to check for new messages. namespace: A `String` (or by defaut `None`) to declare as the 'namespace' for this Messenger. If a 'namespace' is provided, any messages sent from Python need to be prepended with this name. For example, if a Float64 updates with the name 'freq' and this Messenger has the namespace 'synth1', then to update the freq value from Python, the user must send 'synth1.freq'. @@ -31,6 +40,212 @@ struct Messenger(Copyable, Movable): self.namespace = namespace self.key_dict = Dict[String, String]() + fn reply_stream(mut self, name: String, value: Float64): + """Stream a Float64 value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `Float64` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_float[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending float to python. Error: ", error) + + fn reply_once(mut self, name: String, value: Float64): + """Send a Float64 value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `Float64` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_float[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending float to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: List[Float64]): + """Stream a List[Float64] value to Python under the specified name. + + It will be received in Python as a 1D numpy array. + + Args: + name: A `String` to identify the value in Python. + value: A `List[Float64]` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_floats[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending float list to python. Error: ", error) + + fn reply_once(mut self, name: String, value: List[Float64]): + """Send a List[Float64] value to Python under the specified name. + + It will be received in Python as a 1D numpy array. + + Args: + name: A `String` to identify the value in Python. + value: A `List[Float64]` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_floats[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending float list to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: Int): + """Stream an Int value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: An `Int` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_int[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending int to python. Error: ", error) + + fn reply_once(mut self, name: String, value: Int): + """Send an Int value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: An `Int` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_int[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending int to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: List[Int]): + """Stream a List[Int] value to Python under the specified name. + + It will be received in Python as a 1D numpy array. + + Args: + name: A `String` to identify the value in Python. + value: A `List[Int]` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_ints[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending int list to python. Error: ", error) + + fn reply_once(mut self, name: String, value: List[Int]): + """Send a List[Int] value to Python under the specified name. + + It will be received in Python as a 1D numpy array. + + Args: + name: A `String` to identify the value in Python. + value: A `List[Int]` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_ints[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending int list to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: Bool): + """Stream a Bool value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `Bool` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_bool[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending bool to python. Error: ", error) + + fn reply_once(mut self, name: String, value: Bool): + """Send a Bool value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `Bool` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_bool[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending bool to python. Error: ", error) + + # fn reply_stream(mut self, name: String, value: List[Bool]): + # """Stream a List[Bool] value to Python under the specified name. + + # Args: + # name: A `String` to identify the value in Python. + # value: A `List[Bool]` value to be sent to Python. + # """ + # if self.world[].messengerManager.accepting_stream_data: + # try: + # self.world[].messengerManager.reply_stream_bools[self.get_name_with_namespace(name)[]] = value.copy() + # except error: + # print("Error occurred while sending bool list to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: String): + """Stream a String to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `String` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_string[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending string to python. Error: ", error) + + fn reply_once(mut self, name: String, value: String): + """Send a String value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `String` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_string[self.get_name_with_namespace(name)[]] = value + except error: + print("Error occurred while sending string to python. Error: ", error) + + fn reply_stream(mut self, name: String, value: List[String]): + """Stream a List[String] value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `List[String]` value to be sent to Python. + """ + if self.world[].messengerManager.accepting_stream_data: + try: + self.world[].messengerManager.reply_stream_strings[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending string list to python. Error: ", error) + + fn reply_once(mut self, name: String, value: List[String]): + """Send a List[String] value to Python under the specified name. + + Args: + name: A `String` to identify the value in Python. + value: A `List[String]` value to be sent to Python. + """ + try: + self.world[].messengerManager.reply_once_strings[self.get_name_with_namespace(name)[]] = value.copy() + except error: + print("Error occurred while sending string list to python. Error: ", error) + + fn reply_once(mut self, name: String): + """Send a trigger message to Python under the specified name. + + Args: + name: A `String` to identify the trigger in Python. + """ + try: + self.world[].messengerManager.reply_once_trig.add(self.get_name_with_namespace(name)[]) + except error: + print("Error occurred while sending trig to python. Error: ", error) @doc_private fn get_name_with_namespace(mut self, name: String) raises -> LegacyUnsafePointer[mut=False,String]: @@ -466,6 +681,7 @@ struct TrigsMessage(Movable, Copyable): @doc_private struct MessengerManager(Movable, Copyable): + # Data Structure for Receiving Data from Python var bool_msg_pool: Dict[String, Bool] var bool_msgs: Dict[String, BoolMessage] @@ -498,7 +714,29 @@ struct MessengerManager(Movable, Copyable): var trigs_msg_pool: Dict[String, List[Bool]] var trigs_msgs: Dict[String, TrigsMessage] - + + # Data Structures for Sending Data to Python + var reply_stream_float: Dict[String, Float64] + var reply_stream_floats: Dict[String, List[Float64]] + var reply_stream_int: Dict[String, Int] + var reply_stream_ints: Dict[String, List[Int]] + var reply_stream_bool: Dict[String, Bool] + # var reply_stream_bools: Dict[String, List[Bool]] + var reply_stream_string: Dict[String, String] + var reply_stream_strings: Dict[String, List[String]] + + var reply_once_float: Dict[String, Float64] + var reply_once_floats: Dict[String, List[Float64]] + var reply_once_int: Dict[String, Int] + var reply_once_ints: Dict[String, List[Int]] + var reply_once_bool: Dict[String, Bool] + # var reply_once_bools: Dict[String, List[Bool]] + var reply_once_string: Dict[String, String] + var reply_once_strings: Dict[String, List[String]] + var reply_once_trig: Set[String] + + var accepting_stream_data: Bool + fn __init__(out self): self.bool_msg_pool = Dict[String, Bool]() @@ -531,6 +769,27 @@ struct MessengerManager(Movable, Copyable): self.trigs_msg_pool = Dict[String, List[Bool]]() self.trigs_msgs = Dict[String, TrigsMessage]() + self.reply_stream_float = Dict[String, Float64]() + self.reply_stream_floats = Dict[String, List[Float64]]() + self.reply_stream_int = Dict[String, Int]() + self.reply_stream_ints = Dict[String, List[Int]]() + self.reply_stream_bool = Dict[String, Bool]() + # self.reply_stream_bools = Dict[String, List[Bool]]() + self.reply_stream_string = Dict[String, String]() + self.reply_stream_strings = Dict[String, List[String]]() + + self.reply_once_float = Dict[String, Float64]() + self.reply_once_floats = Dict[String, List[Float64]]() + self.reply_once_int = Dict[String, Int]() + self.reply_once_ints = Dict[String, List[Int]]() + self.reply_once_bool = Dict[String, Bool]() + # self.reply_once_bools = Dict[String, List[Bool]]() + self.reply_once_string = Dict[String, String]() + self.reply_once_strings = Dict[String, List[String]]() + self.reply_once_trig = Set[String]() + + self.accepting_stream_data = False + ##### Bool ##### @always_inline fn update_bool_msg(mut self, key: String, value: Bool): diff --git a/mmm_python/MMMAudio.py b/mmm_python/MMMAudio.py index 545e869d..4be90247 100644 --- a/mmm_python/MMMAudio.py +++ b/mmm_python/MMMAudio.py @@ -10,10 +10,11 @@ from typing import Optional, Tuple, List from enum import IntEnum import mojo.importer +import threading +import asyncio import pyautogui - class AudioCommand(IntEnum): STOP_PROCESS = 0 START_AUDIO = 1 @@ -28,6 +29,10 @@ class AudioCommand(IntEnum): SEND_STRINGS = 10 GET_SAMPLES = 11 +class ResponseCommand(IntEnum): + SAMPLES = 0 + CALLBACKS = 1 + class MMMAudio: """ MMMAudio class that runs in its own dedicated process. @@ -78,6 +83,9 @@ def __init__( # Response queue for getting data back from audio process self.response_queue = Queue() + # Callback queue for receiving callbacks from audio process + self.callback_queue = Queue() + # Shared values for real-time parameter control # Add more as needed for your specific parameters self.shared_float_params = {} @@ -86,7 +94,33 @@ def __init__( # Sample rate will be set when process initializes self.sample_rate = Value(ctypes.c_int, 0) + self.callbacks = {} + + # Callback polling thread + self.callback_thread = None + self.callback_active = threading.Event() + self.start_process() + + def register_callback(self, name: str, callback): + """Register a callback function that can be called from the audio process. + + The function will be passed a single argument, which will be the data being sent from Mojo. The callback function will be called whenever the audio process sends a message with the corresponding name. + + If Mojo is sending a List[Float64] the callback will be a numpy array of dtype float64. + If Mojo is sending a List[Int] the callback will be a numpy array of dtype int64. + Other Lists from Mojo will be sent as Python lists. + + Args: + name: The name of the callback. + callback: The function to call when the callback is triggered. + """ + self.callbacks[name] = callback + + def unregister_callback(self, name: str): + """Unregister a previously registered callback function.""" + if name in self.callbacks: + del self.callbacks[name] def start_process(self): """Start the audio process""" @@ -112,6 +146,7 @@ def start_process(self): self.process_ready, self.command_queue, self.response_queue, + self.callback_queue, self.sample_rate ) ) @@ -129,6 +164,9 @@ def stop_process(self): if self.process is None: return + # Stop callback polling first + self._stop_callback_polling() + print("[Main] Stopping audio process...") self.stop_flag.set() @@ -147,9 +185,11 @@ def stop_process(self): def start_audio(self): """Start audio streaming in the audio process""" self.command_queue.put((AudioCommand.START_AUDIO, None)) + self._start_callback_polling() def stop_audio(self): """Stop audio streaming in the audio process""" + self._stop_callback_polling() self.command_queue.put((AudioCommand.STOP_AUDIO, None)) def is_running(self) -> bool: @@ -160,6 +200,68 @@ def is_process_alive(self) -> bool: """Check if the audio process is alive""" return self.process is not None and self.process.is_alive() + # ========================================================================= + # Callback polling methods + # ========================================================================= + + def _start_callback_polling(self): + """Start the callback polling thread""" + if self.callback_thread is not None and self.callback_thread.is_alive(): + return # Already running + + self.callback_active.set() + self.callback_thread = threading.Thread( + target=asyncio.run, + args=(self._callback_polling_loop(),), + daemon=False + ) + self.callback_thread.start() + print("[Main] Callback polling started") + + def _stop_callback_polling(self): + """Stop the callback polling thread""" + if self.callback_thread is None: + return + + self.callback_active.clear() + if self.callback_thread.is_alive(): + self.callback_thread.join(timeout=2.0) + if self.callback_thread.is_alive(): + print("[Main] Warning: Callback thread did not stop cleanly") + self.callback_thread = None + print("[Main] Callback polling stopped") + + async def _callback_polling_loop(self): + """Async loop that polls for callbacks from the audio process""" + import asyncio + from queue import Empty + + print("[Main] Callback polling loop started") + + while self.callback_active.is_set(): + try: + # Try to get callback message from queue (non-blocking) + try: + command, data = self.callback_queue.get(timeout=0.1) + + if command == ResponseCommand.CALLBACKS: + for key, value in data.items(): + if key in self.callbacks: + try: + self.callbacks[key](value) + except Exception as e: + print(f"[Main] Callback error for '{key}': {e}") + except Empty: + pass # No callbacks available + + await asyncio.sleep(0.01) + + except Exception as e: + print(f"[Main] Error in callback polling loop: {e}") + await asyncio.sleep(0.1) + + print("[Main] Callback polling loop ended") + # ========================================================================= # Message sending methods (same interface as original) # ========================================================================= @@ -207,7 +309,7 @@ def get_samples(self, samples: int) -> np.ndarray: # Wait for response try: response = self.response_queue.get(timeout=30.0) - if response[0] == "SAMPLES": + if response[0] == ResponseCommand.SAMPLES: return response[1] else: print(f"[Main] Unexpected response: {response[0]}") @@ -264,6 +366,7 @@ def _audio_process_main( process_ready: Event, command_queue: Queue, response_queue: Queue, + callback_queue: Queue, sample_rate_value: Value ): """ @@ -401,8 +504,13 @@ def output_callback(in_data, frame_count, time_info, status): ) # Process through Mojo bridge with bridge_lock: - mmm_audio_bridge.next(in_array, out_buffer) - + to_py_dict = mmm_audio_bridge.next(in_array, out_buffer) + if to_py_dict: + try: + callback_queue.put_nowait((ResponseCommand.CALLBACKS, to_py_dict)) + except: + pass # Queue full, drop callbacks + out_buffer = np.clip(out_buffer, -1.0, 1.0) output_bytes = out_buffer.astype(np.float32).tobytes() @@ -585,7 +693,7 @@ def handle_get_samples(args): if i * blocksize + j < samples: waveform[i * blocksize + j] = temp_out[j] - response_queue.put(("SAMPLES", waveform)) + response_queue.put((ResponseCommand.SAMPLES, waveform)) return True command_handlers = [ diff --git a/pixi.toml b/pixi.toml index 4d92350a..9636a425 100644 --- a/pixi.toml +++ b/pixi.toml @@ -9,23 +9,26 @@ version = "0.0.0" args = [ { "arg" = "plots", "default" = "" } ] -cmd = "python testing/run_all_validations.py {{ plots }}" +cmd = "python testing_mmm_audio/validation/run_all_validations.py {{ plots }}" [tasks.make_librosa_results] -cmd = "python testing/librosa_results_for_testing_against.py" +cmd = "python testing_mmm_audio/validation/librosa_results_for_testing_against.py" [tasks.validate_snapshot] -cmd = "python testing/validate_against_snapshot.py" +cmd = "python testing_mmm_audio/validation/validate_against_snapshot.py" [tasks.unit_tests] depends-on = ["make_librosa_results"] -cmd = "mojo testing/UnitTests.mojo" +cmd = "mojo testing_mmm_audio/UnitTests.mojo" + +[tasks.py_unit_tests] +cmd = "python testing_mmm_audio/unit_test_py_files.py" [tasks.test_building] -cmd = "python testing/test_build_mojo_files.py" +cmd = "python testing_mmm_audio/test_build_mojo_files.py" [tasks.test_all] -depends-on = ["unit_tests", "test_building","validate_snapshot"] +depends-on = ["unit_tests", "test_building","validate_snapshot","py_unit_tests"] [tasks.docs_serve] cmd = "mkdocs serve" diff --git a/testing/UnitTests.mojo b/testing_mmm_audio/UnitTests.mojo similarity index 98% rename from testing/UnitTests.mojo rename to testing_mmm_audio/UnitTests.mojo index 5f611c82..29812f5e 100644 --- a/testing/UnitTests.mojo +++ b/testing_mmm_audio/UnitTests.mojo @@ -146,7 +146,7 @@ def _test_mel_bands_weights[n_mels: Int, n_fft: Int, sr: Int](): # print("melband weights flat len: ", len(weights_flat)) - expected_path = "testing/librosa_results/librosa_mel_bands_weights_results" + expected_path = "testing_mmm_audio/validation/librosa_results/librosa_mel_bands_weights_results" expected_path += "_nmels=" + String(n_mels) expected_path += "_fftsize=" + String(n_fft) expected_path += "_sr=" + String(sr) diff --git a/testing/__init__.mojo b/testing_mmm_audio/__init__.mojo similarity index 100% rename from testing/__init__.mojo rename to testing_mmm_audio/__init__.mojo diff --git a/testing/__init__.py b/testing_mmm_audio/__init__.py similarity index 100% rename from testing/__init__.py rename to testing_mmm_audio/__init__.py diff --git a/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.mojo b/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.mojo new file mode 100644 index 00000000..79a0f2a7 --- /dev/null +++ b/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.mojo @@ -0,0 +1,75 @@ +from mmm_audio import * + +struct MessengerRoundTripTest(Movable, Copyable): + var world: World + var m: Messenger + var float_val: Float64 + var floats_val: List[Float64] + var int_val: Int + var ints_val: List[Int] + var bool_val: Bool + var str: String + var strs: List[String] + + var stream_float_val: Float64 + var stream_floats_val: List[Float64] + var stream_int_val: Int + var stream_ints_val: List[Int] + var stream_bool_val: Bool + var stream_str_val: String + var stream_strs_val: List[String] + + fn __init__(out self, world: World): + self.world = world + self.m = Messenger(self.world) + self.float_val = 0.0 + self.floats_val = List[Float64]() + self.int_val = 0 + self.ints_val = List[Int]() + self.bool_val = False + self.str = "" + self.strs = List[String]() + + self.stream_float_val = 42.42 + self.stream_floats_val = [1.1, 2.2, 3.3] + self.stream_int_val = 1825 + self.stream_ints_val = [1776,2026] + self.stream_bool_val = True + self.stream_str_val = "kenobi" + self.stream_strs_val = ["luke", "leia", "han"] + + fn next(mut self) -> SIMD[DType.float64, 2]: + + if self.m.notify_update(self.float_val,"float"): + self.m.reply_once("float_return", self.float_val + 1.0) + + if self.m.notify_update(self.floats_val,"floats"): + self.m.reply_once("floats_return", [x + 1.0 for x in self.floats_val]) + + if self.m.notify_update(self.int_val,"int"): + self.m.reply_once("int_return", self.int_val + 1) + + if self.m.notify_update(self.ints_val,"ints"): + self.m.reply_once("ints_return", [x + 1 for x in self.ints_val]) + + if self.m.notify_update(self.bool_val,"bool"): + self.m.reply_once("bool_return", not self.bool_val) + + if self.m.notify_trig("trig"): + self.m.reply_once("trig_return") + + if self.m.notify_update(self.str,"str"): + self.m.reply_once("str_return", self.str + "_return") + + if self.m.notify_update(self.strs,"strs"): + self.m.reply_once("strs_return", [s + "_return" for s in self.strs]) + + self.m.reply_stream("stream_float", self.stream_float_val) + self.m.reply_stream("stream_floats", self.stream_floats_val) + self.m.reply_stream("stream_int", self.stream_int_val) + self.m.reply_stream("stream_ints", self.stream_ints_val) + self.m.reply_stream("stream_bool", self.stream_bool_val) + self.m.reply_stream("stream_str", self.stream_str_val) + self.m.reply_stream("stream_strs", self.stream_strs_val) + + return SIMD[DType.float64, 2](0.0, 0.0) \ No newline at end of file diff --git a/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.py b/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.py new file mode 100644 index 00000000..a0187187 --- /dev/null +++ b/testing_mmm_audio/py_unit_tests/MessengerRoundTripTest.py @@ -0,0 +1,89 @@ +import time + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from mmm_python import * + +def main(): + m = MMMAudio(128, graph_name="MessengerRoundTripTest", package_name="testing_mmm_audio.py_unit_tests") + + # received values + rv = {} + + m.register_callback("float_return", lambda args: rv.update({"float": args})) + m.register_callback("floats_return", lambda args: rv.update({"floats": args})) + m.register_callback("int_return", lambda args: rv.update({"int": args})) + m.register_callback("ints_return", lambda args: rv.update({"ints": args})) + m.register_callback("bool_return", lambda args: rv.update({"bool": args})) + m.register_callback("str_return", lambda args: rv.update({"str": args})) + m.register_callback("strs_return", lambda args: rv.update({"strs": args})) + m.register_callback("trig_return", lambda args: rv.update({"trig": args})) + + m.register_callback("stream_float", lambda args: rv.update({"stream_float": args})) + m.register_callback("stream_floats", lambda args: rv.update({"stream_floats": args})) + m.register_callback("stream_int", lambda args: rv.update({"stream_int": args})) + m.register_callback("stream_ints", lambda args: rv.update({"stream_ints": args})) + m.register_callback("stream_bool", lambda args: rv.update({"stream_bool": args})) + m.register_callback("stream_str", lambda args: rv.update({"stream_str": args})) + m.register_callback("stream_strs", lambda args: rv.update({"stream_strs": args})) + + m.start_audio() + + # print("callbacks",m.callbacks) + + m.send_float("float",2.1415) + m.send_floats("floats",[1.0,2.0,3.0]) + m.send_int("int",42) + m.send_ints("ints",[100, 200, 300]) + m.send_bool("bool",True) + m.send_string("str","hello") + m.send_strings("strs",["foo", "bar", "baz"]) + m.send_trig("trig") + + time.sleep(0.1) + + # expected values + ev = { + "float": 3.1415, + "floats": np.array([2.0, 3.0, 4.0]), + "int": 43, + "ints": np.array([101, 201, 301]), + "bool": False, + "str": "hello_return", + "strs": ["foo_return", "bar_return", "baz_return"], + "trig": None, + "stream_float": 42.42, + "stream_floats": np.array([1.1, 2.2, 3.3]), + "stream_int": 1825, + "stream_ints": np.array([1776,2026]), + "stream_bool": True, + "stream_str": "kenobi", + "stream_strs": ["luke", "leia", "han"] + } + + assert rv["float"] == ev["float"], f"Expected {ev}, but got {rv}" + assert np.array_equal(rv["floats"], ev["floats"]), f"Expected {ev}, but got {rv}" + assert rv["int"] == ev["int"], f"Expected {ev}, but got {rv}" + assert np.array_equal(rv["ints"], ev["ints"]), f"Expected {ev}, but got {rv}" + assert rv["bool"] == ev["bool"], f"Expected {ev}, but got {rv}" + assert rv["str"] == ev["str"], f"Expected {ev}, but got {rv}" + assert rv["strs"] == ev["strs"], f"Expected {ev}, but got {rv}" + assert rv["trig"] == ev["trig"], f"Expected {ev}, but got {rv}" + assert rv["stream_float"] == ev["stream_float"], f"Expected {ev}, but got {rv}" + assert np.array_equal(rv["stream_floats"], ev["stream_floats"]), f"Expected {ev}, but got {rv}" + assert rv["stream_int"] == ev["stream_int"], f"Expected {ev}, but got {rv}" + assert np.array_equal(rv["stream_ints"], ev["stream_ints"]), f"Expected {ev}, but got {rv}" + assert rv["stream_bool"] == ev["stream_bool"], f"Expected {ev}, but got {rv}" + assert rv["stream_str"] == ev["stream_str"], f"Expected {ev}, but got {rv}" + assert rv["stream_strs"] == ev["stream_strs"], f"Expected {ev}, but got {rv}" + + print("MessengerRoundTripTest passed") + + m.stop_audio() + m.stop_process() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/testing_mmm_audio/py_unit_tests/__init__.mojo b/testing_mmm_audio/py_unit_tests/__init__.mojo new file mode 100644 index 00000000..e69de29b diff --git a/testing/test_build_mojo_files.py b/testing_mmm_audio/test_build_mojo_files.py similarity index 100% rename from testing/test_build_mojo_files.py rename to testing_mmm_audio/test_build_mojo_files.py diff --git a/testing_mmm_audio/unit_test_py_files.py b/testing_mmm_audio/unit_test_py_files.py new file mode 100644 index 00000000..b2265ee4 --- /dev/null +++ b/testing_mmm_audio/unit_test_py_files.py @@ -0,0 +1,13 @@ +from glob import glob +import os +import subprocess + +def test_dir_of_pys(dirpath: str): + python_files = glob(f"{dirpath}/*.py") + for py_file in python_files: + if os.path.basename(py_file) != "__init__.py": + print(f"Running {py_file}...") + subprocess.run(["python", py_file], check=True) + +if __name__ == "__main__": + test_dir_of_pys("testing_mmm_audio/py_unit_tests") \ No newline at end of file diff --git a/testing/MFCC_Validation.mojo b/testing_mmm_audio/validation/MFCC_Validation.mojo similarity index 95% rename from testing/MFCC_Validation.mojo rename to testing_mmm_audio/validation/MFCC_Validation.mojo index 18560c06..48b0d9cc 100644 --- a/testing/MFCC_Validation.mojo +++ b/testing_mmm_audio/validation/MFCC_Validation.mojo @@ -31,7 +31,7 @@ def main(): print("Number of frames processed: ", len(fftprocess.buffered_process.process.process.data)) - with open("testing/mojo_results/mfcc_mojo_results.csv", "w") as f: + with open("testing_mmm_audio/validation/mojo_results/mfcc_mojo_results.csv", "w") as f: f.write("windowsize," + String(fftsize) + "\n") f.write("hopsize," + String(hopsize) + "\n") f.write("num_coeffs," + String(num_coeffs) + "\n") diff --git a/testing/MFCC_Validation.py b/testing_mmm_audio/validation/MFCC_Validation.py similarity index 86% rename from testing/MFCC_Validation.py rename to testing_mmm_audio/validation/MFCC_Validation.py index 4090210b..97888c71 100644 --- a/testing/MFCC_Validation.py +++ b/testing_mmm_audio/validation/MFCC_Validation.py @@ -19,21 +19,21 @@ args = parser.parse_args() show_plots = args.show_plots -os.makedirs("testing/validation_results", exist_ok=True) -os.makedirs("testing/mojo_results", exist_ok=True) -os.makedirs("testing/flucoma_sc_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/validation_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/mojo_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/flucoma_sc_results", exist_ok=True) -os.system("mojo run testing/MFCC_Validation.mojo") +os.system("mojo run -I . testing_mmm_audio/validation/MFCC_Validation.mojo") print("mojo analysis complete") try: - flucoma_csv_path = "testing/flucoma_sc_results/mfcc_flucoma_results.csv" + flucoma_csv_path = "testing_mmm_audio/validation/flucoma_sc_results/mfcc_flucoma_results.csv" if not os.path.exists(flucoma_csv_path): - os.system("sclang testing/MFCC_Validation.scd") + os.system("sclang testing_mmm_audio/validation/MFCC_Validation.scd") except Exception as e: print("Error running SuperCollider script (make sure `sclang` can be called from the Terminal):", e) -with open("testing/mojo_results/mfcc_mojo_results.csv", "r") as f: +with open("testing_mmm_audio/validation/mojo_results/mfcc_mojo_results.csv", "r") as f: lines = f.readlines() windowsize = int(lines[0].strip().split(",")[1]) @@ -118,7 +118,7 @@ def compare_mfcc(arr1, arr2): ax[1].set(title="FluCoMa", ylabel="MFCC") ax[2].set(title="MMMAudio", xlabel="Frame", ylabel="MFCC") plt.tight_layout() -plt.savefig("testing/validation_results/mfcc_comparison.png") +plt.savefig("testing_mmm_audio/validation/validation_results/mfcc_comparison.png") if show_plots: plt.show() else: diff --git a/testing/MFCC_Validation.scd b/testing_mmm_audio/validation/MFCC_Validation.scd similarity index 92% rename from testing/MFCC_Validation.scd rename to testing_mmm_audio/validation/MFCC_Validation.scd index acdfd4ef..183f09fa 100644 --- a/testing/MFCC_Validation.scd +++ b/testing_mmm_audio/validation/MFCC_Validation.scd @@ -14,7 +14,7 @@ s.waitForBoot{ "numCoeffs: %".format(numCoeffs).postln; "numBands: %".format(numBands).postln; - b = Buffer.readChannel(s,dir+/+"../resources/Shiverer.wav",channels:[0]); + b = Buffer.readChannel(s,dir+/+"../../resources/Shiverer.wav",channels:[0]); f = Buffer(s); s.sync; FluidBufMFCC.processBlocking( diff --git a/testing/MelBands_Validation.mojo b/testing_mmm_audio/validation/MelBands_Validation.mojo similarity index 93% rename from testing/MelBands_Validation.mojo rename to testing_mmm_audio/validation/MelBands_Validation.mojo index ac38f760..a7c7f41d 100644 --- a/testing/MelBands_Validation.mojo +++ b/testing_mmm_audio/validation/MelBands_Validation.mojo @@ -27,7 +27,7 @@ def main(): print("Number of frames processed: ", len(fftprocess.buffered_process.process.process.data)) - with open("testing/mojo_results/mel_bands_mojo.csv", "w") as f: + with open("testing_mmm_audio/validation/mojo_results/mel_bands_mojo.csv", "w") as f: for i,frame in enumerate(fftprocess.buffered_process.process.process.data): if i > 0: f.write("\n") diff --git a/testing/MelBands_Validation.py b/testing_mmm_audio/validation/MelBands_Validation.py similarity index 85% rename from testing/MelBands_Validation.py rename to testing_mmm_audio/validation/MelBands_Validation.py index 79d210b7..1ec9b928 100644 --- a/testing/MelBands_Validation.py +++ b/testing_mmm_audio/validation/MelBands_Validation.py @@ -12,15 +12,15 @@ args = parser.parse_args() show_plots = args.show_plots -os.makedirs("testing/validation_results", exist_ok=True) -os.makedirs("testing/mojo_results", exist_ok=True) -os.makedirs("testing/flucoma_sc_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/validation_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/mojo_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/flucoma_sc_results", exist_ok=True) -flucoma_csv_path = "./testing/flucoma_sc_results/mel_bands_flucoma.csv" +flucoma_csv_path = "./testing_mmm_audio/validation/flucoma_sc_results/mel_bands_flucoma.csv" if not os.path.exists(flucoma_csv_path): - os.system("sclang ./testing/MelBands_Validation.scd") + os.system("sclang ./testing_mmm_audio/validation/MelBands_Validation.scd") -os.system("mojo run ./testing/MelBands_Validation.mojo") +os.system("mojo run -I . ./testing_mmm_audio/validation/MelBands_Validation.mojo") with open(flucoma_csv_path, "r") as f: reader = csv.reader(f) @@ -28,7 +28,7 @@ for row in reader: flucoma_results.append([float(value) for value in row]) -with open("./testing/mojo_results/mel_bands_mojo.csv", "r") as f: +with open("./testing_mmm_audio/validation/mojo_results/mel_bands_mojo.csv", "r") as f: reader = csv.reader(f) mojo_results = [] for row in reader: @@ -86,7 +86,7 @@ def compare_mel_bands(arr1, arr2, name1, name2): ax[1].set(title='FluCoMa') ax[2].set(title='MMMAudio') ax[0].label_outer() -plt.savefig("./testing/validation_results/mel_bands_comparison.png") +plt.savefig("./testing_mmm_audio/validation/validation_results/mel_bands_comparison.png") if show_plots: plt.show() else: diff --git a/testing/MelBands_Validation.scd b/testing_mmm_audio/validation/MelBands_Validation.scd similarity index 100% rename from testing/MelBands_Validation.scd rename to testing_mmm_audio/validation/MelBands_Validation.scd diff --git a/testing/RMS_Validation.mojo b/testing_mmm_audio/validation/RMS_Validation.mojo similarity index 94% rename from testing/RMS_Validation.mojo rename to testing_mmm_audio/validation/RMS_Validation.mojo index 0ef18d2c..052db824 100644 --- a/testing/RMS_Validation.mojo +++ b/testing_mmm_audio/validation/RMS_Validation.mojo @@ -31,7 +31,7 @@ fn main(): sample = playBuf.next(buffer) analyzer.next(sample) - pth = "testing/mojo_results/rms_mojo_results.csv" + pth = "testing_mmm_audio/validation/mojo_results/rms_mojo_results.csv" try: with open(pth, "w") as f: f.write("windowsize,",windowsize,"\n") diff --git a/testing/RMS_Validation.py b/testing_mmm_audio/validation/RMS_Validation.py similarity index 80% rename from testing/RMS_Validation.py rename to testing_mmm_audio/validation/RMS_Validation.py index 71ab0b32..3f44f8ca 100644 --- a/testing/RMS_Validation.py +++ b/testing_mmm_audio/validation/RMS_Validation.py @@ -18,14 +18,14 @@ args = parser.parse_args() show_plots = args.show_plots -os.makedirs("testing/validation_results", exist_ok=True) -os.makedirs("testing/mojo_results", exist_ok=True) -os.makedirs("testing/flucoma_sc_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/validation_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/mojo_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/flucoma_sc_results", exist_ok=True) -os.system("mojo run testing/RMS_Validation.mojo") +os.system("mojo run -I . testing_mmm_audio/validation/RMS_Validation.mojo") print("mojo analysis complete") -with open("testing/mojo_results/rms_mojo_results.csv", "r") as f: +with open("testing_mmm_audio/validation/mojo_results/rms_mojo_results.csv", "r") as f: lines = f.readlines() windowsize = int(lines[0].strip().split(",")[1]) hopsize = int(lines[1].strip().split(",")[1]) @@ -51,9 +51,9 @@ def compare_analyses(list1, list2): return np.mean(np.abs(diff)), np.std(diff) try: - flucoma_csv_path = "testing/flucoma_sc_results/rms_flucoma_results.csv" + flucoma_csv_path = "testing_mmm_audio/validation/flucoma_sc_results/rms_flucoma_results.csv" if not os.path.exists(flucoma_csv_path): - os.system("sclang testing/RMS_Validation.scd") + os.system("sclang testing_mmm_audio/validation/RMS_Validation.scd") except Exception as e: print("Error running SuperCollider script (make sure `sclang` can be called from the Terminal):", e) @@ -88,7 +88,7 @@ def compare_analyses(list1, list2): plt.legend() plt.ylabel("Amplitude") plt.title("RMS Comparison") -plt.savefig("testing/validation_results/rms_comparison.png") +plt.savefig("testing_mmm_audio/validation/validation_results/rms_comparison.png") if show_plots: plt.show() else: diff --git a/testing/RMS_Validation.scd b/testing_mmm_audio/validation/RMS_Validation.scd similarity index 91% rename from testing/RMS_Validation.scd rename to testing_mmm_audio/validation/RMS_Validation.scd index e1bf9828..82ff233c 100644 --- a/testing/RMS_Validation.scd +++ b/testing_mmm_audio/validation/RMS_Validation.scd @@ -8,7 +8,7 @@ s.waitForBoot{ "windowsize: %".format(windowsize).postln; "hopsize: %".format(hopsize).postln; - b = Buffer.readChannel(s,dir+/+"../resources/Shiverer.wav",channels:[0]); + b = Buffer.readChannel(s,dir+/+"../../resources/Shiverer.wav",channels:[0]); f = Buffer(s); s.sync; FluidBufLoudness.processBlocking( diff --git a/testing/SpectralCentroid_Validation.mojo b/testing_mmm_audio/validation/SpectralCentroid_Validation.mojo similarity index 95% rename from testing/SpectralCentroid_Validation.mojo rename to testing_mmm_audio/validation/SpectralCentroid_Validation.mojo index 87be9cbe..21928211 100644 --- a/testing/SpectralCentroid_Validation.mojo +++ b/testing_mmm_audio/validation/SpectralCentroid_Validation.mojo @@ -36,7 +36,7 @@ fn main(): sample = playBuf.next(buffer) analyzer.next(sample) - pth = "testing/mojo_results/spectral_centroid_mojo_results.csv" + pth = "testing_mmm_audio/validation/mojo_results/spectral_centroid_mojo_results.csv" try: with open(pth, "w") as f: f.write("windowsize,",windowsize,"\n") diff --git a/testing/SpectralCentroid_Validation.py b/testing_mmm_audio/validation/SpectralCentroid_Validation.py similarity index 82% rename from testing/SpectralCentroid_Validation.py rename to testing_mmm_audio/validation/SpectralCentroid_Validation.py index 47995ca1..a506b00e 100644 --- a/testing/SpectralCentroid_Validation.py +++ b/testing_mmm_audio/validation/SpectralCentroid_Validation.py @@ -19,14 +19,14 @@ args = parser.parse_args() show_plots = args.show_plots -os.makedirs("testing/validation_results", exist_ok=True) -os.makedirs("testing/mojo_results", exist_ok=True) -os.makedirs("testing/flucoma_sc_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/validation_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/mojo_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/flucoma_sc_results", exist_ok=True) -os.system("mojo run testing/SpectralCentroid_Validation.mojo") +os.system("mojo run -I . testing_mmm_audio/validation/SpectralCentroid_Validation.mojo") print("mojo analysis complete") -with open("testing/mojo_results/spectral_centroid_mojo_results.csv", "r") as f: +with open("testing_mmm_audio/validation/mojo_results/spectral_centroid_mojo_results.csv", "r") as f: lines = f.readlines() windowsize = int(lines[0].strip().split(",")[1]) hopsize = int(lines[1].strip().split(",")[1]) @@ -66,9 +66,9 @@ def compare_analyses_pitch(list1, list2): return mean_hz, std_hz, mean_st, std_st try: - flucoma_csv_path = "testing/flucoma_sc_results/spectral_centroid_flucoma_results.csv" + flucoma_csv_path = "testing_mmm_audio/validation/flucoma_sc_results/spectral_centroid_flucoma_results.csv" if not os.path.exists(flucoma_csv_path): - os.system("sclang testing/SpectralCentroid_Validation.scd") + os.system("sclang testing_mmm_audio/validation/SpectralCentroid_Validation.scd") except Exception as e: print("Error running SuperCollider script (make sure `sclang` can be called from the Terminal):", e) @@ -103,7 +103,7 @@ def compare_analyses_pitch(list1, list2): except Exception as e: print("Error comparing FluCoMa results:", e) -plt.savefig("testing/validation_results/spectral_centroid_comparison.png") +plt.savefig("testing_mmm_audio/validation/validation_results/spectral_centroid_comparison.png") if show_plots: plt.show() else: diff --git a/testing/SpectralCentroid_Validation.scd b/testing_mmm_audio/validation/SpectralCentroid_Validation.scd similarity index 86% rename from testing/SpectralCentroid_Validation.scd rename to testing_mmm_audio/validation/SpectralCentroid_Validation.scd index fa9b7b95..93262831 100644 --- a/testing/SpectralCentroid_Validation.scd +++ b/testing_mmm_audio/validation/SpectralCentroid_Validation.scd @@ -1,4 +1,5 @@ var dir = thisProcess.nowExecutingPath.dirname; +s.options.numWireBufs = 1024; s.waitForBoot{ var windowsize = 1024; @@ -7,7 +8,7 @@ s.waitForBoot{ "windowsize: %".format(windowsize).postln; "hopsize: %".format(hopsize).postln; - b = Buffer.readChannel(s,dir+/+"../resources/Shiverer.wav",channels:[0]); + b = Buffer.readChannel(s,dir+/+"../../resources/Shiverer.wav",channels:[0]); f = Buffer(s); s.sync; FluidBufSpectralShape.processBlocking( diff --git a/testing/YIN_Validation.mojo b/testing_mmm_audio/validation/YIN_Validation.mojo similarity index 96% rename from testing/YIN_Validation.mojo rename to testing_mmm_audio/validation/YIN_Validation.mojo index c0404ae7..fdef2856 100644 --- a/testing/YIN_Validation.mojo +++ b/testing_mmm_audio/validation/YIN_Validation.mojo @@ -47,7 +47,7 @@ fn main(): sample = playBuf.next(buffer) analyzer.next(sample) - pth = "testing/mojo_results/yin_mojo_results.csv" + pth = "testing_mmm_audio/validation/mojo_results/yin_mojo_results.csv" try: with open(pth, "w") as f: f.write("windowsize,",windowsize,"\n") diff --git a/testing/YIN_Validation.py b/testing_mmm_audio/validation/YIN_Validation.py similarity index 89% rename from testing/YIN_Validation.py rename to testing_mmm_audio/validation/YIN_Validation.py index beefc16e..4daed480 100644 --- a/testing/YIN_Validation.py +++ b/testing_mmm_audio/validation/YIN_Validation.py @@ -21,14 +21,14 @@ args = parser.parse_args() show_plots = args.show_plots -os.makedirs("testing/validation_results", exist_ok=True) -os.makedirs("testing/mojo_results", exist_ok=True) -os.makedirs("testing/flucoma_sc_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/validation_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/mojo_results", exist_ok=True) +os.makedirs("testing_mmm_audio/validation/flucoma_sc_results", exist_ok=True) -os.system("mojo run testing/YIN_Validation.mojo") +os.system("mojo run -I . testing_mmm_audio/validation/YIN_Validation.mojo") print("mojo analysis complete") -with open("testing/mojo_results/yin_mojo_results.csv", "r") as f: +with open("testing_mmm_audio/validation/mojo_results/yin_mojo_results.csv", "r") as f: lines = f.readlines() windowsize = int(lines[0].strip().split(",")[1]) hopsize = int(lines[1].strip().split(",")[1]) @@ -95,9 +95,9 @@ def compare_analyses_confidence(list1, list2): return mean_diff, std_diff try: - flucoma_csv_path = "testing/flucoma_sc_results/yin_flucoma_results.csv" + flucoma_csv_path = "testing_mmm_audio/validation/flucoma_sc_results/yin_flucoma_results.csv" if not os.path.exists(flucoma_csv_path): - os.system("sclang testing/YIN_Validation.scd") + os.system("sclang testing_mmm_audio/validation/YIN_Validation.scd") except Exception as e: print("Error running SuperCollider script (make sure `sclang` can be called from the Terminal):", e) @@ -168,7 +168,7 @@ def compare_analyses_confidence(list1, list2): print("Error comparing FluCoMa results:", e) plt.tight_layout() -plt.savefig("testing/validation_results/yin_comparison.png") +plt.savefig("testing_mmm_audio/validation/validation_results/yin_comparison.png") if show_plots: plt.show() else: @@ -199,7 +199,7 @@ def compare_analyses_confidence(list1, list2): plt.title('Histogram of Pitch Deviation (Semitones)') plt.legend() plt.xticks(bins) -plt.savefig("testing/validation_results/yin_deviation_histogram.png") +plt.savefig("testing_mmm_audio/validation/validation_results/yin_deviation_histogram.png") if show_plots: plt.show() else: diff --git a/testing/YIN_Validation.scd b/testing_mmm_audio/validation/YIN_Validation.scd similarity index 92% rename from testing/YIN_Validation.scd rename to testing_mmm_audio/validation/YIN_Validation.scd index a7c5122f..de806af8 100644 --- a/testing/YIN_Validation.scd +++ b/testing_mmm_audio/validation/YIN_Validation.scd @@ -12,7 +12,7 @@ s.waitForBoot{ "minfreq: %".format(minfreq).postln; "maxfreq: %".format(maxfreq).postln; - b = Buffer.readChannel(s,dir+/+"../resources/Shiverer.wav",channels:[0]); + b = Buffer.readChannel(s,dir+/+"../../resources/Shiverer.wav",channels:[0]); f = Buffer(s); s.sync; FluidBufPitch.processBlocking( diff --git a/testing/flucoma_sc_results/mel_bands_flucoma.csv b/testing_mmm_audio/validation/flucoma_sc_results/mel_bands_flucoma.csv similarity index 100% rename from testing/flucoma_sc_results/mel_bands_flucoma.csv rename to testing_mmm_audio/validation/flucoma_sc_results/mel_bands_flucoma.csv diff --git a/testing/flucoma_sc_results/mfcc_flucoma_results.csv b/testing_mmm_audio/validation/flucoma_sc_results/mfcc_flucoma_results.csv similarity index 100% rename from testing/flucoma_sc_results/mfcc_flucoma_results.csv rename to testing_mmm_audio/validation/flucoma_sc_results/mfcc_flucoma_results.csv diff --git a/testing/flucoma_sc_results/rms_flucoma_results.csv b/testing_mmm_audio/validation/flucoma_sc_results/rms_flucoma_results.csv similarity index 100% rename from testing/flucoma_sc_results/rms_flucoma_results.csv rename to testing_mmm_audio/validation/flucoma_sc_results/rms_flucoma_results.csv diff --git a/testing/flucoma_sc_results/spectral_centroid_flucoma_results.csv b/testing_mmm_audio/validation/flucoma_sc_results/spectral_centroid_flucoma_results.csv similarity index 100% rename from testing/flucoma_sc_results/spectral_centroid_flucoma_results.csv rename to testing_mmm_audio/validation/flucoma_sc_results/spectral_centroid_flucoma_results.csv diff --git a/testing/flucoma_sc_results/spectral_shape_flucoma_results.csv b/testing_mmm_audio/validation/flucoma_sc_results/spectral_shape_flucoma_results.csv similarity index 100% rename from testing/flucoma_sc_results/spectral_shape_flucoma_results.csv rename to testing_mmm_audio/validation/flucoma_sc_results/spectral_shape_flucoma_results.csv diff --git a/testing/flucoma_sc_results/yin_flucoma_results.csv b/testing_mmm_audio/validation/flucoma_sc_results/yin_flucoma_results.csv similarity index 100% rename from testing/flucoma_sc_results/yin_flucoma_results.csv rename to testing_mmm_audio/validation/flucoma_sc_results/yin_flucoma_results.csv diff --git a/testing/librosa_results_for_testing_against.py b/testing_mmm_audio/validation/librosa_results_for_testing_against.py similarity index 91% rename from testing/librosa_results_for_testing_against.py rename to testing_mmm_audio/validation/librosa_results_for_testing_against.py index 141c5ef1..a00b504c 100644 --- a/testing/librosa_results_for_testing_against.py +++ b/testing_mmm_audio/validation/librosa_results_for_testing_against.py @@ -62,8 +62,8 @@ def mel_bands_weights_results(n_mels: int, n_fft: int, sr: int): - os.makedirs("testing/librosa_results", exist_ok=True) - with open(f"testing/librosa_results/librosa_mel_bands_weights_results_nmels={n_mels}_fftsize={n_fft}_sr={sr}.csv", "w") as f: + os.makedirs("testing_mmm_audio/validation/librosa_results", exist_ok=True) + with open(f"testing_mmm_audio/validation/librosa_results/librosa_mel_bands_weights_results_nmels={n_mels}_fftsize={n_fft}_sr={sr}.csv", "w") as f: for row in range(len(mel_weights)): for col in range(len(mel_weights[row])): f.write(f"{mel_weights[row][col]}\n") diff --git a/testing/make_validation_snapshot.py b/testing_mmm_audio/validation/make_validation_snapshot.py similarity index 98% rename from testing/make_validation_snapshot.py rename to testing_mmm_audio/validation/make_validation_snapshot.py index b9e36203..21b47a99 100644 --- a/testing/make_validation_snapshot.py +++ b/testing_mmm_audio/validation/make_validation_snapshot.py @@ -70,7 +70,7 @@ def build_snapshot(show_plots: bool) -> dict[str, Any]: scripts = validation_scripts() if not scripts: - raise RuntimeError("No *_Validation.py scripts found under testing/.") + raise RuntimeError("No *_Validation.py scripts found under testing_mmm_audio/.") for script in scripts: name = os.path.basename(script) @@ -97,7 +97,7 @@ def build_snapshot(show_plots: bool) -> dict[str, Any]: def main() -> int: parser = argparse.ArgumentParser( - description="Run validation scripts and create/update testing/validation_snapshot.json" + description="Run validation scripts and create/update testing_mmm_audio/validation_snapshot.json" ) parser.add_argument( "--show-plots", diff --git a/testing/run_all_validations.py b/testing_mmm_audio/validation/run_all_validations.py similarity index 86% rename from testing/run_all_validations.py rename to testing_mmm_audio/validation/run_all_validations.py index 91002307..a532c189 100644 --- a/testing/run_all_validations.py +++ b/testing_mmm_audio/validation/run_all_validations.py @@ -6,7 +6,7 @@ parser.add_argument("--show-plots", action="store_true", help="Display plots for each validation script") args = parser.parse_args() -validations = glob("testing/*_Validation.py") +validations = glob("testing_mmm_audio/validation/*_Validation.py") for validation in validations: print(f"Running {validation}...") diff --git a/testing/validate_against_snapshot.py b/testing_mmm_audio/validation/validate_against_snapshot.py similarity index 97% rename from testing/validate_against_snapshot.py rename to testing_mmm_audio/validation/validate_against_snapshot.py index 5f8ff518..6c2d381f 100644 --- a/testing/validate_against_snapshot.py +++ b/testing_mmm_audio/validation/validate_against_snapshot.py @@ -12,11 +12,11 @@ def repo_root() -> str: - return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + return os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) def testing_dir() -> str: - return os.path.join(repo_root(), "testing") + return os.path.join(repo_root(), "testing_mmm_audio/validation") def snapshot_path() -> str: @@ -76,7 +76,7 @@ def current_results(show_plots: bool) -> dict[str, Any]: out: dict[str, Any] = {"scripts": {}} scripts = validation_scripts() if not scripts: - raise RuntimeError("No *_Validation.py scripts found under testing/.") + raise RuntimeError("No *_Validation.py scripts found under testing_mmm_audio/.") for script in scripts: name = os.path.basename(script) @@ -116,7 +116,7 @@ def load_snapshot() -> dict[str, Any]: def main() -> int: parser = argparse.ArgumentParser( - description="Run validation scripts and compare against testing/validation_snapshot.json" + description="Run validation scripts and compare against testing_mmm_audio/validation_snapshot.json" ) parser.add_argument( "--show-plots", diff --git a/testing/validation_snapshot.json b/testing_mmm_audio/validation/validation_snapshot.json similarity index 100% rename from testing/validation_snapshot.json rename to testing_mmm_audio/validation/validation_snapshot.json