Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified datasets/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified datasets/__pycache__/mvtec.cpython-39.pyc
Binary file not shown.
15 changes: 8 additions & 7 deletions datasets/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,21 @@ def get_image_data(self):
anomaly_files = sorted(os.listdir(anomaly_path))
imgpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_path, x) for x in anomaly_files]

if self.split == DatasetSplit.TEST and anomaly != "good":
anomaly_mask_path = os.path.join(maskpath, anomaly)
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
else:
maskpaths_per_class[self.classname]["good"] = None
# if self.split == DatasetSplit.TEST and anomaly != "good":
# anomaly_mask_path = os.path.join(maskpath, anomaly)
# anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
# maskpaths_per_class[self.classname][anomaly] = [os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files]
# else:
# maskpaths_per_class[self.classname]["good"] = None

data_to_iterate = []
for classname in sorted(imgpaths_per_class.keys()):
for anomaly in sorted(imgpaths_per_class[classname].keys()):
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
data_tuple = [classname, anomaly, image_path]
if self.split == DatasetSplit.TEST and anomaly != "good":
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
# data_tuple.append(maskpaths_per_class[classname][anomaly][i])
data_tuple.append(None)
else:
data_tuple.append(None)
data_to_iterate.append(data_tuple)
Expand Down
226 changes: 194 additions & 32 deletions glass.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,16 @@ def load(
svd=0,
step=20,
limit=392,
es_epoch=10,
**kwargs,
):

self.backbone = backbone.to(device)
self.layers_to_extract_from = layers_to_extract_from
self.input_shape = input_shape
self.device = device
self.es_epoch = es_epoch
assert es_epoch is not None, f"Please set early stopping epochs"

self.forward_modules = torch.nn.ModuleDict({})
feature_aggregator = common.NetworkFeatureAggregator(
Expand Down Expand Up @@ -258,6 +261,10 @@ def update_state_dict():
pbar_str1 = ""
best_record = None

best_score = -1
epoch_counter = 0
best_state = None

with mlflow.start_run():
mlflow.log_param("meta_epochs", self.meta_epochs)
mlflow.log_param("eval_epochs", self.eval_epochs)
Expand Down Expand Up @@ -293,44 +300,56 @@ def update_state_dict():

if (i_epoch + 1) % self.eval_epochs == 0:
images, scores, segmentations, labels_gt, masks_gt = self.predict(val_data)
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
# image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
# labels_gt, masks_gt, name)
image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
labels_gt, masks_gt, name)

mlflow.log_metric("img_auroc", image_auroc, step=i_epoch)
mlflow.log_metric("pixel_auroc", pixel_auroc, step=i_epoch)
# mlflow.log_metric("pixel_auroc", pixel_auroc, step=i_epoch)
mlflow.log_metric("img_threshold", img_threshold, step=i_epoch)
mlflow.log_metric("img_f1_max", img_f1_max, step=i_epoch)

# self.logger.logger.add_scalar("i-auroc", image_auroc, i_epoch)
# self.logger.logger.add_scalar("i-ap", image_ap, i_epoch)
# self.logger.logger.add_scalar("p-auroc", pixel_auroc, i_epoch)
# self.logger.logger.add_scalar("p-ap", pixel_ap, i_epoch)
# self.logger.logger.add_scalar("p-pro", pixel_pro, i_epoch)

eval_path = './results/eval/' + name + '/'
train_path = './results/training/' + name + '/'
if best_record is None:
best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
# best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
best_record = [image_auroc, image_ap, img_f1_max, i_epoch]
ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch))
torch.save(state_dict, ckpt_path_best)
shutil.rmtree(eval_path, ignore_errors=True)
shutil.copytree(train_path, eval_path)

elif image_auroc + pixel_auroc > best_record[0] + best_record[2]:
best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
# elif image_auroc + pixel_auroc > best_record[0] + best_record[2]:
elif image_auroc > best_record[0]:
# best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
best_record = [image_auroc, image_ap, img_f1_max, i_epoch]
os.remove(ckpt_path_best)
ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch))
torch.save(state_dict, ckpt_path_best)
shutil.rmtree(eval_path, ignore_errors=True)
shutil.copytree(train_path, eval_path)

