Ever needed to report FLOPs for your Torch/JAX code. We got you covered ! (More importantly, here's a survey on why bunnies flop)
Borrowed from https://gitlab.com/NERSC/roofline-on-nvidia-gpus/-/tree/roofline-hackathon-2020
Nsight Compute installation needed. Make sure the GPU counters are enabled
(Image Credit: Tuong Phung)-
Define your function that you wanna profile
-
JAX
SIZE = 500 x = jax.random.normal(jax.random.PRNGKey(0), (SIZE, SIZE)) y = jax.random.normal(jax.random.PRNGKey(1), (SIZE, SIZE)) def func(x, y): return jnp.einsum("ij,jk->ik", x, y)
-
Torch
SIZE = 500 x = torch.randn(SIZE, SIZE).to(device='cuda') y = torch.randn(SIZE, SIZE).to(device='cuda') def func(x, y): return torch.einsum("ij,jk->ik", x, y)
-
-
Wrap the function in the
flops_counter
decoratorfrom fast_flops import func_flops @flops_counter def func_flops(func, x, y): return func(x, y)
-
Let it run through JIT (Don't worry we have warmups to cleanse the JIT overhead) and execute !
-
JAX
func = jax.jit(func) func_flops(func, x, y)
-
Torch
func = torch.compile(func, fullgraph=True, mode='max-autotune') func_flops(func, x, y)
-
-
The pipeline can be executed using
bash launch_profiler.sh examples/matmul/test_matmul_torch.py
with the output looking something like
Measured Time: 3.196823494576654e-05 Measured GFLOP/s: 2353.1630879907675 Measured FLOPS: 67108864.0
- Only turn on ncu during the hot loop in addition to nvtx
- Axe run_profiler.py and postprocess.py. Ideally should work with python decorated_file.py
- Plotting utilieis and csv plumbing