Skip to content

Regarding inference with pretrained weights. #53

@sorobedio

Description

@sorobedio

Hello.
i am currently working withe checkpoint of model query with dataset cifar10-valid. I am unable to reproduce the results in the architecture information using the same dataloader.

here is part of the code used
`
import os
import argparse
import numpy as np
import pandas as pd
import torchvision.datasets as dset
import torchvision.transforms as transforms

from tqdm import tqdm
import torch.nn as nn
import torch

from nats_bench import create
from nas_201_api import NASBench201API as API
from nas_201_api import ResultsCount
from models import get_cell_based_tiny_net

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

root = '../../../../Datasets/NASBench/'
batch_size=32
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
test_data = dset.CIFAR10(root, train=False, transform=test_transform, download=True)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
shuffle=False, num_workers=4)

def test():
correct = 0
total = 0

with torch.no_grad():
    for data in tqdm(testloader):
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        _, outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

if name == 'main':
base_nasdir = '../../../../Datasets/NASBench/NATS-tss-v1_0-3ffb9-full/'
api = create(base_nasdir, 'tss', fast_mode=True, verbose=False)

config = api.get_net_config(0, 'cifar10-valid')
model = get_cell_based_tiny_net(config)
params = api.get_net_param(0, 'cifar10-valid', seed=777, hp="200")
model.load_state_dict(params)

)
model = model.to(device)
test()

outpus 75%

expected results should be in range below
82.092,
81.616
82.240
`
i wonder if my procedure is correct.
Is there a better way to reprouce the result given by

results = api.query_by_index(0, 'cifar10-valid', hp="200") for seed, result in results.items(): vacc = result.get_eval('x-valid')['accuracy'] tacc = result.get_eval('ori-test')['accuracy']

thank you in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions