Isn't that backwards? You need fairly good resolution during training or your gradients will be pointing all over the place. Once you've found a good minimum point, moving a little away from it with reduced precision is probably OK.
I have no idea what the right answer is, but I think the argument for int4 training is that the loss measurements would take the lower resolution of the model as a whole into account.
Is it better to have billions of high resolution parameters and quantize them at the end, or to train low resolution parameters where the training algorithms see the lower resolution? It’s beyond me, but I’d love to know.
But by default, training algos don't see the lower resolution, your gradient just doesn't work as well. There is a body of research on how to make training aware of / adapt to the lower precision.
I think the answer is it depends, and further, a dynamic approach may be best. Imagine you are going on a hike, and you have different maps at various resolutions (levels of detail). When planning the hike, you will want to see the zoomed out picture to get general directions, elevations and landmarks identified. Then you can zoom in to see the actual trails themselves, to identify your route, and then you zoom in even further when you are on the ground actually walking, avoiding obstacles along the way.
Different resolutions draw your attention to different types of features.
It can go farther than that, it seems like the weight gradients are the main thing where the precision is a bottleneck (see https://arxiv.org/abs/1805.11046).
> Probably the best method is to just train it on int4 in the first place
Unclear why you think that since experiments show the opposite.
In general the gradient seems to get too "bumpy" to do good gradient decent at lower levels of precision.
There are some papers showing that making the training loop aware of quantitization can help ultimate quantizied performance but I'm not aware of this being implemented at large scale.
What if you smooth the gradient, either by interpolating/removing data points that make the surface "jagged", or maybe change the "window" of gradient descent, meaning instead of using a tangent (derivative, infinitesimally small window) you use a secant (???, window of specified length, likely calculated from the data space).
Sure, there are multiple ways to reduce the complexity of your loss-space, but the issue is that you usually want these small gradient values because they are important. Roughly if you "smooth over what appears to be a small hole" often you'll miss a large space that needs to be explored (obviously this is multi-dimensional but you get the idea).
So then you would need to do some kind of mesh simplification that also preserves the topology, that makes sense.
I'm not quite sure I understand what they are describing in 2.3.1, are they scaling those small gradient magnitudes larger to try to "pull" you into those holes faster?
I was thinking the a way to go about it would be to just increase the "mesh resolution" near the small hole, which in this case would be use a larger precision in the area local to the hole.
I suspect that changing the resolution around hot points in the manifold would be a more expensive task than training the model on a higher global resolution. Optimization algorithms currently do not maintain state on the loss-manifold.
My naive (and I do mean naive) thought here is that you just need a cheap detection function of when you need to swap precision. I'm pretty stuck on the geometric interpretation here but basically if the training step is "within a radius" of a known hot point of the manifold then you swap precision. It's very possible though that I am hallucinating something that is not possible, I don't actually understand how this stuff really works yet.
The challenge here is knowing the shape of the manifold within an epsilon radius 65 Billion dimension sphere around the position being evaluated. To calculate this you would need to sample points within epsilon radius around the current point. As these points will be lower-precision by default, you would have minimal knowledge of the shape of the manifold within the sphere if epsilon is < the minimum precision.
It might be possible to work around this by estimating the gradient volatility through the n^th order derivatives, but you would then also have to deal with mixed precision SIMD which hardware doesn't really support.
My take away was that the reduced performance of natively trained models was more about numerical instability related to training process than a statement about limitations of low precision models.
edit: to clarify, it may be possible to get this loss back and there is reason to be optimistic