Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 199 additions & 45 deletions glass.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,18 @@ def load(
svd=0,
step=20,
limit=392,
es_epoch=10,
tta=False,
**kwargs,
):

self.backbone = backbone.to(device)
self.layers_to_extract_from = layers_to_extract_from
self.input_shape = input_shape
self.device = device
self.es_epoch = es_epoch
print(f"early stopping epochs: {self.es_epoch}")
self.tta = tta

self.forward_modules = torch.nn.ModuleDict({})
feature_aggregator = common.NetworkFeatureAggregator(
Expand Down Expand Up @@ -199,8 +204,9 @@ def trainer(self, training_data, val_data, name):
ckpt_path = glob.glob(self.ckpt_dir + '/ckpt_best*')
ckpt_path_save = os.path.join(self.ckpt_dir, "ckpt.pth")
if len(ckpt_path) != 0:
LOGGER.info("Start testing, ckpt file found!")
return 0., 0., 0., 0., 0., -1.
# LOGGER.info("Start testing, ckpt file found!")
# return 0., 0., 0., 0., 0., -1.
LOGGER.info("Ckpt file found, retrain!")

def update_state_dict():
state_dict["discriminator"] = OrderedDict({
Expand Down Expand Up @@ -258,6 +264,10 @@ def update_state_dict():
pbar_str1 = ""
best_record = None

best_score = -1
epoch_counter = 0
best_state = None

with mlflow.start_run():
mlflow.log_param("meta_epochs", self.meta_epochs)
mlflow.log_param("eval_epochs", self.eval_epochs)
Expand Down Expand Up @@ -293,44 +303,50 @@ def update_state_dict():

if (i_epoch + 1) % self.eval_epochs == 0:
images, scores, segmentations, labels_gt, masks_gt = self.predict(val_data)
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
labels_gt, masks_gt, name)

mlflow.log_metric("img_auroc", image_auroc, step=i_epoch)
mlflow.log_metric("pixel_auroc", pixel_auroc, step=i_epoch)
mlflow.log_metric("img_threshold", img_threshold, step=i_epoch)
mlflow.log_metric("img_f1_max", img_f1_max, step=i_epoch)

# self.logger.logger.add_scalar("i-auroc", image_auroc, i_epoch)
# self.logger.logger.add_scalar("i-ap", image_ap, i_epoch)
# self.logger.logger.add_scalar("p-auroc", pixel_auroc, i_epoch)
# self.logger.logger.add_scalar("p-ap", pixel_ap, i_epoch)
# self.logger.logger.add_scalar("p-pro", pixel_pro, i_epoch)

eval_path = './results/eval/' + name + '/'
train_path = './results/training/' + name + '/'
if best_record is None:
best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
best_record = [image_auroc, image_ap, img_f1_max, i_epoch]
ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch))
torch.save(state_dict, ckpt_path_best)
shutil.rmtree(eval_path, ignore_errors=True)
shutil.copytree(train_path, eval_path)

elif image_auroc + pixel_auroc > best_record[0] + best_record[2]:
best_record = [image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_f1_max, i_epoch]
# elif image_auroc + pixel_auroc > best_record[0] + best_record[2]:
elif image_auroc > best_record[0]:
best_record = [image_auroc, image_ap, img_f1_max, i_epoch]
os.remove(ckpt_path_best)
ckpt_path_best = os.path.join(self.ckpt_dir, "ckpt_best_{}.pth".format(i_epoch))
torch.save(state_dict, ckpt_path_best)
shutil.rmtree(eval_path, ignore_errors=True)
shutil.copytree(train_path, eval_path)

pbar_str1 = f" IAUC:{round(image_auroc * 100, 2)}({round(best_record[0] * 100, 2)})" \
f" PAUC:{round(pixel_auroc * 100, 2)}({round(best_record[2] * 100, 2)})" \
f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[5] * 100, 2)})" \
f" PAUC: Do not have PAUC)" \
f" IF1-max:{round(img_f1_max * 100, 2)}({round(best_record[2] * 100, 2)})" \
f" E:{i_epoch}({best_record[-1]})"

