Agent skill
using-jax
Applies JAX patterns for scientific Python development. Use when working with JAX, distrax, numpyro, blackjax, or scientific computing. Covers vmap, JIT, RNG handling.
Stars
163
Forks
31
Install this agent skill to your Project
npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/development/using-jax
SKILL.md
JAX Scientific Computing
Core Rules
- Pure functions - No side effects
- JIT outer functions -
@jax.jiton hot paths - vmap not loops -
jax.vmap(fn)instead of list comprehensions - Split RNG keys - Never reuse keys
Patterns
python
# RNG: always split
key, k1, k2 = jax.random.split(key, 3)
# Batching: vmap not loops
batched = jax.vmap(fn)(inputs)
# Loops: use scan
_, results = jax.lax.scan(step_fn, init, xs)
Gotchas
- Arrays are immutable
- No Python control flow in JIT - use
jax.lax.cond,jax.lax.scan - Check NaNs:
jnp.isnan(x).any()
Didn't find tool you were looking for?