Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ It is important to include `donate_argnums` when calling `jax.jit` to enable JAX
### Storing Data with Vault
As mentioned above, Vault stores experience data to disk by extending the temporal axis of a Flashbax buffer state. By default, Vault conveniently handles the bookkeeping of this process: consuming a buffer state and saving any fresh, previously unseen data. e.g. Suppose we write 10 timesteps to our Flashbax buffer, and then save this state to a Vault; since all of this data is fresh, all of it will be written to disk. However, if we then write one more timestep and save the state to the Vault, only that new timestep will be written, preventing any duplication of data that has already been saved. Importantly, one must remember that Flashbax states are implemented as _ring buffers_, meaning the Vault must be updated sufficiently frequently before unseen data in the Flashbax buffer state is overwritten. i.e. If our buffer state has a time-axis length of $\tau$, then we must save to the vault every $\tau - 1$ steps, lest we overwrite (and lose) unsaved data.

### X64 Precision
There can be issues when using 32-bit precision and using the sum tree or prioritised experience replay buffer. Due to numerical instabilities, you stand the chance of sampling priority=zero transitions. To fix this, you can use 64-bit precision or simply mask out zero probability transitions when using them for RL losses or importance sampling weights. Otherwise, you can make a wrapper of the prioritised replay buffer's sample function that simply replaces zero probability transitions with another random (or high probability) transition in the batch. We did not want to impose this replacement functionality on the user.

In summary, understanding and addressing these considerations will help you navigate potential pitfalls and ensure the effectiveness of your reinforcement learning strategies while utilising Flashbax buffers.

