Skip to content

Commit 2e59069

Browse files
author
Lucas Liebenwein
committed
added lost/sis experiments
1 parent 0932cbc commit 2e59069

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
### SIS Experiments (Section 4, informative feature comparisons)
2+
3+
**Note: These experiments are provided as is. No guarantee that they will
4+
run bug-free**
5+
6+
*N.B.: I never had time to clean them up but for the sake of completeness I am adding them here. :)*
7+
8+
Make sure you install the `SIS` repository:
9+
```
10+
https://github.com/google-research/google-research/tree/master/sufficient_input_subsets
11+
```
12+
13+
Then you can use `sis_backward_selection.py` as "inspiration" to run these
14+
experiments.
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""Heatmap plots for input decision rationales across models.
2+
3+
Example usage:
4+
python sis_backward_selection.py \
5+
--idx_experiment=0 \
6+
--num_images=250
7+
"""
8+
from __future__ import print_function
9+
10+
import os # noqa
11+
import sys
12+
this_folder = os.path.split(os.path.abspath('__file__'))[0] # noqa
13+
src_folder = os.path.join(this_folder, 'provable_compression', 'src') # noqa
14+
os.chdir(src_folder) # noqa
15+
sys.path.insert(0, src_folder)
16+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # noqa
17+
# print(os.getcwd()) # noqa
18+
19+
from torchprune.util import get_parameters
20+
from experiment import Logger, Evaluator
21+
import torch
22+
import copy
23+
import numpy as np
24+
import argparse
25+
26+
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument('--idx_experiment', type=int, required=True)
29+
parser.add_argument('--num_images', type=int, required=True, default=250)
30+
args = parser.parse_args()
31+
print(args)
32+
33+
34+
# some parameters to configure the pruned networks
35+
# parameters for running SIS
36+
37+
# NOTE: consider code down to line 131 for full extraction
38+
39+
# CHANGE IDX_EXPERIMENT TO DESIRED EXPERIMENT
40+
# * range(0, 4) --> resnet20
41+
# * range(4, 7) --> vgg16
42+
# * range(8, 12) --> resnet20_rewind
43+
# * range(12, 16) --> vgg16_rewind
44+
#
45+
# Within a range the following order of method applies:
46+
# 0. WT --> weight-pruning, magnitude-based
47+
# 1. FT --> filter-pruning, magnitude-based
48+
# 2. SiPP --> weight-pruning, data-informed
49+
# 3. PFP --> filter-pruning, data-informed
50+
IDX_EXPERIMENT = args.idx_experiment
51+
52+
# DON'T CHANGE
53+
IDX_REF = 0
54+
IDX_UNCORRELATED = -1
55+
56+
# DON'T CHANGE
57+
FILES = [
58+
"experiment/cifar/resnet20.yaml",
59+
"experiment/cifar/vgg16.yaml",
60+
"experiment/cifar/resnet20_rewind.yaml",
61+
"experiment/cifar/vgg16_rewind.yaml"
62+
]
63+
METHODS = ["ThresNet", "FilterThresNet", "SiPPNetStar", "PopNet"]
64+
DESIRED_PR = [0.15, 0.46, 0.69, 0.80, 0.98]
65+
66+
# Put together the model description and file
67+
assert(0 <= IDX_EXPERIMENT <= 15)
68+
69+
FILE = FILES[IDX_EXPERIMENT // len(METHODS)]
70+
71+
MODELS_DESCRIPTION = [{"method": "ReferenceNet", "pr": 0.0, "n_idx": 0}]
72+
for pr in DESIRED_PR:
73+
MODELS_DESCRIPTION.append(
74+
{
75+
"method": METHODS[IDX_EXPERIMENT % len(METHODS)],
76+
"pr": pr,
77+
"n_idx": 0
78+
})
79+
MODELS_DESCRIPTION.append({"method": "ReferenceNet", "pr": 0.0, "n_idx": 1})
80+
81+
# %%
82+
# Run the compression experiments (or load results if available)
83+
# Initialize the logger and get the parameters
84+
param = next(get_parameters(FILE, 1, 0))
85+
Logger().initialize_from_param(param)
86+
87+
# Initialize the evaluator
88+
compressor = Evaluator()
89+
90+
# load stats into logger so we don't have to re-run the evaluations
91+
# if that doesn't work because some parameters don't we have to re-run eval
92+
try:
93+
Logger().load_global_state()
94+
except ValueError:
95+
compressor.run()
96+
97+
# store mean and std dev for later
98+
mean_c = np.asarray(param['datasets'][param['network']['dataset']]['mean'])[
99+
:, np.newaxis, np.newaxis]
100+
std_c = np.asarray(param['datasets'][param['network']['dataset']]['std'])[
101+
:, np.newaxis, np.newaxis]
102+
103+
# device settings
104+
torch.cuda.set_device("cuda:0")
105+
device = torch.device("cuda:0")
106+
device_storage = torch.device("cpu")
107+
108+
# %%
109+
# Retrieve the models we want ...
110+
111+
# Generate all the models we like.
112+
# get a list of models
113+
models = [compressor.get_by_pr(**kwargs) for kwargs in MODELS_DESCRIPTION]
114+
115+
# get the prune ratios
116+
PRUNE_RATIOS = [1 - model.size() / models[IDX_REF].size() for model in models]
117+
118+
# construct the legends
119+
LEGENDS = [f"{param['network_names'][type(model).__name__]} (PR={pr:.2f})"
120+
for model, pr in zip(models, PRUNE_RATIOS)]
121+
LEGENDS[IDX_REF] = "Unpruned network"
122+
LEGENDS[IDX_UNCORRELATED] = "Separate network"
123+
124+
# get the standard plot color for each network
125+
COLORS = [param['network_colors'][type(model).__name__] for model in models]
126+
COLORS[IDX_UNCORRELATED] = "grey"
127+
128+
# store accuracy as well for reference
129+
TEST_LOSS = []
130+
ACCURACY_TOP1 = []
131+
ACCURACY_TOP5 = []
132+
for model in models:
133+
model.to(device)
134+
loss, acc1, acc5 = compressor._net_trainer.test(model)
135+
model.to(device_storage)
136+
TEST_LOSS.append(loss.item())
137+
ACCURACY_TOP1.append(acc1)
138+
ACCURACY_TOP5.append(acc5)
139+
140+
# Load datasets
141+
loader_train, loader_val, loader_test = compressor.get_dataloader(
142+
"train", "valid", "test")
143+
144+
145+
# %%
146+
# create one big tensor of images for each set
147+
def get_entire_dataset(dataloader):
148+
dataset = copy.deepcopy(dataloader.dataset)
149+
num_imgs = len(dataset)
150+
images = torch.zeros(size=(num_imgs,)+dataset[0][0].shape)
151+
labels = torch.zeros(dtype=int, size=(num_imgs,))
152+
153+
for i in range(len(dataset)):
154+
images[i], labels[i] = dataset[i]
155+
return images, labels
156+
157+
158+
data_train = get_entire_dataset(loader_train)
159+
data_test = get_entire_dataset(loader_test)
160+
data_val = get_entire_dataset(loader_val)
161+
162+
163+
164+
for m in models:
165+
m.to('cuda')
166+
m.eval()
167+
168+
# %%
169+
170+
os.chdir(this_folder)
171+
import collections
172+
import sis_util
173+
from sufficient_input_subsets import sis
174+
from tqdm import tqdm
175+
176+
177+
178+
OUT_BASEDIR = './sis_data/idx_experiment_%d/' % IDX_EXPERIMENT
179+
print(OUT_BASEDIR)
180+
181+
182+
# Run SIS backward selection on CIFAR test images and write to disk.
183+
184+
SIS_THRESHOLD = 0.0 # To capture the results of backward selection.
185+
INITIAL_MASK = sis.make_empty_boolean_mask_broadcast_over_axis([3, 32, 32], 0)
186+
FULLY_MASKED_IMAGE = np.zeros((3, 32, 32), dtype='float32')
187+
188+
for model_i in tqdm(range(len(models))):
189+
model = models[model_i]
190+
sis_out_dir = os.path.join(OUT_BASEDIR, 'model_%d' % model_i)
191+
print(sis_out_dir)
192+
if not os.path.exists(sis_out_dir):
193+
os.makedirs(sis_out_dir)
194+
195+
for i in range(args.num_images):
196+
image = data_test[0][i]
197+
label = data_test[1][i]
198+
sis_filepath = os.path.join(sis_out_dir, 'test_%d.npz' % i)
199+
# If SIS file already exists, skip.
200+
if os.path.exists(sis_filepath):
201+
continue
202+
sis_result = sis_util.find_sis_on_input(
203+
model, image, INITIAL_MASK, FULLY_MASKED_IMAGE, SIS_THRESHOLD,
204+
add_softmax=True, batch_size=128)
205+
sis_util.save_sis_result(sis_result, sis_filepath)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Util for running SIS on PyTorch models."""
2+
3+
import numpy as np
4+
import torch
5+
6+
from sufficient_input_subsets import sis
7+
8+
9+
def predict(model, inputs, add_softmax=False):
10+
model.eval()
11+
with torch.no_grad():
12+
preds = model(inputs)
13+
if add_softmax:
14+
preds = torch.nn.functional.softmax(preds, dim=1)
15+
return preds
16+
17+
18+
def make_f_for_class(model, class_idx, batch_size=128, add_softmax=False):
19+
def f(inputs):
20+
with torch.no_grad():
21+
ret_np = False
22+
if isinstance(inputs, np.ndarray):
23+
ret_np = True
24+
inputs = torch.from_numpy(inputs).cuda()
25+
else:
26+
inputs = inputs.cuda()
27+
num_batches = int(np.ceil(inputs.shape[0] / batch_size))
28+
all_preds = []
29+
for batch_idx in range(num_batches):
30+
batch_start_i = batch_idx * batch_size
31+
batch_end_i = min(inputs.shape[0],
32+
(batch_idx + 1) * batch_size)
33+
assert batch_end_i > batch_start_i
34+
preds = predict(
35+
model,
36+
inputs[batch_start_i:batch_end_i],
37+
add_softmax=add_softmax)[:, class_idx]
38+
all_preds.append(preds)
39+
all_preds = torch.cat(all_preds)
40+
if ret_np:
41+
all_preds = all_preds.cpu().numpy()
42+
return all_preds
43+
return f
44+
45+
46+
def find_sis_on_input(model, x, initial_mask, fully_masked_input, threshold,
47+
batch_size=128, add_softmax=False):
48+
"""Find first SIS on input x with PyTorch model."""
49+
if isinstance(x, np.ndarray):
50+
x = torch.from_numpy(x).cuda()
51+
with torch.no_grad():
52+
pred = model(x.unsqueeze(0).cuda())[0]
53+
pred_class = int(pred.argmax())
54+
pred_confidence = float(pred.max())
55+
if pred_confidence < threshold:
56+
return None
57+
f_class = make_f_for_class(model, pred_class, batch_size=batch_size,
58+
add_softmax=add_softmax)
59+
sis_result = sis.find_sis(
60+
f_class,
61+
threshold,
62+
x.cpu().numpy(),
63+
initial_mask,
64+
fully_masked_input,
65+
)
66+
return sis_result
67+
68+
69+
def create_masked_input(x, sis_result, fully_masked_input):
70+
return sis.produce_masked_inputs(
71+
x.cpu().numpy(), fully_masked_input, [sis_result.mask])[0]
72+
73+
74+
def save_sis_result(sr, filepath):
75+
np.savez_compressed(
76+
filepath,
77+
sis=np.array(sr.sis),
78+
ordering_over_entire_backselect=sr.ordering_over_entire_backselect,
79+
values_over_entire_backselect=sr.values_over_entire_backselect,
80+
mask=sr.mask,
81+
)
82+
83+
84+
def load_sis_result(filepath):
85+
loaded = np.load(filepath)
86+
sr = sis.SISResult(
87+
sis=loaded['sis'],
88+
ordering_over_entire_backselect=(
89+
loaded['ordering_over_entire_backselect']),
90+
values_over_entire_backselect=loaded['values_over_entire_backselect'],
91+
mask=loaded['mask'],
92+
)
93+
return sr

0 commit comments

Comments
 (0)