diff --git a/training/example_train_script.sh b/training/example_train_script.sh index 4976c56..5b731e1 100644 --- a/training/example_train_script.sh +++ b/training/example_train_script.sh @@ -8,4 +8,5 @@ python main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \ --logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \ --iterative_training_warm_up_epoch 50 --reuse_img_embedding \ --label_prompt_warm_up_epoch 25 \ ---checkpoint ./runs/9s_2dembed_model.pt +--checkpoint ./runs/9s_2dembed_model.pt \ +--num_classes 105 diff --git a/training/main_2pt5d.py b/training/main_2pt5d.py index 46f3637..a16c1c7 100644 --- a/training/main_2pt5d.py +++ b/training/main_2pt5d.py @@ -113,6 +113,8 @@ parser.add_argument("--skip_bk", action="store_true", help="skip background (0) during training") parser.add_argument("--patch_embed_3d", action="store_true", help="using 3d patch embedding layer") +parser.add_argument("--num_classes", default=105, type=int, help="number of output classes") + def start_tb(log_dir): cmd = ["tensorboard", "--logdir", log_dir] @@ -123,6 +125,10 @@ def main(): args = parser.parse_args() args.amp = not args.noamp args.logdir = "./runs/" + args.logdir + + if args.num_classes == 0: + warnings.warn("consider setting the correct number of classes") + # start_tb(args.logdir) if args.seed > -1: set_determinism(seed=args.seed) @@ -162,7 +168,7 @@ def main_worker(gpu, args): dice_loss = DiceCELoss(sigmoid=True) - post_label = AsDiscrete(to_onehot=105) + post_label = AsDiscrete(to_onehot=args.num_classes) post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True) diff --git a/training/trainer_2pt5d.py b/training/trainer_2pt5d.py index 1d1e906..8e9ff59 100644 --- a/training/trainer_2pt5d.py +++ b/training/trainer_2pt5d.py @@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model): unique_labels = unique_labels[: args.num_prompt] # add 4 background labels to every batch - background_labels = list(set([i for i in range(1, 105)]) - set(unique_labels.cpu().numpy())) + background_labels = list(set([i for i in range(1, args.num_classes)]) - set(unique_labels.cpu().numpy())) random.shuffle(background_labels) unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:4]).cuda(args.rank)]) @@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar def prepare_sam_test_input(inputs, labels, args, previous_pred=None): - unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank) + unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank) # preprocess make the size of lable same as high_res_logit batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float() @@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None): def prepare_sam_val_input_cp_only(inputs, labels, args): # Don't exclude background in val but will ignore it in metric calculation - unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank) + unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank) # preprocess make the size of lable same as high_res_logit batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float() @@ -457,15 +457,18 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label= y_pred = torch.stack(post_pred(decollate_batch(logit)), 0) # TODO: we compute metric for each prompt for simplicity in validation. - acc_batch = compute_dice(y_pred=y_pred, y=target) + acc_batch = compute_dice(y_pred=y_pred[None,], y=target[None,]) acc_sum, not_nans = ( torch.nansum(acc_batch).item(), - 104 - torch.sum(torch.isnan(acc_batch).float()).item(), + (args.num_classes - 1) - torch.sum(torch.isnan(acc_batch).float()).item(), ) acc_sum_total += acc_sum not_nans_total += not_nans - acc, not_nans = acc_sum_total / not_nans_total, not_nans_total + if not_nans_total > 0: + acc, not_nans = acc_sum_total / not_nans_total, not_nans_total + else: + acc, not_nans = 0, 0 f_name = batch_data["image"].meta["filename_or_obj"] print(f"Rank: {args.rank}, Case: {f_name}, Acc: {acc:.4f}, N_prompts: {int(not_nans)} ")