## Benchmarks 📈
Expand Down
104 changes: 17 additions & 87 deletions flashbax/buffers/prioritised_trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,17 @@ def prioritised_init(
sum_tree_size = get_sum_tree_capacity(max_length_time_axis, period, add_batch_size)
sum_tree_state = sum_tree.init(sum_tree_size)

# Set the running index - Ideally int64 but we put as int32
running_index = jnp.array(0, dtype=jnp.int32)
# Set the running index - Ideally int64 - if you are doing extremely long training runs then
# this is important. You want to avoid overflow of the running index. This means that when
# using int32 you can have at most 2,147,483,647 * add_batch_size timesteps in the buffer
# before the running index overflows. Check if 64-bit precision is enabled and if not
# then use int32.
if jax.config.x64_enabled: # type: ignore
running_index_dtype = jnp.int64
else:
running_index_dtype = jnp.int32

running_index = jnp.array(0, dtype=running_index_dtype)

return PrioritisedTrajectoryBufferState( # type: ignore
sum_tree_state=sum_tree_state,
Expand Down Expand Up @@ -441,13 +450,17 @@ def _calculate_new_item_priorities(
padding_value = _get_padding_num(max_length_time_axis, period, add_batch_size)
# Calculate the masked valid priorities
new_valid_priorities = jnp.full_like(
newly_valid_item_indices, fill_value=sum_tree_state.max_recorded_priority
newly_valid_item_indices,
fill_value=sum_tree_state.max_recorded_priority,
dtype=sum_tree_state.dtype,
)
vp_mask = newly_valid_item_indices != padding_value
new_valid_priorities = new_valid_priorities * vp_mask

# Get invalid priorities
new_invalid_priorities = jnp.zeros_like(newly_invalid_item_indices)
new_invalid_priorities = jnp.zeros_like(
newly_invalid_item_indices, dtype=sum_tree_state.dtype
)

return new_valid_priorities, new_invalid_priorities

Expand Down Expand Up @@ -562,89 +575,6 @@ def prioritised_add(
)


def get_invalid_indices(
state: TrajectoryBufferState[Experience],
sample_sequence_length: int,
period: int,
add_batch_size: int,
max_length_time_axis: int,
) -> Array:
"""
Get the indices of the items that will be invalid when sampling from the buffer state. This
is used to mask out the invalid items when sampling. The indices are in the format of a
flattened array and refer to items, not the actual data. To convert item indices into data
indices, we would perform the following:

indices = item_indices * period
row_indices = indices // max_length_time_axis
time_indices = indices % max_length_time_axis

Item indices essentially refer to a flattened array picture of the
items (i.e. subsequences that can be sampled) in the buffer state.


Args:
state: The buffer state.
sample_sequence_length: The length of the sequence that will be sampled from the buffer
state.
period: The period refers to the interval between sampled sequences. It serves to regulate
how much overlap there is between the trajectories that are sampled. To understand the
degree of overlap, you can calculate it as the difference between the
sample_sequence_length and the period. For instance, if you set period=1, it means that
trajectories will be sampled uniformly with the potential for any degree of overlap. On
the other hand, if period is equal to sample_sequence_length - 1, then trajectories can
be sampled in a way where only the first and last timesteps overlap with each other.
This helps you control the extent of overlap between consecutive sequences in your
sampling process.
add_batch_size: The number of trajectories that will be added to the buffer state.
max_length_time_axis: The maximum length of the time axis of the buffer state.

Returns:
The indices of the items (with shape : [add_batch_size, num_items]) that will be invalid
when sampling from the buffer state.
"""
# We get the max subsequence data index as done in the add function.
max_divisible_length = max_length_time_axis - (max_length_time_axis % period)
max_subsequence_data_index = max_divisible_length - 1
# We get the data index that is at least sample_sequence_length away from the
# current index.
previous_valid_data_index = (
state.current_index - sample_sequence_length
) % max_length_time_axis
# We ensure that this index is not above the maximum mappable data index of the buffer.
previous_valid_data_index = jnp.minimum(
previous_valid_data_index, max_subsequence_data_index
)
# We then convert the data index into the item index and add one to get the index
# of the item that is broken apart.
invalid_item_starting_index = (previous_valid_data_index // period) + 1
# We then take the modulo of the invalid item index to ensure that it is within the
# bounds of the priority array. max_length_time_axis // period is the maximum number
# of items/subsequences that can be sampled from the buffer state.
invalid_item_starting_index = invalid_item_starting_index % (
max_length_time_axis // period
)

# Calculate the maximum number of items/subsequences that can start within a
# sample length of data. We add one to account for situations where the max
# number of items has been broken. Often, this will unfortunately mask an item
# that is valid however this should not be a severe issue as it would be only
# one additional item.
max_num_invalid_items = (sample_sequence_length // period) + 1
# Get the actual indices of the items we cannot sample from.
invalid_item_indices = (
jnp.arange(max_num_invalid_items) + invalid_item_starting_index
) % (max_length_time_axis // period)
# Since items that are broken are broken in the same place in each row, we
# broadcast and add the total number of items to each index to reference
# the invalid items in each add_batch row.
invalid_item_indices = invalid_item_indices + jnp.arange(add_batch_size)[
:, None
] * (max_length_time_axis // period)

return invalid_item_indices


def prioritised_sample(
state: PrioritisedTrajectoryBufferState[Experience],
rng_key: chex.PRNGKey,
Expand Down
44 changes: 25 additions & 19 deletions flashbax/buffers/sum_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Pure functions defining a sum-tree data structure. The desired use is within a prioritised replay
buffer, see Prioritized Experience Replay by Schaul et al. (2015) and `prioritised_replay.py`.

This is an adaption of the sum-tree implementation from
dopamine: https://github.com/google/dopamine/blob/master/dopamine/replay_memory/sum_tree.py.
Lots of the code is verbatim copied.
The key differences between this implementation and the dopamine implementation are (1) This
implementation is in jax with a functional style, and (2) this implementation focuses on
vectorised adding (rather sequential adding).
"""

from typing import TYPE_CHECKING, Optional, Tuple, Union

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
Expand All @@ -46,6 +33,7 @@ class SumTreeState:
max_recorded_priority: Array
tree_depth: int = struct.field(pytree_node=False)
capacity: int = struct.field(pytree_node=False)
dtype: jnp.dtype = struct.field(pytree_node=False)


def get_tree_depth(capacity: int) -> int:
Expand All @@ -61,22 +49,34 @@ def get_tree_depth(capacity: int) -> int:
return int(np.ceil(np.log2(capacity)))


def init(capacity: int) -> SumTreeState:
def init(capacity: int, dtype: Optional[jnp.dtype] = None) -> SumTreeState:
"""Creates the sum tree data structure for the given replay capacity.

Args:
capacity: The maximum number of elements that can be stored in this
data structure.
dtype: The data type of the sum tree. If None, will use jnp.float64 if
jax.config.x64_enabled is True, otherwise will use jnp.float32.
It is strongly recommended to use jnp.float64 for performance reasons.
If not, there is a risk of numerical instability causing the sum tree to
sample priority zero transitions.
"""
if dtype is None:
if jax.config.x64_enabled: # type: ignore
dtype = jnp.float64
else:
dtype = jnp.float32

tree_depth = get_tree_depth(capacity)
array_size = 2 ** (tree_depth + 1) - 1
nodes = jnp.zeros(array_size)
max_recorded_priority = jnp.array(1.0)
nodes = jnp.zeros(array_size, dtype=dtype)
max_recorded_priority = jnp.array(1.0, dtype=dtype)
return SumTreeState(
nodes=nodes,
max_recorded_priority=max_recorded_priority,
tree_depth=tree_depth,
capacity=capacity,
dtype=dtype,
)


Expand Down Expand Up @@ -137,7 +137,7 @@ def sample(
raise ValueError(
"Either the `rng_key` or the `query_value` must be specified."
)
query_value = jax.random.uniform(rng_key)
query_value = jax.random.uniform(rng_key, dtype=state.dtype)
query_value = query_value * _total_priority(state)

# Now traverse the sum tree.
Expand Down Expand Up @@ -198,13 +198,13 @@ def stratified_sample(
Batch of indices sampled from the sum tree.
"""
query_keys = jax.random.split(rng_key, batch_size)
bounds = jnp.linspace(0.0, 1.0, batch_size + 1)
bounds = jnp.linspace(0.0, 1.0, batch_size + 1, dtype=state.dtype)

lower_bounds = bounds[:-1, None]
upper_bounds = bounds[1:, None]

query_values = jax.vmap(jax.random.uniform, in_axes=(0, None, None, 0, 0))(
query_keys, (), jnp.float32, lower_bounds, upper_bounds
query_keys, (), state.dtype, lower_bounds, upper_bounds
)

return jax.vmap(sample, in_axes=(None, None, 0))(
Expand Down Expand Up @@ -272,6 +272,8 @@ def set_non_batched(
nonnegative. Setting value = 0 will cause the element to never be sampled.

"""
# Cast the value to the correct dtype.
value = jnp.asarray(value, dtype=state.dtype)
# We get the tree index of the node.
mapped_index = get_tree_index(state.tree_depth, node_index)
# We ensure that if we index out of bounds (which is what we do for padding)
Expand Down Expand Up @@ -330,6 +332,8 @@ def set_batch_bincount(
Returns:
A buffer state with updates nodes.
"""
# Cast the values to the correct dtype.
values = jnp.asarray(values, dtype=state.dtype)
# We get the tree indices of the nodes.
mapped_indices = get_tree_index(state.tree_depth, node_indices)

Expand Down Expand Up @@ -408,6 +412,8 @@ def set_batch_scan(
Returns:
A buffer state with updates nodes.
"""
# Cast the values to the correct dtype.
values = jnp.asarray(values, dtype=state.dtype)

def update_node_priority(state: SumTreeState, node_data: Tuple[Array, Array]):
"""Updates the priority of a single node."""
Expand Down