You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A namespace with the required arguments. Typically, this can be gotten from add_model_specific_args().
79
-
loss_fn : Callable
80
-
A function to be used to compute the loss for the training. The input of this function must match the output of the
81
-
forward() method. The output of this function must be a tensor with a single value.
82
78
output_stride : int
83
79
How many times the output of the network is smaller than the input.
80
+
loss_fn : Optional[Callable]
81
+
A function to be used to compute the loss for the training. The input of this function must match the output of the
82
+
forward() method. The output of this function must be a tensor with a single value.
83
+
lr : Optional[float]
84
+
The learning rate to be used for training the model. If not provided, it will be set as 1e-4.
85
+
wdecay : Optional[float]
86
+
The weight decay to be used for training the model. If not provided, it will be set as 1e-4.
87
+
warm_start : bool, default False
88
+
If True, use warm start to initialize the flow prediction. The warm_start strategy was presented by the RAFT method and forward interpolates the prediction from the last frame.
0 commit comments