diff --git a/AL_center.py b/AL_center.py index 392d0f1..dfa4578 100644 --- a/AL_center.py +++ b/AL_center.py @@ -29,7 +29,7 @@ parser = argparse.ArgumentParser("Center Loss Example") # dataset -parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['mnist', 'cifar100', 'cifar10']) +parser.add_argument('-d', '--dataset', type=str, default='combined_wafer_data.npz', choices=['combined_wafer_data.npz']) parser.add_argument('-j', '--workers', default=0, type=int, help="number of data loading workers (default: 4)") # optimization @@ -281,7 +281,7 @@ def plot_features(features, labels, num_classes, epoch, prefix): features: (num_instances, num_features). labels: (num_instances). """ - colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] + colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8'] for label_idx in range(num_classes): plt.scatter( features[labels==label_idx, 0], @@ -289,7 +289,7 @@ def plot_features(features, labels, num_classes, epoch, prefix): c=colors[label_idx], s=1, ) - plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right') + plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8'], loc='upper right') dirname = osp.join(args.save_dir, prefix) if not osp.exists(dirname): os.mkdir(dirname) diff --git a/datasets.py b/datasets.py index 845aca4..e85294c 100644 --- a/datasets.py +++ b/datasets.py @@ -2,7 +2,6 @@ import torchvision from torch.utils.data import DataLoader from torch.utils.data import SubsetRandomSampler -from simclr.modules.transformations import TransformsSimCLR import random import transforms @@ -293,4 +292,4 @@ def create(name, known_class_, init_percent_, batch_size, use_gpu, num_workers, init_percent = init_percent_ if name not in __factory.keys(): raise KeyError("Unknown dataset: {}".format(name)) - return __factory[name](batch_size, use_gpu, num_workers, is_filter, is_mini, unlabeled_ind_train, labeled_ind_train) \ No newline at end of file + return __factory[name](batch_size, use_gpu, num_workers, is_filter, is_mini, unlabeled_ind_train, labeled_ind_train)