PyTorch Hook
These are really useful for debugging.
handles = []
def stat_hook(name):
def hook(module, inputs, output):
t = output[0] if isinstance(output, tuple) else output
if torch.is_tensor(t):
print(name, t.shape, t.mean().item(), t.std().item())
return hook
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear):
handles.append(mod.register_forward_hook(stat_hook(name)))
# later
for h in handles: h.remove()Find first layer that diverges
acts_a, acts_b = {}, {}
def save_acts(store, name):
def hook(m, inp, out):
t = out[0] if isinstance(out, tuple) else out
if torch.is_tensor(t):
store[name] = t.detach().float().cpu()
return hook
for (n1, m1), (n2, m2) in zip(ref.named_modules(), test.named_modules()):
if len(list(m1.children())) == 0: # leaf modules only
m1.register_forward_hook(save_acts(acts_a, n1))
m2.register_forward_hook(save_acts(acts_b, n2))
with torch.no_grad():
ref(x); test(x)
for k in acts_a:
err = (acts_a[k] - acts_b[k]).abs().max()
if err > 1e-4:
print("first mismatch:", k, err.item())
breakimport torch
def first_tensor(x):
"""Extract first Tensor from common output structures."""
if torch.is_tensor(x):
return x
if isinstance(x, (tuple, list)):
for item in x:
t = first_tensor(item)
if t is not None:
return t
if isinstance(x, dict):
for item in x.values():
t = first_tensor(item)
if t is not None:
return t
return None
def collect_activations(model, module_filter):
activations = {}
handles = []
def make_hook(name):
def hook(module, args, output):
t = first_tensor(output)
if t is not None:
activations[name] = t.detach().float().cpu()
return hook
for name, module in model.named_modules():
if module_filter(name, module):
handles.append(module.register_forward_hook(make_hook(name)))
return activations, handlesNamed modules
model.named_modules()
gives (name, module) pairs:
for name, module in model.named_modules():
print(name, module)
Example output:
Transformer(...)
wte Embedding(...)
blocks.0 Block(...)
blocks.0.ln1 LayerNorm(...)
blocks.0.attn Attention(...)
blocks.0.mlp MLP(...)
blocks.0.mlp.fc Linear(...)