Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,4 @@ testing_mmm_audio/validation/librosa_results
testing_mmm_audio/validation/mojo_results
testing_mmm_audio/validation/validation_results
peaks
.DS_Store
67 changes: 67 additions & 0 deletions examples/Classifier.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from mmm_audio import *

comptime scaler_path = "examples/nn_trainings/mfcc_classifier_scaler.joblib"
comptime model_path = "examples/nn_trainings/mfcc_classifier_traced.pt"

comptime windowsize = 1024
comptime hopsize = windowsize // 2
comptime n_mfcc = 13

struct ClassifierWindow(FFTProcessable):
var model: PythonObject
var scaler: StandardScaler
var mfcc: MFCC
var scaled_coeffs: List[Float64]
var py_input: PythonObject
var py_output: PythonObject

def __init__(out self, sr: Float64):
self.scaler = StandardScaler(scaler_path)
self.mfcc = MFCC(sr=sr, fft_size=windowsize, num_coeffs=n_mfcc)
self.scaled_coeffs = List[Float64](fill=0.0, length=n_mfcc)

try:
torch = Python.import_module("torch")
self.model = torch.jit.load(model_path)
self.py_input = torch.zeros(n_mfcc)
self.py_output = torch.zeros(1) # Adjust the size based on your model's output
except e:
abort("Error loading PyTorch model: " + String(e))

def next_frame(mut self, mut mags: List[Float64], mut phss: List[Float64]):
self.mfcc.from_mags(mags)
self.scaler.transform_point(self.mfcc.coeffs, self.scaled_coeffs)
try:
for i in range(n_mfcc):
self.py_input[i] = self.scaled_coeffs[i]
self.py_output = self.model(self.py_input)
o = Float64(py=self.py_output.item())
display: String = "🐶" if o > 0.5 else "❌"
print("Dog:",display,"---", o)
except e:
abort("Error predicting: " + String(e))

struct Classifier(Movable,Copyable):
var world: World
var fftp: FFTProcess[ClassifierWindow,output_window_shape=WindowType.hann]
var src: Buffer
var player: Play
var src_path: String
var m: Messenger

def __init__(out self, world: World):
self.world = world
self.src_path = "/Users/ted/Desktop/dog-dataset/Media/Tremblay-BaB-SoundscapeGolcarWithDog.wav"
self.fftp = FFTProcess[ClassifierWindow](self.world, ClassifierWindow(self.world[].sample_rate), windowsize, hopsize)
self.src = Buffer.load(self.src_path)
self.player = Play(self.world)
self.m = Messenger(self.world)

def next(mut self) -> MFloat[2]:

if self.m.notify_update(self.src_path,"src_path"):
self.src = Buffer.load(self.src_path)

src = self.player.next(self.src)
_ = self.fftp.next(src)
return src
21 changes: 21 additions & 0 deletions examples/Classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sys
from pathlib import Path
import argparse

sys.path.insert(0, str(Path(__file__).parent.parent))

from mmm_python import *

def main():
parser = argparse.ArgumentParser(description="Run the MMMAudio Classifier example.")
parser.add_argument("--src", type=str, help="Source audio file to classify", required=False)
args = parser.parse_args()
# outdevice = 'BlackHole 2ch'
outdevice = 'default'
mmm_audio = MMMAudio(in_device=None, out_device=outdevice, blocksize=512, graph_name="Classifier", package_name="examples")
if args.src:
mmm_audio.send_string("src", args.src)
mmm_audio.start_audio()

if __name__ == "__main__":
main()
Loading
Loading