pbar_str += pbar_str1
pbar.set_description_str(pbar_str)

# current_score = image_auroc*1 + pixel_auroc * 0
current_score = image_auroc*0.5 + img_f1_max*0.5
if current_score - best_score > 0.1:
best_score = current_score
epoch_counter = 0
else:
epoch_counter += 1
if epoch_counter > self.es_epoch:
LOGGER.info(f"Early stopping triggered at epoch {i_epoch}")
break

torch.save(state_dict, ckpt_path_save)
return best_record

Expand Down Expand Up @@ -362,7 +378,7 @@ def _train_discriminator(self, input_data, cur_epoch, pbar, pbar_str1):
true_feats = self._embed(img, evaluation=False)[0]
true_feats.requires_grad = True

mask_s_gt = data_item["mask_s"].reshape(-1, 1).to(self.device)
mask_s_gt = data_item["mask_s"].reshape(-1, 1).to(self.device) # feature
noise = torch.normal(0, self.noise, true_feats.shape).to(self.device)
gaus_feats = true_feats + noise

Expand Down Expand Up @@ -502,15 +518,21 @@ def tester(self, test_data, name):
else:
self.load_state_dict(state_dict, strict=False)

images, scores, segmentations, labels_gt, masks_gt = self.predict(test_data)
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
labels_gt, masks_gt, name, path='eval')
if self.tta:
print("Do tta")
images, scores, segmentations, labels_gt, masks_gt = self.tta_predict(test_data)
image_auroc, image_ap, img_threshold, img_f1_max = self.tta_evaluate(images, scores, segmentations,
labels_gt, masks_gt, name, path='eval')
else:
images, scores, segmentations, labels_gt, masks_gt = self.predict(test_data)
image_auroc, image_ap, img_threshold, img_f1_max = self._evaluate(images, scores, segmentations,
labels_gt, masks_gt, name, path='eval')
epoch = int(ckpt_path[0].split('_')[-1].split('.')[0])
else:
image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch = 0., 0., 0., 0., 0., 0., 0., -1.
image_auroc, image_ap, img_threshold, img_f1_max, epoch = 0., 0., 0., 0., -1.
LOGGER.info("No ckpt file found!")

return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max, epoch
return image_auroc, image_ap, img_threshold, img_f1_max, epoch

def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
scores = np.squeeze(np.array(scores))
Expand All @@ -530,42 +552,26 @@ def _evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, pa
max_scores = np.max(segmentations)
norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10)

pixel_scores = metrics.compute_pixelwise_retrieval_metrics(norm_segmentations, masks_gt, path)
pixel_auroc = pixel_scores["auroc"]
pixel_ap = pixel_scores["ap"]
if path == 'eval':
try:
pixel_pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations)

except:
pixel_pro = 0.
else:
pixel_pro = 0.

else:
pixel_auroc = -1.
pixel_ap = -1.
pixel_pro = -1.
return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro,img_threshold, img_f1_max

defects = np.array(images)
targets = np.array(masks_gt)
# targets = np.array(masks_gt)
for i in range(len(defects)):
defect = utils.torch_format_2_numpy_img(defects[i])
target = utils.torch_format_2_numpy_img(targets[i])
# target = utils.torch_format_2_numpy_img(targets[i])

mask = cv2.cvtColor(cv2.resize(norm_segmentations[i], (defect.shape[1], defect.shape[0])),
cv2.COLOR_GRAY2BGR)
mask = (mask * 255).astype('uint8')
mask = cv2.applyColorMap(mask, cv2.COLORMAP_JET)

img_up = np.hstack([defect, target, mask])
img_up = cv2.resize(img_up, (256 * 3, 256))
# img_up = np.hstack([defect, target, mask])
# img_up = cv2.resize(img_up, (256 * 3, 256))
img_up = np.hstack([defect, mask])
img_up = cv2.resize(img_up, (256 * 2, 256))
full_path = './results/' + path + '/' + name + '/'
utils.del_remake_dir(full_path, del_flag=False)
cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up)

