Skip to content

Commit f4745b3

Browse files
minor
1 parent 66c245b commit f4745b3

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

tests/post_training/pipelines/image_classification_timm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ def prepare_model(self) -> None:
4545
onnx_path = self.fp32_model_dir / "model_fp32.onnx"
4646
additional_kwargs = {}
4747
if self.batch_size > 1:
48-
batch = torch.export.Dim("batch")
49-
additional_kwargs["dynamic_shapes"] = ({0: batch},)
50-
48+
additional_kwargs["input_names"] = ["image"]
49+
additional_kwargs["dynamic_axes"] = {"image": {0: "batch"}}
5150
torch.onnx.export(
5251
timm_model,
5352
self.dummy_tensor,
@@ -56,7 +55,6 @@ def prepare_model(self) -> None:
5655
opset_version=13,
5756
**additional_kwargs,
5857
)
59-
6058
self.model = onnx.load(onnx_path)
6159
self.input_name = self.model.graph.input[0].name
6260

tests/post_training/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ accelerate==1.9.0
2222
transformers==4.53.0
2323
whowhatbench @ git+https://github.com/openvinotoolkit/[email protected]#subdirectory=tools/who_what_benchmark
2424
datasets==3.6.0
25-
onnxscript
25+
onnxscript==0.5.4

0 commit comments

Comments
 (0)