Skip to content

Commit 4830fe1

Browse files
committed
update train_utils
1 parent 229765a commit 4830fe1

File tree

1 file changed

+1
-15
lines changed

1 file changed

+1
-15
lines changed

src/slam_llm/utils/train_utils.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,31 +109,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
109109
with autocast():
110110
outputs, *rest = model(**batch)
111111
acc = rest[0] if rest else -1
112-
audio_acc = rest[1] if rest else -1 # seven layers of audio acc
113-
layer_loss = rest[2] if rest else -1 # eight layers of loss (seven audio and one text)
114112
loss = outputs.loss
115113

116114
loss = loss / gradient_accumulation_steps
117-
layer_loss = [l / gradient_accumulation_steps for l in layer_loss]
118115
acc = acc / gradient_accumulation_steps
119-
audio_acc = [a / gradient_accumulation_steps for a in audio_acc]
120116

121117
if log_config.use_wandb and step % log_config.log_interval == 0:
122118
if train_config.enable_fsdp or train_config.enable_ddp:
123119
if rank==0:
124-
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step))
125-
for layer, acc in enumerate(audio_acc):
126-
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
127-
for layer, l in enumerate(layer_loss[:-1]):
128-
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
129-
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
120+
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
130121
else:
131122
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
132-
for layer, acc in enumerate(audio_acc):
133-
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
134-
for layer, l in enumerate(layer_loss[:-1]):
135-
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
136-
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
137123

138124
total_loss += loss.detach().float()
139125
total_acc += acc

0 commit comments

Comments
 (0)