We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a993949 commit 4a85b17Copy full SHA for 4a85b17
lavis/tasks/base_task.py
@@ -61,9 +61,12 @@ def build_datasets(self, cfg):
61
return datasets
62
63
def train_step(self, model, samples):
64
- loss_dict = model(samples)
65
- loss = loss_dict["loss"]
66
- return loss, loss_dict
+ output = model(samples)
+ loss_dict = {}
+ for k,v in output.items():
67
+ if "loss" in k:
68
+ loss_dict[k] = v
69
+ return output["loss"], loss_dict
70
71
def valid_step(self, model, samples):
72
raise NotImplementedError
0 commit comments