diff --git a/hw1/cs285/policies/MLP_policy.py b/hw1/cs285/policies/MLP_policy.py index c8e1fd7d..2e353c68 100644 --- a/hw1/cs285/policies/MLP_policy.py +++ b/hw1/cs285/policies/MLP_policy.py @@ -109,7 +109,11 @@ def update( adv_n=None, acs_labels_na=None, qvals=None ): # TODO: update the policy and return the loss - loss = TODO + self.optimizer.zero_grad() + predicted_actions = self(torch.Tensor(observations).to(self.device)) + loss = self.loss_func(predicted_actions, torch.Tensor(actions).to(self.device)) + loss.backward() + self.optimizer.step() return { # You can add extra logging information here, but keep this line 'Training Loss': ptu.to_numpy(loss),