TorchRL Episode Data Format (TED)
https://pytorch.org/rl/stable/reference/data.html#ted-format
TorchRL has a funny way of writing state transitions, slightly different from the conventional OpenAI Gym.
This, however, is one of the ways it allows parallelization.
Generally followings this shape
>>> data = env.reset()
>>> data = policy(data)
>>> print(env.step(data))
Generally, MDP follow this:
- a new observation
- an indicator of task completion (terminated, truncated, done), and
- a reward signal
This can get more complicated in multi-agent RL settings, sometimes reward also is not necessary (like in imitation learning scenarios).
So we root the information at time t
of the tensorDict:
General Rule
Everything that belongs to time step
t
is stored at the root of the tensordict, while everything that belongs tot+1
is stored in the ”next” entry of the tensordict.
Here’s an example:
data = env.reset()
data = policy(data)
print(env.step(data))
"""
TensorDict(
fields={
action: Tensor(...), # The action taken at time t
done: Tensor(...), # The done state when the action was taken (at reset)
next: TensorDict( # all of this content comes from the call to `step`
fields={
done: Tensor(...), # The done state after the action has been taken
observation: Tensor(...), # The observation resulting from the action
reward: Tensor(...), # The reward resulting from the action
terminated: Tensor(...), # The terminated state after the action has been taken
truncated: Tensor(...), # The truncated state after the action has been taken
batch_size=torch.Size([]),
device=cpu,
is_shared=False),
observation: Tensor(...), # the observation at reset
terminated: Tensor(...), # the terminated at reset
truncated: Tensor(...), # the truncated at reset
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
"""