|
17 | 17 | from nncf.structures import QuantizationRangeInitArgs |
18 | 18 | from nncf.utils import is_tensor |
19 | 19 | from nncf.utils import objwalk |
20 | | -from nncf.utils import training_mode_switcher |
| 20 | +from contextlib import contextmanager |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class InitializingDataLoader: |
@@ -164,39 +164,66 @@ def __init__(self, model, init_device: str, num_bn_forget_steps): |
164 | 164 | self.num_bn_forget_steps = num_bn_forget_steps |
165 | 165 | self.momentum_bn_forget = 0.9 |
166 | 166 | self.original_momenta_values = {} |
| 167 | + self.original_training_state = {} |
167 | 168 |
|
168 | 169 | @staticmethod |
169 | 170 | def _apply_to_batchnorms(func): |
170 | 171 | def func_apply_to_bns(module): |
171 | | - if isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d): |
| 172 | + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm1d, |
| 173 | + torch.nn.modules.batchnorm.BatchNorm2d, |
| 174 | + torch.nn.modules.batchnorm.BatchNorm3d)): |
172 | 175 | func(module) |
173 | 176 |
|
174 | 177 | return func_apply_to_bns |
175 | 178 |
|
176 | | - def _run_model_inference(self, data_loader, num_init_steps, device): |
177 | | - num_bn_forget_steps = self.num_bn_forget_steps |
| 179 | + @contextmanager |
| 180 | + def _bn_training_state_switcher(self) -> None: |
| 181 | + def save_original_bn_training_state(module: torch.nn.Module): |
| 182 | + self.original_training_state[module] = module.training |
| 183 | + |
| 184 | + def set_bn_training_state(module: torch.nn.Module, state: Dict[str, bool]): |
| 185 | + module.training = state |
| 186 | + |
| 187 | + def restore_original_bn_training_state(module: torch.nn.Module): |
| 188 | + module.training = self.original_training_state[module] |
| 189 | + |
| 190 | + self.model.apply(self._apply_to_batchnorms(save_original_bn_training_state)) |
| 191 | + self.model.apply(self._apply_to_batchnorms(partial(set_bn_training_state, state=True))) |
| 192 | + try: |
| 193 | + yield |
| 194 | + finally: |
| 195 | + self.model.apply(self._apply_to_batchnorms(restore_original_bn_training_state)) |
178 | 196 |
|
| 197 | + @contextmanager |
| 198 | + def _bn_momentum_switcher(self) -> None: |
179 | 199 | def set_bn_momentum(module, momentum_value): |
180 | 200 | module.momentum = momentum_value |
181 | 201 |
|
182 | | - def save_original_bn_momenta(module): |
| 202 | + def save_original_bn_momentum(module: torch.nn.Module): |
183 | 203 | self.original_momenta_values[module] = module.momentum |
184 | 204 |
|
185 | | - def restore_original_bn_momenta(module): |
| 205 | + def restore_original_bn_momentum(module: torch.nn.Module): |
186 | 206 | module.momentum = self.original_momenta_values[module] |
187 | 207 |
|
188 | | - with training_mode_switcher(self.model, is_training=True): |
189 | | - self.model.apply(self._apply_to_batchnorms(save_original_bn_momenta)) |
190 | | - self.model.apply(self._apply_to_batchnorms(partial(set_bn_momentum, |
191 | | - momentum_value=self.momentum_bn_forget))) |
| 208 | + self.model.apply(self._apply_to_batchnorms(save_original_bn_momentum)) |
| 209 | + self.model.apply(self._apply_to_batchnorms(partial(set_bn_momentum, |
| 210 | + momentum_value=self.momentum_bn_forget))) |
| 211 | + try: |
| 212 | + yield |
| 213 | + finally: |
| 214 | + self.model.apply(self._apply_to_batchnorms(restore_original_bn_momentum)) |
192 | 215 |
|
193 | | - for i, loaded_item in enumerate(data_loader): |
194 | | - if num_bn_forget_steps is not None and i >= num_bn_forget_steps: |
195 | | - break |
196 | | - args_kwargs_tuple = data_loader.get_inputs(loaded_item) |
197 | | - self._infer_batch(args_kwargs_tuple, device) |
| 216 | + def _run_model_inference(self, data_loader, num_init_steps, device): |
| 217 | + num_bn_forget_steps = self.num_bn_forget_steps |
198 | 218 |
|
199 | | - self.model.apply(self._apply_to_batchnorms(restore_original_bn_momenta)) |
| 219 | + with self._bn_training_state_switcher(): |
| 220 | + if num_bn_forget_steps is not None and num_bn_forget_steps > 0: |
| 221 | + with self._bn_momentum_switcher(): |
| 222 | + for i, loaded_item in enumerate(data_loader): |
| 223 | + if i >= num_bn_forget_steps: |
| 224 | + break |
| 225 | + args_kwargs_tuple = data_loader.get_inputs(loaded_item) |
| 226 | + self._infer_batch(args_kwargs_tuple, device) |
200 | 227 |
|
201 | 228 | for i, loaded_item in ProgressBar( |
202 | 229 | enumerate(data_loader), |
|
0 commit comments