-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for FP16/BF16 in train_gpt2.cu (1.86x Perf) #218
Conversation
It is trivial to use the exact same code with everything in FP32, at the top of train_gpt2.cu simply replace this: typedef __nv_bfloat16 floatX; with this: typedef float floatX; This is now able to train in BF16 for many layers and kinda-sorta works for test_gpt2.cu, as the loss converges much slower than FP32 for now (need to debug how to improve that afterwards): LOSS MISMATCH AT STEP 1: 4.598247 4.059707 It does eventually converge: step 99: loss 0.001294 (took 8.712021 ms) |
…( (is it doing CAS?)
Debugged the BF16 convergence issue and fixed it by adding stochastic rounding support. Simplified code by making all activations the same type (but params can still be different types partly due to severe perf issues for atomicAdd otherwise). The PR is now in a good state in my opinion where it's worth thinking about what it would take to integrate it. |
BTW this approach should work perfectly fine for FP8 as well, the main issue (besides loss scaling) to get that working is cuBLAS non-Lt doesn't support FP8 at all, so we can't use StridedBatched GEMMs for attention, and we need padding to move the other cuBLAS calls to Lt, etc... By hacking things so all the "cannot be FP8" GEMMs stay at BF16 while halving k (obviously not functionally correct), I got FP8 to run at ~29.5ms (vs ~43ms for BF16). So it does seem to scale reasonably well despite suffering a little bit from Amdahl's Law. It should be possible to keep gradients as e5m2 and everything else as e4m3 by just adding a "floatG" for e5m2 and casting appropriately, since the storage requirements are the same. What would not work without a LOT more complexity is using types with different sizes (e.g. activations at FP8 and gradients at BF16) but I think we agreed we shouldn't really need that anytime soon. |
merging this. we'll iterate in master. |
Now finished and reasonably happy with it!
1.86x performance on my RTX 4090:
This allows the same train_gpt2.cu to work as full FP32, full BF16, full FP16, or full (BF/FP)16 + FP32 layernorm simply by changing the define at the top of the file. Also included stochastic rounding for the Adam kernel (but nothing else at this point, possibly worth adding to gradients in general when we move to FP8?)
To simplify the logic compared to the first version of the PR, all activation tensors are now always "floatX", we cannot mix-and-match. However, because atomicAdd on 16-bit values in some of the backwards kernels are HORRIBLY slow (10x slower or worse), and because this kind of flexibility seems useful in general for layernorm accuracy, layernorm is kept at FP32 by defining "floatN" as "float".
I reduced the amount of code duplication by using very lightweight templates for the kernel types. It's still a BIG change though, unfortunately I don't think there's any way around that!