@@ -109,31 +109,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
109109 with autocast ():
110110 outputs , * rest = model (** batch )
111111 acc = rest [0 ] if rest else - 1
112- audio_acc = rest [1 ] if rest else - 1 # seven layers of audio acc
113- layer_loss = rest [2 ] if rest else - 1 # eight layers of loss (seven audio and one text)
114112 loss = outputs .loss
115113
116114 loss = loss / gradient_accumulation_steps
117- layer_loss = [l / gradient_accumulation_steps for l in layer_loss ]
118115 acc = acc / gradient_accumulation_steps
119- audio_acc = [a / gradient_accumulation_steps for a in audio_acc ]
120116
121117 if log_config .use_wandb and step % log_config .log_interval == 0 :
122118 if train_config .enable_fsdp or train_config .enable_ddp :
123119 if rank == 0 :
124- wandb .log ({"train_inner/train_inner_loss" :loss , "train_inner/train_inner_text_accuracy" :acc }, step = (epoch * total_length + step ))
125- for layer , acc in enumerate (audio_acc ):
126- wandb .log ({f"train_inner/train_inner_audio_accuracy_layer{ layer } " :acc }, step = (epoch * total_length + step ))
127- for layer , l in enumerate (layer_loss [:- 1 ]):
128- wandb .log ({f"train_inner/train_inner_audio_loss_layer{ layer } " :l }, step = (epoch * total_length + step ))
129- wandb .log ({f"train_inner/train_inner_text_loss" :layer_loss [- 1 ]}, step = (epoch * total_length + step ))
120+ wandb .log ({"train_inner/train_inner_loss" :loss , "train_inner/train_inner_accuracy" :acc }, step = (epoch * total_length + step ))
130121 else :
131122 wandb .log ({"train_inner/train_inner_loss" :loss , "train_inner/train_inner_accuracy" :acc }, step = (epoch * total_length + step ))
132- for layer , acc in enumerate (audio_acc ):
133- wandb .log ({f"train_inner/train_inner_audio_accuracy_layer{ layer } " :acc }, step = (epoch * total_length + step ))
134- for layer , l in enumerate (layer_loss [:- 1 ]):
135- wandb .log ({f"train_inner/train_inner_audio_loss_layer{ layer } " :l }, step = (epoch * total_length + step ))
136- wandb .log ({f"train_inner/train_inner_text_loss" :layer_loss [- 1 ]}, step = (epoch * total_length + step ))
137123
138124 total_loss += loss .detach ().float ()
139125 total_acc += acc
0 commit comments