Can someone versed in the ways of math explain how this is different from previous quantization methods?
And specifically, seeing how going from 16fp to 8bit mostly gives same perplexity while anything further seems to lose quality / dumb down the model, how is this even less precise method is able to achieve this?
If I understand it correctly, this seems to be more than just quantizing, the models are apparently trained in this format as well. So it's possible that the many layers adjust themselves in a way that "cancels out" the inaccuracies of the lower bit count
So modern NNs aren't really using the network nodes in the structure they physically are, but essentially builds a virtual neural network using combinations of nodes (how you can model hundreds of parameters in only a dozen or so nodes).
So as the number of nodes scales up, the individual precision probably matters less and less. Which is what they found here - it reaches parity at 3B and then starts exceeding performance at larger sizes, up to the 2T tested.
Seemingly when trained from scratch the virtual network can find adequate precision from ternary physical nodes where needed. This is different from the information loss as an already trained floating point network has its weights quantized to smaller precision and sees a performance loss.
Not only is this approach more efficient, it seems to perform better too at larger network sizes, which is probably the most interesting part.
And specifically, seeing how going from 16fp to 8bit mostly gives same perplexity while anything further seems to lose quality / dumb down the model, how is this even less precise method is able to achieve this?