pbar_str1 = f" IAUC:{round(image_auroc * 100, 2)}({round(best_record[0] * 100, 2)})" \
f" PAUC:{round(pixel_auroc * 100, 2)}({round(best_record[2] * 100, 2)})" \
f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[5] * 100, 2)})" \
f" PAUC: Do not have PAUC)" \
f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[2] * 100, 2)})" \
f" E:{i_epoch}({best_record[-1]})"
# f" PAUC:{round(pixel_auroc * 100, 2)}({round(best_record[2] * 100, 2)})" \

pbar_str += pbar_str1
pbar.set_description_str(pbar_str)

# current_score = image_auroc*1 + pixel_auroc*0
current_score = image_auroc*1
if current_score - best_score > 0.1:
best_score = current_score
epoch_counter = 0
else:
epoch_counter += 1
if epoch_counter > self.es_epoch:
LOGGER.info(f"Early stopping triggered at epoch {i_epoch}")
break

torch.save(state_dict, ckpt_path_save)
return best_record

Expand Down Expand Up @@ -503,14 +522,17 @@ def tester(self, test_data, name):
self.load_state_dict(state_dict, strict=False)

images, scores, segmentations, labels_gt, masks_gt = self.predict(test_data)
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
# image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
# labels_gt, masks_gt, name, path='eval')
image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
labels_gt, masks_gt, name, path='eval')
epoch = int(ckpt_path[0].split('_')[-1].split('.')[0])
else:
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch = 0., 0., 0., 0., 0., 0., 0., -1.
LOGGER.info("No ckpt file found!")

return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch
# return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch
return image_auroc, image_ap, img_threshold, img_f1_max, epoch

def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
scores = np.squeeze(np.array(scores))
Expand All @@ -530,23 +552,23 @@ def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, pa
max_scores = np.max(segmentations)
norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10)

pixel_scores = metrics.compute_pixelwise_retrieval_metrics(norm_segmentations, masks_gt, path)
pixel_auroc = pixel_scores["auroc"]
pixel_ap = pixel_scores["ap"]
if path == 'eval':
try:
pixel_pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations)
# pixel_scores = metrics.compute_pixelwise_retrieval_metrics(norm_segmentations, masks_gt, path)
# pixel_auroc = pixel_scores["auroc"]
# pixel_ap = pixel_scores["ap"]
# if path == 'eval':
# try:
# pixel_pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations)

except:
pixel_pro = 0.
else:
pixel_pro = 0.
# except:
# pixel_pro = 0.
# else:
# pixel_pro = 0.

else:
pixel_auroc = -1.
pixel_ap = -1.
pixel_pro = -1.
return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro,img_threshold, img_f1_max
# else:
# pixel_auroc = -1.
# pixel_ap = -1.
# pixel_pro = -1.
# return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro,img_threshold, img_f1_max

defects = np.array(images)
targets = np.array(masks_gt)
Expand All @@ -565,7 +587,8 @@ def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, pa
utils.del_remake_dir(full_path, del_flag=False)
cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up)

return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max
# return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max
return image_auroc, image_ap, img_threshold, img_f1_max

def predict(self, test_dataloader):
"""This function provides anomaly scores/maps for full dataloaders."""
Expand Down Expand Up @@ -604,7 +627,6 @@ def _predict(self, img):
self.discriminator.eval()

with torch.no_grad():

patch_features, patch_shapes = self._embed(img, provide_patch_shapes=True, evaluation=True)
if self.pre_proj > 0:
patch_features = self.pre_projection(patch_features)
Expand All @@ -622,3 +644,143 @@ def _predict(self, img):
image_scores = image_scores.cpu().numpy()

return list(image_scores), list(masks)

def tta_evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
scores = np.squeeze(np.array(scores))
img_min_scores = min(scores)
img_max_scores = max(scores)
norm_scores = (scores - img_min_scores) / (img_max_scores - img_min_scores + 1e-10)

image_scores = metrics.compute_imagewise_retrieval_metrics(norm_scores, labels_gt, path)
image_auroc = image_scores["auroc"]
image_ap = image_scores["ap"]

img_threshold, img_f1_max = metrics.compute_best_pr_re(labels_gt, norm_scores)

if len(masks_gt) > 0:
segmentations = np.array(segmentations)

target_height, target_width = images[0].shape[:2]

segmentations = segmentations.astype(np.float32)
min_scores = np.min(segmentations)
max_scores = np.max(segmentations)
norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10)

