-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
89 lines (74 loc) · 3.57 KB
/
train.py
File metadata and controls
89 lines (74 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch.distributed import all_reduce, ReduceOp, barrier
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
import torch
from torch.distributed import all_reduce, ReduceOp, barrier
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
from torch.amp import autocast, GradScaler
scaler = GradScaler()
def train_step(model, loss_fn, optimizer, dataloader, device, start_iter, print_freq, rank=0,accumulation_steps=4):
model.train()
# Initialize as a tensor on the correct device
train_loss =torch.zeros(1, device = device)
optimizer.zero_grad()
start_iter+=1
for batch, data in enumerate(dataloader):
with autocast('cuda',dtype=torch.float16):
gt = data['GT'].to(device, memory_format=torch.channels_last, non_blocking=True)
lr = data['LR'].to(device,memory_format=torch.channels_last, non_blocking=True)
pred = model(lr)
loss = loss_fn(pred, gt)
# Normalize loss by accumulation steps so gradients don't explode
loss_scaled = loss / accumulation_steps
# 2. Scaled Backward Pass
# Scales loss to prevent underflow in fp16, then backprops
scaler.scale(loss_scaled).backward()
# 3. Gradient Accumulation Step
# Only update weights every 'accumulation_steps'
if (batch + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Accumulate loss
train_loss += loss.detach()
# Get local average
train_loss_avg = train_loss / len(dataloader)
# Get global average across all GPUs
all_reduce(train_loss_avg, op=ReduceOp.AVG)
# Only rank 0 prints
if rank == 0 and start_iter%print_freq ==0:
print(f"[Iteration {epoch}] Train Loss: {train_loss_avg.item():.4f}")
# ALL ranks must return the value
return train_loss_avg.item()
def validation_step(model, loss_fn, dataloader, device, epoch=0, rank=0):
model.eval()
# Initialize metrics as tensors on the correct device
val_loss = torch.zeros(1, device=device)
total_psnr = torch.zeros(1, device=device)
total_ssim = torch.zeros(1, device=device)
with torch.no_grad():
for batch, data in enumerate(dataloader):
gt = data['GT'].to(device, memory_format=torch.channels_last, non_blocking=True)
lr = data['LR'].to(device, memory_format=torch.channels_last, non_blocking=True)
with torch.cuda.amp.autocast('cuda',dtype=torch.float16):
pred = model(lr)
loss = loss_fn(pred, gt)
pred= pred.float()
# Make sure these functions return tensors
psnr = calculate_psnr_pt(pred, gt, crop_border=0, test_y_channel=True)
ssim = calculate_ssim_pt(pred, gt, crop_border=0, test_y_channel=True)
val_loss += loss.detach()
total_psnr += psnr.detach()
total_ssim += ssim.detach()
# Compute local averages
val_loss_avg = val_loss / len(dataloader)
psnr_avg = total_psnr / len(dataloader)
ssim_avg = total_ssim / len(dataloader)
# Reduce across all GPUs (get global average)
all_reduce(val_loss_avg, op=ReduceOp.AVG)
all_reduce(psnr_avg, op=ReduceOp.AVG)
all_reduce(ssim_avg, op=ReduceOp.AVG)
if rank == 0:
print(f"[Validation] Loss: {val_loss_avg.item():.4f}, PSNR: {psnr_avg.item():.2f}, SSIM: {ssim_avg.item():.4f}")
# ALL ranks must return to prevent a TypeError
return val_loss_avg.item(), psnr_avg.item(), ssim_avg.item()