diff --git a/hw1/cs285/infrastructure/replay_buffer.py b/hw1/cs285/infrastructure/replay_buffer.py index 60148e79..96171943 100644 --- a/hw1/cs285/infrastructure/replay_buffer.py +++ b/hw1/cs285/infrastructure/replay_buffer.py @@ -76,8 +76,9 @@ def sample_random_data(self, batch_size): ## HINT 1: use np.random.permutation to sample random indices ## HINT 2: return corresponding data points from each array (i.e., not different indices from each array) ## HINT 3: look at the sample_recent_data function below - - return TODO, TODO, TODO, TODO, TODO + + indices = np.random.permutation(len(self))[:batch_size] + return self.obs[indices], self.acs[indices], self.rews[indices], self.next_obs[indices], self.terminals[indices] def sample_recent_data(self, batch_size=1): return (