I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.
I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.
[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.
Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?