A mistake ? Weights/grads/optimizer stats memory for mixed precision
#104
by
donglongfei
- opened
the formulas about mixed precision memory
seems lost m_(grad_fp32)=4 * N. Total memory should be 2N+2N+4N+4N =12N which matches "The default nowadays for mixed precision training is to generally use BF16 for most of the computations –requiring 2 bytes per parameter and gradient– as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total." . Now only 8 bytes.