@@ -265,17 +265,40 @@ def _add(self, *args):
265265 Args:
266266 *args: All the elements in a transition.
267267 """
268- cursor = self .cursor ()
268+ self ._check_args_length (* args )
269+ transition = {e .name : args [idx ]
270+ for idx , e in enumerate (self .get_add_args_signature ())}
271+ self ._add_transition (transition )
272+
273+ def _add_transition (self , transition ):
274+ """Internal add method to add transition dictionary to storage arrays.
269275
270- arg_names = [e .name for e in self .get_add_args_signature ()]
271- for arg_name , arg in zip (arg_names , args ):
272- self ._store [arg_name ][cursor ] = arg
276+ Args:
277+ transition: The dictionary of names and values of the transition
278+ to add to the storage.
279+ """
280+ cursor = self .cursor ()
281+ for arg_name in transition :
282+ self ._store [arg_name ][cursor ] = transition [arg_name ]
273283
274284 self .add_count += 1
275285 self .invalid_range = invalid_range (
276286 self .cursor (), self ._replay_capacity , self ._stack_size ,
277287 self ._update_horizon )
278288
289+ def _check_args_length (self , * args ):
290+ """Check if args passed to the add method have the same length as storage.
291+
292+ Args:
293+ *args: Args for elements used in storage.
294+
295+ Raises:
296+ ValueError: If args have wrong length.
297+ """
298+ if len (args ) != len (self .get_add_args_signature ()):
299+ raise ValueError ('Add expects {} elements, received {}' .format (
300+ len (self .get_add_args_signature ()), len (args )))
301+
279302 def _check_add_types (self , * args ):
280303 """Checks if args passed to the add method match those of the storage.
281304
@@ -285,9 +308,7 @@ def _check_add_types(self, *args):
285308 Raises:
286309 ValueError: If args have wrong shape or dtype.
287310 """
288- if len (args ) != len (self .get_add_args_signature ()):
289- raise ValueError ('Add expects {} elements, received {}' .format (
290- len (self .get_add_args_signature ()), len (args )))
311+ self ._check_args_length (* args )
291312 for arg_element , store_element in zip (args , self .get_add_args_signature ()):
292313 if isinstance (arg_element , np .ndarray ):
293314 arg_shape = arg_element .shape
0 commit comments