PyTorch Hook

These are really useful for debugging.

handles = []
 
def stat_hook(name):
    def hook(module, inputs, output):
        t = output[0] if isinstance(output, tuple) else output
        if torch.is_tensor(t):
            print(name, t.shape, t.mean().item(), t.std().item())
    return hook
 
for name, mod in model.named_modules():
    if isinstance(mod, nn.Linear):
        handles.append(mod.register_forward_hook(stat_hook(name)))
 
# later
for h in handles: h.remove()

Find first layer that diverges

acts_a, acts_b = {}, {}
 
def save_acts(store, name):
    def hook(m, inp, out):
        t = out[0] if isinstance(out, tuple) else out
        if torch.is_tensor(t):
            store[name] = t.detach().float().cpu()
    return hook
 
for (n1, m1), (n2, m2) in zip(ref.named_modules(), test.named_modules()):
    if len(list(m1.children())) == 0:       # leaf modules only
        m1.register_forward_hook(save_acts(acts_a, n1))
        m2.register_forward_hook(save_acts(acts_b, n2))
 
with torch.no_grad():
    ref(x); test(x)
 
for k in acts_a:
    err = (acts_a[k] - acts_b[k]).abs().max()
    if err > 1e-4:
        print("first mismatch:", k, err.item())
        break
import torch
 
def first_tensor(x):
    """Extract first Tensor from common output structures."""
    if torch.is_tensor(x):
        return x
    if isinstance(x, (tuple, list)):
        for item in x:
            t = first_tensor(item)
            if t is not None:
                return t
    if isinstance(x, dict):
        for item in x.values():
            t = first_tensor(item)
            if t is not None:
                return t
    return None
 
 
def collect_activations(model, module_filter):
    activations = {}
    handles = []
 
    def make_hook(name):
        def hook(module, args, output):
            t = first_tensor(output)
            if t is not None:
                activations[name] = t.detach().float().cpu()
        return hook
 
    for name, module in model.named_modules():
        if module_filter(name, module):
            handles.append(module.register_forward_hook(make_hook(name)))
 
    return activations, handles

Named modules

model.named_modules()

gives (name, module) pairs:

for name, module in model.named_modules():
    print(name, module)

Example output:

                  Transformer(...)
wte               Embedding(...)
blocks.0          Block(...)
blocks.0.ln1      LayerNorm(...)
blocks.0.attn     Attention(...)
blocks.0.mlp      MLP(...)
blocks.0.mlp.fc   Linear(...)