diff --git a/README.md b/README.md index 3cea37a..78a47cb 100644 --- a/README.md +++ b/README.md @@ -219,7 +219,7 @@ In our original paper, we acknowledge that there were anomalies in the test data ## 👉 Train Prepare your own dataset and refer to the samples in `SAM-Med2D/data_demo` to replace them according to your specific scenario. You need to generate the `image2label_train.json` file before running `train.py`. -If you want to use mixed-precision training, please install [Apex](https://github.com/NVIDIA/apex). If you don't want to install Apex, you can comment out the line `from apex import amp` and set `use_amp` to False. +If you want to use mixed-precision training, please install [Apex](https://github.com/NVIDIA/apex). ```bash cd ./SAM-Med2D diff --git a/train.py b/train.py index aa7f1ba..99d0901 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import numpy as np import datetime from torch.nn import functional as F -from apex import amp import random @@ -102,6 +101,8 @@ def prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_ite def train_one_epoch(args, model, optimizer, train_loader, epoch, criterion): + if args.use_amp: + from apex import amp train_loader = tqdm(train_loader) train_losses = [] train_iter_metrics = [0] * len(args.metrics)