Torch Distributions
https://docs.pytorch.org/docs/stable/distributions.html
Categorical
Normal
There are only 2 apis that you care about:
sample()
- get a number- sample from a gaussian is implemented via mu + sigma * torch.randn_like()
log_prob()
- the log probability of this number (a density value) is just computed
Why is
log_prob()
differentiable, but sampling from the action is not?Like think about it,
log_prob()
is just computing the PDF at that point. When you sample, its a black box . But you can also dotorch.random * sigma + mean
? Yes, that’s essentially the Reparametrization Trick so that your gradients can flow.
- Practically, in pytorch, use the
rsample
function.- See SAC implementation
def _distribution(self, obs):
logits = self.logits_net(obs)
return Categorical(logits=logits)
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
Should understand that there are two main methods