return image_auroc, image_ap, pixel_auroc, pixel_ap, pixel_pro, img_threshold, img_f1_max
return image_auroc, image_ap, img_threshold, img_f1_max

def predict(self, test_dataloader):
"""This function provides anomaly scores/maps for full dataloaders."""
Expand Down Expand Up @@ -604,7 +610,6 @@ def _predict(self, img):
self.discriminator.eval()

with torch.no_grad():

patch_features, patch_shapes = self._embed(img, provide_patch_shapes=True, evaluation=True)
if self.pre_proj > 0:
patch_features = self.pre_projection(patch_features)
Expand All @@ -622,3 +627,152 @@ def _predict(self, img):
image_scores = image_scores.cpu().numpy()

return list(image_scores), list(masks)


def tta_evaluate(self, images, scores, segmentations, labels_gt, masks_gt, name, path='training'):
scores = np.squeeze(np.array(scores))
img_min_scores = min(scores)
img_max_scores = max(scores)
norm_scores = (scores - img_min_scores) / (img_max_scores - img_min_scores + 1e-10)

image_scores = metrics.compute_imagewise_retrieval_metrics(norm_scores, labels_gt, path)
image_auroc = image_scores["auroc"]
image_ap = image_scores["ap"]

img_threshold, img_f1_max = metrics.compute_best_pr_re(labels_gt, norm_scores)

if len(masks_gt) > 0:
segmentations = np.array(segmentations)
min_scores = np.min(segmentations)
max_scores = np.max(segmentations)
norm_segmentations = (segmentations - min_scores) / (max_scores - min_scores + 1e-10)

# ============== DEBUG ==============
print("\n[DEBUG] Data shape validation:")
print("1. Original segmentation result dtype:", segmentations.dtype, "shape:", segmentations.shape)
print("2. Normalized dtype:", norm_segmentations.dtype, "range:", np.min(norm_segmentations), "-", np.max(norm_segmentations))

norm_segmentations = norm_segmentations.astype(np.float32) # Fix 1:convert to float32
print("3. Converted dtype:", norm_segmentations.dtype)

defects = np.array(images)
targets = np.array(masks_gt)
for i in range(len(defects)):
if i == 0: # print debug info only for the first sample
print("\n[DEBUG] Sample processing pipeline validation (i=0):")

defect = utils.torch_format_2_numpy_img(defects[i])
target = utils.torch_format_2_numpy_img(targets[i])

# Maintain single channel during resizing
resized_mask = cv2.resize(
norm_segmentations[i],
(defect.shape[1], defect.shape[0]),
interpolation=cv2.INTER_LINEAR
)
if i == 0:
print("4. Resized mask shape:", resized_mask.shape, "dtype:", resized_mask.dtype)

# Convert to 0-255 and uint8
mask_8bit = (resized_mask * 255).astype(np.uint8)
if i == 0:
print("5. Converted to uint8 range:", np.min(mask_8bit), "-", np.max(mask_8bit), "dtype:", mask_8bit.dtype)

# Color sapce conversion
try:
mask_color = cv2.cvtColor(mask_8bit, cv2.COLOR_GRAY2BGR)
if i == 0:
print("6. Converted color shape:", mask_color.shape)
except Exception as e:
print(f"\n[ERROR] Color conversion failed!")
print("error mask_8bit parameters:", f"shape:{mask_8bit.shape}", f"dtype:{mask_8bit.dtype}")
raise e

# Apply color mapping
mask_color = cv2.applyColorMap(mask_color, cv2.COLORMAP_JET)
if i == 0:
print("7. Applied color mapping shape:", mask_color.shape, "dtype:", mask_color.dtype)

# Concatenate result images
img_up = np.hstack([defect, target, mask_color])
img_up = cv2.resize(img_up, (256 * 3, 256))
full_path = './results/' + path + '/' + name + '/'
utils.del_remake_dir(full_path, del_flag=False)
cv2.imwrite(full_path + str(i + 1).zfill(3) + '.png', img_up)

return image_auroc, image_ap, img_threshold, img_f1_max

def tta_predict(self, test_dataloader):
"""This function provides anomaly scores/maps for full dataloaders."""
self.forward_modules.eval()

