> Training proceeded
for 700,000 steps (mini-batches of size 4,096) starting from randomly initialised parameters,
using 5,000 first-generation TPUs to generate self-play games and 64 second-generation
TPUs to train the neural networks. Further details of the training procedure are provided in the
Methods.
TPUs were used for neural network inference and training, but game logic as well as MCTS was on the CPU using C++.
JAX is awesome though, I use it for all my neural network stuff!