|
| 1 | +"""Heatmap plots for input decision rationales across models. |
| 2 | +
|
| 3 | +Example usage: |
| 4 | +python sis_backward_selection.py \ |
| 5 | + --idx_experiment=0 \ |
| 6 | + --num_images=250 |
| 7 | +""" |
| 8 | +from __future__ import print_function |
| 9 | + |
| 10 | +import os # noqa |
| 11 | +import sys |
| 12 | +this_folder = os.path.split(os.path.abspath('__file__'))[0] # noqa |
| 13 | +src_folder = os.path.join(this_folder, 'provable_compression', 'src') # noqa |
| 14 | +os.chdir(src_folder) # noqa |
| 15 | +sys.path.insert(0, src_folder) |
| 16 | +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # noqa |
| 17 | +# print(os.getcwd()) # noqa |
| 18 | + |
| 19 | +from torchprune.util import get_parameters |
| 20 | +from experiment import Logger, Evaluator |
| 21 | +import torch |
| 22 | +import copy |
| 23 | +import numpy as np |
| 24 | +import argparse |
| 25 | + |
| 26 | + |
| 27 | +parser = argparse.ArgumentParser() |
| 28 | +parser.add_argument('--idx_experiment', type=int, required=True) |
| 29 | +parser.add_argument('--num_images', type=int, required=True, default=250) |
| 30 | +args = parser.parse_args() |
| 31 | +print(args) |
| 32 | + |
| 33 | + |
| 34 | +# some parameters to configure the pruned networks |
| 35 | +# parameters for running SIS |
| 36 | + |
| 37 | +# NOTE: consider code down to line 131 for full extraction |
| 38 | + |
| 39 | +# CHANGE IDX_EXPERIMENT TO DESIRED EXPERIMENT |
| 40 | +# * range(0, 4) --> resnet20 |
| 41 | +# * range(4, 7) --> vgg16 |
| 42 | +# * range(8, 12) --> resnet20_rewind |
| 43 | +# * range(12, 16) --> vgg16_rewind |
| 44 | +# |
| 45 | +# Within a range the following order of method applies: |
| 46 | +# 0. WT --> weight-pruning, magnitude-based |
| 47 | +# 1. FT --> filter-pruning, magnitude-based |
| 48 | +# 2. SiPP --> weight-pruning, data-informed |
| 49 | +# 3. PFP --> filter-pruning, data-informed |
| 50 | +IDX_EXPERIMENT = args.idx_experiment |
| 51 | + |
| 52 | +# DON'T CHANGE |
| 53 | +IDX_REF = 0 |
| 54 | +IDX_UNCORRELATED = -1 |
| 55 | + |
| 56 | +# DON'T CHANGE |
| 57 | +FILES = [ |
| 58 | + "experiment/cifar/resnet20.yaml", |
| 59 | + "experiment/cifar/vgg16.yaml", |
| 60 | + "experiment/cifar/resnet20_rewind.yaml", |
| 61 | + "experiment/cifar/vgg16_rewind.yaml" |
| 62 | +] |
| 63 | +METHODS = ["ThresNet", "FilterThresNet", "SiPPNetStar", "PopNet"] |
| 64 | +DESIRED_PR = [0.15, 0.46, 0.69, 0.80, 0.98] |
| 65 | + |
| 66 | +# Put together the model description and file |
| 67 | +assert(0 <= IDX_EXPERIMENT <= 15) |
| 68 | + |
| 69 | +FILE = FILES[IDX_EXPERIMENT // len(METHODS)] |
| 70 | + |
| 71 | +MODELS_DESCRIPTION = [{"method": "ReferenceNet", "pr": 0.0, "n_idx": 0}] |
| 72 | +for pr in DESIRED_PR: |
| 73 | + MODELS_DESCRIPTION.append( |
| 74 | + { |
| 75 | + "method": METHODS[IDX_EXPERIMENT % len(METHODS)], |
| 76 | + "pr": pr, |
| 77 | + "n_idx": 0 |
| 78 | + }) |
| 79 | +MODELS_DESCRIPTION.append({"method": "ReferenceNet", "pr": 0.0, "n_idx": 1}) |
| 80 | + |
| 81 | +# %% |
| 82 | +# Run the compression experiments (or load results if available) |
| 83 | +# Initialize the logger and get the parameters |
| 84 | +param = next(get_parameters(FILE, 1, 0)) |
| 85 | +Logger().initialize_from_param(param) |
| 86 | + |
| 87 | +# Initialize the evaluator |
| 88 | +compressor = Evaluator() |
| 89 | + |
| 90 | +# load stats into logger so we don't have to re-run the evaluations |
| 91 | +# if that doesn't work because some parameters don't we have to re-run eval |
| 92 | +try: |
| 93 | + Logger().load_global_state() |
| 94 | +except ValueError: |
| 95 | + compressor.run() |
| 96 | + |
| 97 | +# store mean and std dev for later |
| 98 | +mean_c = np.asarray(param['datasets'][param['network']['dataset']]['mean'])[ |
| 99 | + :, np.newaxis, np.newaxis] |
| 100 | +std_c = np.asarray(param['datasets'][param['network']['dataset']]['std'])[ |
| 101 | + :, np.newaxis, np.newaxis] |
| 102 | + |
| 103 | +# device settings |
| 104 | +torch.cuda.set_device("cuda:0") |
| 105 | +device = torch.device("cuda:0") |
| 106 | +device_storage = torch.device("cpu") |
| 107 | + |
| 108 | +# %% |
| 109 | +# Retrieve the models we want ... |
| 110 | + |
| 111 | +# Generate all the models we like. |
| 112 | +# get a list of models |
| 113 | +models = [compressor.get_by_pr(**kwargs) for kwargs in MODELS_DESCRIPTION] |
| 114 | + |
| 115 | +# get the prune ratios |
| 116 | +PRUNE_RATIOS = [1 - model.size() / models[IDX_REF].size() for model in models] |
| 117 | + |
| 118 | +# construct the legends |
| 119 | +LEGENDS = [f"{param['network_names'][type(model).__name__]} (PR={pr:.2f})" |
| 120 | + for model, pr in zip(models, PRUNE_RATIOS)] |
| 121 | +LEGENDS[IDX_REF] = "Unpruned network" |
| 122 | +LEGENDS[IDX_UNCORRELATED] = "Separate network" |
| 123 | + |
| 124 | +# get the standard plot color for each network |
| 125 | +COLORS = [param['network_colors'][type(model).__name__] for model in models] |
| 126 | +COLORS[IDX_UNCORRELATED] = "grey" |
| 127 | + |
| 128 | +# store accuracy as well for reference |
| 129 | +TEST_LOSS = [] |
| 130 | +ACCURACY_TOP1 = [] |
| 131 | +ACCURACY_TOP5 = [] |
| 132 | +for model in models: |
| 133 | + model.to(device) |
| 134 | + loss, acc1, acc5 = compressor._net_trainer.test(model) |
| 135 | + model.to(device_storage) |
| 136 | + TEST_LOSS.append(loss.item()) |
| 137 | + ACCURACY_TOP1.append(acc1) |
| 138 | + ACCURACY_TOP5.append(acc5) |
| 139 | + |
| 140 | +# Load datasets |
| 141 | +loader_train, loader_val, loader_test = compressor.get_dataloader( |
| 142 | + "train", "valid", "test") |
| 143 | + |
| 144 | + |
| 145 | +# %% |
| 146 | +# create one big tensor of images for each set |
| 147 | +def get_entire_dataset(dataloader): |
| 148 | + dataset = copy.deepcopy(dataloader.dataset) |
| 149 | + num_imgs = len(dataset) |
| 150 | + images = torch.zeros(size=(num_imgs,)+dataset[0][0].shape) |
| 151 | + labels = torch.zeros(dtype=int, size=(num_imgs,)) |
| 152 | + |
| 153 | + for i in range(len(dataset)): |
| 154 | + images[i], labels[i] = dataset[i] |
| 155 | + return images, labels |
| 156 | + |
| 157 | + |
| 158 | +data_train = get_entire_dataset(loader_train) |
| 159 | +data_test = get_entire_dataset(loader_test) |
| 160 | +data_val = get_entire_dataset(loader_val) |
| 161 | + |
| 162 | + |
| 163 | + |
| 164 | +for m in models: |
| 165 | + m.to('cuda') |
| 166 | + m.eval() |
| 167 | + |
| 168 | +# %% |
| 169 | + |
| 170 | +os.chdir(this_folder) |
| 171 | +import collections |
| 172 | +import sis_util |
| 173 | +from sufficient_input_subsets import sis |
| 174 | +from tqdm import tqdm |
| 175 | + |
| 176 | + |
| 177 | + |
| 178 | +OUT_BASEDIR = './sis_data/idx_experiment_%d/' % IDX_EXPERIMENT |
| 179 | +print(OUT_BASEDIR) |
| 180 | + |
| 181 | + |
| 182 | +# Run SIS backward selection on CIFAR test images and write to disk. |
| 183 | + |
| 184 | +SIS_THRESHOLD = 0.0 # To capture the results of backward selection. |
| 185 | +INITIAL_MASK = sis.make_empty_boolean_mask_broadcast_over_axis([3, 32, 32], 0) |
| 186 | +FULLY_MASKED_IMAGE = np.zeros((3, 32, 32), dtype='float32') |
| 187 | + |
| 188 | +for model_i in tqdm(range(len(models))): |
| 189 | + model = models[model_i] |
| 190 | + sis_out_dir = os.path.join(OUT_BASEDIR, 'model_%d' % model_i) |
| 191 | + print(sis_out_dir) |
| 192 | + if not os.path.exists(sis_out_dir): |
| 193 | + os.makedirs(sis_out_dir) |
| 194 | + |
| 195 | + for i in range(args.num_images): |
| 196 | + image = data_test[0][i] |
| 197 | + label = data_test[1][i] |
| 198 | + sis_filepath = os.path.join(sis_out_dir, 'test_%d.npz' % i) |
| 199 | + # If SIS file already exists, skip. |
| 200 | + if os.path.exists(sis_filepath): |
| 201 | + continue |
| 202 | + sis_result = sis_util.find_sis_on_input( |
| 203 | + model, image, INITIAL_MASK, FULLY_MASKED_IMAGE, SIS_THRESHOLD, |
| 204 | + add_softmax=True, batch_size=128) |
| 205 | + sis_util.save_sis_result(sis_result, sis_filepath) |
0 commit comments