-
Notifications
You must be signed in to change notification settings - Fork 420
Description
Describe the bug
The transform LineariseRewards applies a weighted sum to inputs and should write the result to an output key. However, it does not override transform_output_spec of the Transform class. Since the output key is then missing in the spec during spec validation, subsequent transforms in a Compose container fail to initialize when checking the spec and relying on this key, e.g., for reward scaling the weighted sum output.
To be clear: The LineariseRewards transform may work, but it does not register it output keys for validation.
To Reproduce
Given a simple example where a reward vector is linearised, and then the written value should be scaled. When the Compose iterates through its elements and lets them check the spec, RewardScaling fails to check since the key is missing.
Compose(
transforms=[
LineariseRewards(
in_keys=[("agent1", "reward_vec")],
out_keys=[("agent1", "weighted_reward")], # will be missing in output spec of tensordict
),
RewardScaling(in_keys=[("agent1", "weighted_reward")], ...), # relies on missing key in spec
],
),
...This also emits the warning in Transform.transform_output_spec:
FutureWarning: The key '('agent1', 'weighted_reward')' is unaccounted for by the transform (expected keys [<redacted>]).
Every new entry in the tensordict resulting from a call to a transform must be registered in the specs for torchrl rollouts to be consistently built.
Make sure transform_output_spec/transform_observation_spec/... is coded correctly.
This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.
Expected behavior
The transform_output_spec function should apply the out_key to the tensordict spec. Then subsequent transform steps can make use of its output value.
Workaround
Overwrite the value of an existing key, which is already in the spec before LineariseRewards. Then subsequent transforms will find it in the spec during initialization validation.
System info
Describe the characteristic of your environment:
- TorchRL 0.10.0 installed via pip (however, the shown version is
0.0.0+unknown) - [GCC 13.3.0] linux
- Python 3.12.3
Reason and Possible fixes
Implement the transform_output_spec function, which is missing here. Other transforms which could potentially write to the tensordict implement this as well.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)