Torch Distributions

https://docs.pytorch.org/docs/stable/distributions.html

  • Categorical
  • Normal

There are only 2 apis that you care about:

  • sample() - get a number
    • sample from a gaussian is implemented via mu + sigma * torch.randn_like()
  • log_prob() - the log probability of this number (a density value) is just computed

Why is log_prob() differentiable, but sampling from the action is not?

Like think about it, log_prob() is just computing the PDF at that point. When you sample, its a black box . But you can also do torch.random * sigma + mean? Yes, that’s essentially the Reparametrization Trick so that your gradients can flow.

  • Practically, in pytorch, use the rsample function.
  • See SAC implementation
def _distribution(self, obs):
    logits = self.logits_net(obs)
    return Categorical(logits=logits)
 
def _distribution(self, obs):
    mu = self.mu_net(obs)
    std = torch.exp(self.log_std)
    return Normal(mu, std)

Should understand that there are two main methods