norm_segmentations = (norm_segmentations * 255).astype(np.uint8)

pixel_scores = metrics.compute_pixelwise_retrieval_metrics(norm_segmentations, masks_gt, path)
pixel_auroc = pixel_scores["auroc"]
pixel_ap = pixel_scores["ap"]
if path == 'eval':
try:
pixel_pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations)

except:
pixel_pro = 0.
else:
pixel_pro = 0.

else:
pixel_auroc = -1.
pixel_ap = -1.
pixel_pro = -1.
return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro,img_threshold, img_f1_max

defects = np.array(images)
targets = np.array(masks_gt)
for i in range(len(defects)):
defect = utils.torch_format_2_numpy_img(defects[i])
target = utils.torch_format_2_numpy_img(targets[i])

resized_mask = cv2.resize(norm_segmentations[i], (target_width, target_height), interpolation=cv2.INTER_LINEAR)
resized_mask = resized_mask.astype(np.uint8)
mask = cv2.cvtColor(resized_mask, cv2.COLOR_GRAY2BGR)

mask = (mask * 255).astype('uint8')
mask = cv2.applyColorMap(mask, cv2.COLORMAP_JET)

img_up = np.hstack([defect, target, mask])
img_up = cv2.resize(img_up, (256 * 3, 256))
full_path = './results/' + path + '/' + name + '/'
utils.del_remake_dir(full_path, del_flag=False)
cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up)

return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max

def tta_predict(self, test_dataloader):
"""This function provides anomaly scores/maps for full dataloaders."""
self.forward_modules.eval()

img_paths = []
images = []
scores = []
masks = []
labels_gt = []
masks_gt = []

with tqdm.tqdm(test_dataloader, desc="Inferring...", leave=False, unit='batch') as data_iterator:
for data in data_iterator:
if isinstance(data, dict):
labels_gt.extend(data["is_anomaly"].numpy().tolist())
if data.get("mask_gt", None) is not None:
masks_gt.extend(data["mask_gt"].numpy().tolist())
image = data["image"]
images.extend(image.numpy().tolist())
img_paths.extend(data["image_path"])
_scores, _masks = self.tta__predict(image)
for score, mask in zip(_scores, _masks):
scores.append(score)
masks.append(mask)

return images, scores, masks, labels_gt, masks_gt

def tta__predict(self, img):
"""Infer score and mask for a batch of images with TTA (仅修改此函数)."""
self.forward_modules.eval()
if self.pre_proj > 0:
self.pre_projection.eval()
self.discriminator.eval()

img = img.to(self.device) if not img.is_cuda else img

tta_transforms = [
{'name': 'original', 'transform': lambda x: x, 'reverse': lambda x: x},
{'name': 'h_flip', 'transform': lambda x: x.flip(3), 'reverse': lambda x: x.flip(3)},
{'name': 'v_flip', 'transform': lambda x: x.flip(2), 'reverse': lambda x: x.flip(2)}
]

all_scores = []
all_masks = []

with torch.no_grad():
for aug in tta_transforms:
transformed_img = aug['transform'](img)

patch_features, patch_shapes = self._embed(transformed_img, provide_patch_shapes=True, evaluation=True)

if self.pre_proj > 0:
patch_features = self.pre_projection(patch_features)
patch_features = patch_features[0] if len(patch_features)==2 else patch_features

patch_scores = self.discriminator(patch_features)

patch_scores_unpatched = self.patch_maker.unpatch_scores(patch_scores, batchsize=img.shape[0])
scales = patch_shapes[0]

score_tensor = patch_scores_unpatched.reshape(img.shape[0], scales[0], scales[1])

masks = self.anomaly_segmentor.convert_to_segmentation(score_tensor)
if isinstance(masks, list):
masks = torch.stack([torch.from_numpy(m).to(self.device) for m in masks])

reversed_masks = aug['reverse'](masks.unsqueeze(1)).squeeze(1)

image_scores = self.patch_maker.score(patch_scores_unpatched)
all_scores.append(image_scores)
all_masks.append(reversed_masks)

avg_scores = torch.mean(torch.stack(all_scores), dim=0)
avg_masks = torch.mean(torch.stack(all_masks), dim=0)

return avg_scores.cpu().numpy().tolist(), avg_masks.cpu().numpy().tolist()
Loading