Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions graph_net/test/swin_t_extract_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import os

import torch
from torchvision import transforms
from torchvision.models import get_model, get_model_weights

import graph_net


def extract_swin_t_graph(model_name: str, model_path: str):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

batch_size = 1
height, width = 224, 224
num_channels = 3
random_input = torch.rand(batch_size, num_channels, height, width)
normalized_input = normalize(random_input)

weights = None
try:
w = get_model_weights(model_path)
weights = w.DEFAULT
except Exception:
pass

model = get_model(model_path, weights=weights)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
normalized_input = normalized_input.to(device)

model = graph_net.torch.extract(name=model_name, dynamic=False)(model)

print("Running inference...")
print("Input shape:", normalized_input.shape)
with torch.no_grad():
output = model(normalized_input)
print("Inference finished. Output shape:", output.shape)


if __name__ == "__main__":
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "workspace")

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="swin_t")
parser.add_argument("--model_path", type=str, default="swin_t")
parser.add_argument("--workspace", type=str, default=workspace_default)
args = parser.parse_args()

os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace

extract_swin_t_graph(args.model_name, args.model_path)
1 change: 1 addition & 0 deletions samples/torchvision/swin_t/graph_hash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
7d17ac7dfc2ce690ae08f3a44dc915b308b8995d29b685a1d4e1a51741d03ffc
7 changes: 7 additions & 0 deletions samples/torchvision/swin_t/graph_net.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": false,
"model_name": "swin_t"
}
Empty file.
Empty file.
866 changes: 866 additions & 0 deletions samples/torchvision/swin_t/model.py

Large diffs are not rendered by default.

Loading
Loading