Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TrainConfig(BaseModel):
early_stopping_patience: int = 10
early_stopping_min_delta: float = 0.001
early_stopping_use_ema: bool = False
progress_bar: bool = False
tensorboard: bool = True
wandb: bool = False
project: Optional[str] = None
Expand Down
46 changes: 34 additions & 12 deletions rfdetr/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import sys
from typing import Iterable
from tqdm import tqdm

import torch

Expand Down Expand Up @@ -66,12 +67,10 @@ def train_one_epoch(
metric_logger.add_meter(
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Epoch: [{}]".format(epoch)
print_freq = 10
start_steps = epoch * num_training_steps_per_epoch

print("Grad accum steps: ", args.grad_accum_steps)
print("Total batch size: ", batch_size * utils.get_world_size())
# print("Grad accum steps: ", args.grad_accum_steps)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how helpful these print statements are in the current context, so I left them unchanged. I'll leave it to a core maintainer to decide whether they should be removed or kept.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time before an epoch, the same info was printed, so I shifted it to the main code to print it once, before the start of the training while adding a progress bar.

# print("Total batch size: ", batch_size * utils.get_world_size())

# Add gradient scaler for AMP
if DEPRECATED_AMP:
Expand All @@ -82,10 +81,14 @@ def train_one_epoch(
optimizer.zero_grad()
assert batch_size % args.grad_accum_steps == 0
sub_batch_size = batch_size // args.grad_accum_steps
print("LENGTH OF DATA LOADER:", len(data_loader))
for data_iter_step, (samples, targets) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)
):
# print("LENGTH OF DATA LOADER:", len(data_loader))

header = f"Epoch: [{epoch+1}/{args.epochs}]"
if args.progress_bar:
progress = tqdm(enumerate(data_loader), total=len(data_loader), desc=header, colour="green")
else:
progress = enumerate(metric_logger.log_every(data_loader, print_freq=10, header=header))
for data_iter_step, (samples, targets) in progress:
it = start_steps + data_iter_step
callback_dict = {
"step": it,
Expand Down Expand Up @@ -162,13 +165,22 @@ def train_one_epoch(
)
metric_logger.update(class_error=loss_dict_reduced["class_error"])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])

if args.progress_bar:
log_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
log_dict = {'lr':log_dict['lr'],
'class_loss':"%.2f"%log_dict['class_error'],
'box_loss':"%.2f"%log_dict['loss_bbox'],
'loss':"%.2f"%log_dict['loss']}
progress.set_postfix(log_dict)

# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
print("Averaged stats:", metric_logger, '\n')
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None):
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, args=None, header='Eval'):
model.eval()
if args.fp16_eval:
model.half()
Expand All @@ -178,12 +190,15 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
metric_logger.add_meter(
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
)
header = "Test:"

iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)

for samples, targets in metric_logger.log_every(data_loader, 10, header):
if args.progress_bar:
progress = tqdm(data_loader, total=len(data_loader), desc=header, colour="green")
else:
progress = metric_logger.log_every(data_loader, print_freq=10, header=header)
for samples, targets in progress:
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

Expand Down Expand Up @@ -237,6 +252,11 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
if coco_evaluator is not None:
coco_evaluator.update(res)

log_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
log_dict = {'class_loss':"%.2f"%log_dict['class_error'],
'box_loss':"%.2f"%log_dict['loss_bbox'],
'loss':"%.2f"%log_dict['loss']}

# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
Expand All @@ -253,4 +273,6 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
if "segm" in postprocessors.keys():
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()

print()
return stats, coco_evaluator
17 changes: 12 additions & 5 deletions rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
model_without_ddp = model.module

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
print('\nnumber of params:', n_parameters)
param_dicts = get_param_dict(args, model_without_ddp)

param_dicts = [p for p in param_dicts if p['params'].requires_grad]
Expand All @@ -191,7 +191,9 @@ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
weight_decay=args.weight_decay)
# Choose the learning rate scheduler based on the new argument

print("\nloading training data:")
dataset_train = build_dataset(image_set='train', args=args, resolution=args.resolution)
print("\nloading validation data:")
dataset_val = build_dataset(image_set='val', args=args, resolution=args.resolution)

# for cosine annealing, calculate total training steps and warmup steps
Expand Down Expand Up @@ -265,7 +267,7 @@ def lr_lambda(current_step: int):
output_dir = Path(args.output_dir)

if utils.is_main_process():
print("Get benchmark")
print("\nGetting benchmark")
if args.do_benchmark:
benchmark_model = copy.deepcopy(model_without_ddp)
bm = benchmark(benchmark_model.float(), dataset_val, output_dir)
Expand Down Expand Up @@ -309,7 +311,12 @@ def lr_lambda(current_step: int):
args.cutoff_epoch, args.drop_mode, args.drop_schedule)
print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))

print("Start training")
print("batch size:", args.batch_size)
print("Grad accum steps:", args.grad_accum_steps)
print("Effective batch size:", effective_batch_size)
print("Training steps/epoch:", len(data_loader_train))
print("\ntraining...")

start_time = time.time()
best_map_holder = BestMetricHolder(use_ema=args.use_ema)
best_map_5095 = 0
Expand Down Expand Up @@ -355,7 +362,7 @@ def lr_lambda(current_step: int):

with torch.inference_mode():
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args=args
model, criterion, postprocessors, data_loader_val, base_ds, device, args=args, header="Test"
)

map_regular = test_stats['coco_eval_bbox'][0]
Expand All @@ -378,7 +385,7 @@ def lr_lambda(current_step: int):
'n_parameters': n_parameters}
if args.use_ema:
ema_test_stats, _ = evaluate(
self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args
self.ema_m.module, criterion, postprocessors, data_loader_val, base_ds, device, args=args, header="Test-ema"
)
log_stats.update({f'ema_test_{k}': v for k,v in ema_test_stats.items()})
map_ema = ema_test_stats['coco_eval_bbox'][0]
Expand Down