TorchRL Episode Data Format (TED)

https://pytorch.org/rl/stable/reference/data.html#ted-format

TorchRL has a funny way of writing state transitions, slightly different from the conventional OpenAI Gym.

This, however, is one of the ways it allows parallelization.

Generally followings this shape

>>> data = env.reset()
>>> data = policy(data)
>>> print(env.step(data))

Generally, MDP follow this:

  1. a new observation
  2. an indicator of task completion (terminated, truncated, done), and
  3. a reward signal

This can get more complicated in multi-agent RL settings, sometimes reward also is not necessary (like in imitation learning scenarios).

So we root the information at time t of the tensorDict:

General Rule

Everything that belongs to time step t is stored at the root of the tensordict, while everything that belongs to t+1 is stored in the ”next” entry of the tensordict.

Here’s an example:

data = env.reset()
data = policy(data)
print(env.step(data))
 
"""
TensorDict(
    fields={
        action: Tensor(...),  # The action taken at time t
        done: Tensor(...),  # The done state when the action was taken (at reset)
        next: TensorDict(  # all of this content comes from the call to `step`
            fields={
                done: Tensor(...),  # The done state after the action has been taken
                observation: Tensor(...),  # The observation resulting from the action
                reward: Tensor(...),  # The reward resulting from the action
                terminated: Tensor(...),  # The terminated state after the action has been taken
                truncated: Tensor(...),  # The truncated state after the action has been taken
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(...),  # the observation at reset
        terminated: Tensor(...),  # the terminated at reset
        truncated: Tensor(...),  # the truncated at reset
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
"""