img_paths = []
images = []
scores = []
masks = []
labels_gt = []
masks_gt = []

with tqdm.tqdm(test_dataloader, desc="Inferring...", leave=False, unit='batch') as data_iterator:
for data in data_iterator:
if isinstance(data, dict):
labels_gt.extend(data["is_anomaly"].numpy().tolist())
if data.get("mask_gt", None) is not None:
masks_gt.extend(data["mask_gt"].numpy().tolist())
image = data["image"]
images.extend(image.numpy().tolist())
img_paths.extend(data["image_path"])
_scores, _masks = self.tta__predict(image)
for score, mask in zip(_scores, _masks):
scores.append(score)
masks.append(mask)

return images, scores, masks, labels_gt, masks_gt

def tta__predict(self, img):
"""Infer score and mask for a batch of images with TTA"""
self.forward_modules.eval()
if self.pre_proj > 0:
self.pre_projection.eval()
self.discriminator.eval()

tta_transforms = [
{'name': 'original',
'transform': lambda x: x,
'reverse': lambda x: x},

{'name': 'h_flip',
'transform': lambda x: x.flip(-1),
'reverse': lambda x: x.flip(-1)},

{'name': 'v_flip',
'transform': lambda x: x.flip(-2),
'reverse': lambda x: x.flip(-2)},

# {'name': 'rotate90',
# 'transform': lambda x: x.rot90(1, [-2, -1]).flip(-1),
# 'reverse': lambda x: x.flip(-1).rot90(-1, [-2, -1])},

# {'name': 'color_jitter',
# 'transform': lambda x: x * (0.9 + 0.2*torch.rand(1,device=x.device)) + 0.1*torch.randn_like(x),
# 'reverse': lambda x: x},
]

all_scores = []
all_masks = []

with torch.no_grad():
for aug in tta_transforms:
transformed_img = aug['transform'](img)
_scores, _masks = self._predict(transformed_img)

mask_tensor = torch.tensor(np.array(_masks))
reversed_masks = aug['reverse'](torch.tensor(_masks))

all_scores.append(_scores)
all_masks.append(reversed_masks.numpy())

avg_scores = np.mean(all_scores, axis=0)
avg_masks = np.mean(all_masks, axis=0)

return avg_scores.tolist(), avg_masks.tolist()
11 changes: 7 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def main(**kwargs):


@main.command("net")
@click.option("--es_epoch", type=int, default=10, help="Early stopping epochs")
@click.option("--tta", is_flag=True, default=False, help="If using the tta")
@click.option("--dsc_margin", type=float, default=0.5)
@click.option("--train_backbone", is_flag=True)
@click.option("--backbone_names", "-b", type=str, multiple=True, default=[])
Expand Down Expand Up @@ -68,6 +70,8 @@ def net(
svd,
step,
limit,
es_epoch,
tta,
):
backbone_names = list(backbone_names)
if len(backbone_names) > 1:
Expand Down Expand Up @@ -110,6 +114,8 @@ def get_glass(input_shape, device):
svd=svd,
step=step,
limit=limit,
es_epoch=es_epoch,
tta=tta,
)
glasses.append(glass_inst.to(device))
return glasses
Expand Down Expand Up @@ -304,15 +310,12 @@ def run(
df = pd.concat([df, pd.DataFrame(row_dist, index=[0])])

if type(flag) != int:
i_auroc, i_ap, p_auroc, p_ap, p_pro, img_threshold, i_f1_max, epoch = GLASS.tester(dataloaders["testing"], dataset_name)
i_auroc, i_ap, img_threshold, i_f1_max, epoch = GLASS.tester(dataloaders["testing"], dataset_name)
result_collect.append(
{
"dataset_name": dataset_name,
"image_auroc": i_auroc,
"image_ap": i_ap,
"pixel_auroc": p_auroc,
"pixel_ap": p_ap,
"pixel_pro": p_pro,
"image_f1_max": i_f1_max,
"f1_max_threshold": img_threshold,
"best_epoch": epoch,
Expand Down
Loading