%timeit (x_jax @ y_jax).block_until_ready() 579 µs ± 4.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax, optimize=True).block_until_ready() 658 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax).block_until_ready() 660 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)