JAX
I need to learn JAX, originally this was recommended by Ryan Erlich.
You can tabulate, and it will show the model parameters very cool, found from here
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
- um in practice, I think you gotta pass the arguments?
- In pytorch, consider using torchsummary https://pypi.org/project/torch-summary/