Skip to content

Commit 4a85b17

Browse files
author
root
committed
fix loss logging issue
1 parent a993949 commit 4a85b17

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

lavis/tasks/base_task.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@ def build_datasets(self, cfg):
6161
return datasets
6262

6363
def train_step(self, model, samples):
64-
loss_dict = model(samples)
65-
loss = loss_dict["loss"]
66-
return loss, loss_dict
64+
output = model(samples)
65+
loss_dict = {}
66+
for k,v in output.items():
67+
if "loss" in k:
68+
loss_dict[k] = v
69+
return output["loss"], loss_dict
6770

6871
def valid_step(self, model, samples):
6972
raise NotImplementedError

0 commit comments

Comments
 (0)