Yes! Though this answer has some subtlety. tf.while() will run several iterations in parallel but this is not the same as _batching_ those same iterations. For that, you'll need to use the experimental parallel_for feature [0]. Using this should get you into roughly the same speed as JAX.
Tricky!
[0]: https://github.com/tensorflow/tensorflow/tree/b3e00739468080...