How I’m Learning JAX
Recently I’ve been trying to learn more about JAX, the next-gen ML framework from DeepMind. These are the resources I’ve found most helpful so far.
- The JAX docs, which are excellent
- Patrick Kidger’s Learning JAX as a PyTorch developer
- JAX AI Stack tutorials
- The book JAX in Action, along with accompanying Python notebooks
- Codebases: Gemma, MaxText
If you know any more good pointers please let me know.