Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.


Most Pytorch users don’t bother even with the simplest performance optimizations, and you are talking about PTX.


I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.

[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.


The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.

Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?


I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?

[1] https://github.com/jaymody/picoGPT/blob/main/gpt2.py


yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)


What PTX kerfuffle are you referring to?


You do understand that PTX is part of CUDA right?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: