How to convert .jit to .onnx? #644
Unanswered
gitqinxinyu
asked this question in
Q&A
Replies: 2 comments
-
|
One approach would be to rebuild the model using original components by loading the fine-tuned weights into it. import torch
import torch.nn as nn
import numpy as np
import onnxruntime as ort
from tuning.utils import VADDecoderRNNJIT
# 1. Load the fine-tuned JIT model to get its components
print("Loading fine-tuned model components...")
finetuned_jit = torch.jit.load('tuning/finetuned_model.jit', map_location='cpu')
stft_module = finetuned_jit._model.stft
encoder_module = finetuned_jit._model.encoder
finetuned_decoder_state_dict = finetuned_jit._model.decoder.state_dict()
print("Model components loaded.")
# 2. Reconstruct the full model architecture using the original components
class SileroVADModel(nn.Module):
def __init__(self, stft, encoder, decoder):
super().__init__()
self.stft = stft
self.encoder = encoder
self.decoder = decoder
def forward(self, x, state):
# This forward pass is simplified for ONNX export and assumes 16kHz
x = self.stft(x)
x = self.encoder(x)
x, new_state = self.decoder(x, state)
return x, new_state
# 3. Instantiate the new, clean model with the original components
print("Building a new model with original components...")
clean_decoder = VADDecoderRNNJIT()
clean_model = SileroVADModel(stft_module, encoder_module, clean_decoder)
clean_model.eval()
print("New model built.")
# 4. Load the fine-tuned decoder weights
print("Loading fine-tuned decoder weights...")
clean_model.decoder.load_state_dict(finetuned_decoder_state_dict)
print("Fine-tuned decoder weights successfully loaded.")
# 5. Script the new model to make it fully compatible for export
print("Scripting the new model...")
scripted_model = torch.jit.script(clean_model)
print("Model scripted successfully.")
# 6. Convert the scripted model to ONNX
print("Converting the scripted model to ONNX...")
onnx_output_path = 'tuning/finetuned_model.onnx'
dummy_input = torch.randn(1, 512)
dummy_state = torch.zeros(2, 1, 128)
torch.onnx.export(
scripted_model,
(dummy_input, dummy_state),
onnx_output_path,
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=['input', 'state'],
output_names=['output', 'stateN'],
dynamic_axes={
'input': {0: 'batch_size', 1: 'sequence_length'},
'state': {1: 'batch_size'},
'output': {0: 'batch_size'},
'stateN': {1: 'batch_size'}
}
)
print(f"Model successfully converted to {onnx_output_path}")
# 7. Verify the ONNX model
print("Verifying the ONNX model...")
session = ort.InferenceSession(onnx_output_path)
input_name = session.get_inputs()[0].name
state_name = session.get_inputs()[1].name
output, stateN = session.run(None, {
input_name: dummy_input.numpy(),
state_name: dummy_state.numpy()
})
print("ONNX model verification successful.")
print(f"Output shape: {output.shape}")
print(f"New state shape: {stateN.shape}")Model is fixed to 16khz sr. You can test inference like this: import onnxruntime as ort
import numpy as np
# 1. Load the ONNX model
onnx_model_path = 'tuning/finetuned_model.onnx'
try:
session = ort.InferenceSession(onnx_model_path)
print(f"Successfully loaded ONNX model from: {onnx_model_path}")
except Exception as e:
print(f"Error loading ONNX model: {e}")
exit()
# 2. Get input and output names
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
print(f"Input names: {input_names}")
print(f"Output names: {output_names}")
# 3. Create dummy input tensors
# The shapes should match what the model expects.
# 'input' shape: (batch_size, sequence_length) -> e.g., (1, 512)
# 'state' shape: (2, batch_size, 128) -> e.g., (2, 1, 128)
dummy_input = np.random.randn(1, 512).astype(np.float32)
dummy_state = np.zeros((2, 1, 128), dtype=np.float32)
print(f"Dummy input shape: {dummy_input.shape}")
print(f"Dummy state shape: {dummy_state.shape}")
# 4. Run inference
print("\nRunning inference...")
try:
output, new_state = session.run(
output_names,
{
'input': dummy_input,
'state': dummy_state
}
)
print("Inference successful.")
print(f"Output shape: {output.shape}")
print(f"New state shape: {new_state.shape}")
except Exception as e:
print(f"An error occurred during inference: {e}") |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Hello~ Have you solved this problem now? I've encountered the same problem. How did you solve it? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to try to convert my trained torch(.jit) model to onnx(.onnx) model, but meet a bug, could you please provide relatived convert script?
Beta Was this translation helpful? Give feedback.
All reactions