Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster GELU forward & backward using MUFU.TANH for SM7.5+ #721

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ademeure
Copy link
Contributor

These are faster GELU kernels by using the HW instruction NVIDIA introduced for this in Turing (SM7.5) but never exposed outside of PTX as far as I can tell, possibly because it's slightly less accurate - but based on the val loss we get which is slightly better for the backwards pass (and within noise for forward), I am pretty sure it's fine for our purposes!

This is only somewhat faster on H100 PCIe but should be much faster on H200/Blackwell as they have more DRAM bandwidth relative to compute, and also much faster with FP8 (this was originally done in the context of the FP8 branch where it was >50% faster!)

This also includes the change for backward suggested in #307

Included the changes in /dev/cuda/ as well:

FORWARD PASS - BEFORE
block_size   32 | time 0.0311 ms | bandwidth 810.06 GB/s
block_size   64 | time 0.0277 ms | bandwidth 907.20 GB/s
block_size  128 | time 0.0277 ms | bandwidth 908.80 GB/s
block_size  256 | time 0.0277 ms | bandwidth 907.33 GB/s
block_size  512 | time 0.0281 ms | bandwidth 895.78 GB/s
block_size 1024 | time 0.0307 ms | bandwidth 819.99 GB/s

FORWARD PASS - AFTER
block_size   32 | time 0.0310 ms | bandwidth 810.72 GB/s
block_size   64 | time 0.0271 ms | bandwidth 927.50 GB/s
block_size  128 | time 0.0271 ms | bandwidth 928.29 GB/s
block_size  256 | time 0.0271 ms | bandwidth 928.59 GB/s
block_size  512 | time 0.0274 ms | bandwidth 919.46 GB/s
block_size 1024 | time 0.0288 ms | bandwidth 872.31 GB/s

BACKWARD PASS - BEFORE
block_size   32 | time 0.0417 ms | bandwidth 1207.47 GB/s
block_size   64 | time 0.0402 ms | bandwidth 1252.53 GB/s
block_size  128 | time 0.0403 ms | bandwidth 1249.39 GB/s
block_size  256 | time 0.0403 ms | bandwidth 1248.84 GB/s
block_size  512 | time 0.0404 ms | bandwidth 1246.94 GB/s
block_size 1024 | time 0.0407 ms | bandwidth 1235.60 GB/s

BACKWARD PASS - AFTER
block_size   32 | time 0.0333 ms | bandwidth 1509.47 GB/s
block_size   64 | time 0.0325 ms | bandwidth 1546.38 GB/s
block_size  128 | time 0.0326 ms | bandwidth 1545.97 GB/s
block_size  256 | time 0.0326 ms | bandwidth 1542.07 GB/s
block_size  512 | time 0.0329 ms | bandwidth 1531.70 GB/s
block_size 1024 | time 0.0351 ms | bandwidth 1433.92 GB/s

In terms of loss for a few steps of Tiny Shakespeare, 'make train_gpt2cu && ./train_gpt2cu -r 0 -ge 0 -e "d12"' gives:

BEFORE
step   74/74 | loss 5.935039 (+nanz)| norm 1.2940 (+nanz)| lr 3.00e-04 | 28.42 ms | 15.3% bf16 MFU | 144989 tok/s
val loss 5.955256

AFTER
step   74/74 | loss 5.920402 (+nanz)| norm 1.2239 (+nanz)| lr 3.00e-04 | 27.98 ms | 15.6% bf16 MFU | 146466 tok/s
val loss 5.948031

so actually slightly better (but potentially noise)! Using the new backward but the old forward pass gives a val loss of 5.942228, but again, it might be noise. Either way looks to be good enough as far as I can tell!

I believe this NVIDIA forum thread (and a few others) talk a little bit about this HW instruction: https://forums.developer.nvidia.com/t/hardware-accelerated-tanh-on-turing/173291

@ademeure
Copy link
Contributor Author

Oops, /dev/cuda/ numbers are for non-MUFU.TANH part of the changes (e.g. faster way to do the derivative for backward) because I compiled without specifying the arch, so it's not SM7.5 😅 The val loss numbers for llm.c are correct though.

Can't test it until back on Tuesday but expecting it to be way faster

…increase batch sizes for GELU fwd/bwd to hit closer to peak
@ademeure
Copy link
Contributor Author

ademeure commented Aug 7, 2024

zI think there's an error for both /dev/cuda/common.h and test_gpt2.cu using an epsilon of 0.079 instead of 0.0079 for BF16 which makes the error threshold too high (missing some fairly large errors) - but even after fixing that, this seems to pass all the tests :)

The bandwidth calculations for gelu_backward were broken, and the batch size was wayyyy too small so it couldn't saturate H100 on either kernel, this is the correct performance (compiling for SM9.0):

FORWARD PASS - BEFORE
block_size   32 | time 0.3724 ms | bandwidth 1081.16 GB/s
block_size   64 | time 0.2472 ms | bandwidth 1628.73 GB/s
block_size  128 | time 0.2477 ms | bandwidth 1625.77 GB/s
block_size  256 | time 0.2496 ms | bandwidth 1613.01 GB/s
block_size  512 | time 0.2552 ms | bandwidth 1577.94 GB/s
block_size 1024 | time 0.2771 ms | bandwidth 1452.96 GB/s

FORWARD PASS - AFTER
block_size   32 | time 0.3738 ms | bandwidth 1077.24 GB/s
block_size   64 | time 0.2327 ms | bandwidth 1730.08 GB/s
block_size  128 | time 0.2327 ms | bandwidth 1730.10 GB/s
block_size  256 | time 0.2330 ms | bandwidth 1728.38 GB/s
block_size  512 | time 0.2340 ms | bandwidth 1721.00 GB/s
block_size 1024 | time 0.2370 ms | bandwidth 1699.07 GB/s

BACKWARD PASS - BEFORE
block_size   32 | time 0.4575 ms | bandwidth 1320.26 GB/s
block_size   64 | time 0.4231 ms | bandwidth 1427.44 GB/s
block_size  128 | time 0.4232 ms | bandwidth 1427.10 GB/s
block_size  256 | time 0.4223 ms | bandwidth 1430.12 GB/s
block_size  512 | time 0.4187 ms | bandwidth 1442.61 GB/s
block_size 1024 | time 0.4234 ms | bandwidth 1426.65 GB/s

BACKWARD PASS - AFTER
block_size   32 | time 0.3722 ms | bandwidth 1622.76 GB/s
block_size   64 | time 0.3290 ms | bandwidth 1835.54 GB/s
block_size  128 | time 0.3291 ms | bandwidth 1835.42 GB/s
block_size  256 | time 0.3290 ms | bandwidth 1835.81 GB/s
block_size  512 | time 0.3290 ms | bandwidth 1836.01 GB/s
block_size 1024 | time 0.3297 ms | bandwidth 1832.09 GB/s

===> +6% for forward and +27% for backward.

@karpathy
Copy link
Owner

This sounds cool, I guess you only tried for the little shakespeare training run, i wonder if the slight accuracy decrease could cause training instabilities, probably should try a bigger run?

…s is negligible for BF16, proven with gelu_precision_test branch)
@ademeure
Copy link
Contributor Author

ademeure commented Sep 3, 2024

Haven't tried a bigger run, but made very quick & dirty precision tests ("gelu_precision_test" branch) to see if it made any difference after you round to BF16 (i.e. with only 7 mantissa bits, is the maximum error basically a single rounding error?)

It is indeed negligible for forward! It's trickier to test for backward (with 2x BF16 inputs there are 2^32 possible bits, forward is only 2^16) but error seems extremely small except with insanely large dout in the millions/billions which would point to a much bigger problem anyway! (and even then the error isn't that bad relative to the magnitude of the inputs).

For forward, only 60 out of 65536 inputs result in any difference (the other 65476 inputs have the exact same BF16 outputs bit-for-bit). The worst error is with inputs -4.875 to -5.15625 where the output gets rounded down from a very small number to zero but that's basically nothing compared to what happens when we use FP8 (which couldn't represent anything near those tiny outputs anyway).

So I'm pretty sure it's fine :) But I edited my PR so that it would never be used in FP32 mode to make sure that remains a good reference point.

[49060]: INPUT -1.281250000000000 ===> -0.128906250000000 vs -0.127929687500000 ===> DIFF: 0.000976562500000 (0.75757575%)
[49061]: INPUT -1.289062500000000 ===> -0.127929687500000 vs -0.126953125000000 ===> DIFF: 0.000976562500000 (0.76335877%)
[49215]: INPUT -2.984375000000000 ===> -0.003829956054688 vs -0.003814697265625 ===> DIFF: 0.000015258789062 (0.39840639%)
[49240]: INPUT -3.375000000000000 ===> -0.000991821289062 vs -0.000999450683594 ===> DIFF: 0.000007629394531 (0.76923078%)
[49241]: INPUT -3.390625000000000 ===> -0.000938415527344 vs -0.000942230224609 ===> DIFF: 0.000003814697266 (0.40650403%)
[49242]: INPUT -3.406250000000000 ===> -0.000885009765625 vs -0.000888824462891 ===> DIFF: 0.000003814697266 (0.43103448%)
[49246]: INPUT -3.468750000000000 ===> -0.000698089599609 vs -0.000694274902344 ===> DIFF: 0.000003814697266 (0.54644805%)
[49247]: INPUT -3.484375000000000 ===> -0.000656127929688 vs -0.000652313232422 ===> DIFF: 0.000003814697266 (0.58139533%)
[49248]: INPUT -3.500000000000000 ===> -0.000617980957031 vs -0.000614166259766 ===> DIFF: 0.000003814697266 (0.61728394%)
[49249]: INPUT -3.515625000000000 ===> -0.000579833984375 vs -0.000576019287109 ===> DIFF: 0.000003814697266 (0.65789473%)
[49250]: INPUT -3.531250000000000 ===> -0.000545501708984 vs -0.000541687011719 ===> DIFF: 0.000003814697266 (0.69930071%)
[49251]: INPUT -3.546875000000000 ===> -0.000511169433594 vs -0.000507354736328 ===> DIFF: 0.000003814697266 (0.74626863%)
[49252]: INPUT -3.562500000000000 ===> -0.000480651855469 vs -0.000478744506836 ===> DIFF: 0.000001907348633 (0.39682543%)
[49253]: INPUT -3.578125000000000 ===> -0.000450134277344 vs -0.000453948974609 ===> DIFF: 0.000003814697266 (0.84745765%)
[49255]: INPUT -3.609375000000000 ===> -0.000396728515625 vs -0.000398635864258 ===> DIFF: 0.000001907348633 (0.48076925%)
[49256]: INPUT -3.625000000000000 ===> -0.000371932983398 vs -0.000373840332031 ===> DIFF: 0.000001907348633 (0.51282054%)
[49257]: INPUT -3.640625000000000 ===> -0.000349044799805 vs -0.000350952148438 ===> DIFF: 0.000001907348633 (0.54644805%)
[49258]: INPUT -3.656250000000000 ===> -0.000326156616211 vs -0.000328063964844 ===> DIFF: 0.000001907348633 (0.58479536%)
[49261]: INPUT -3.703125000000000 ===> -0.000268936157227 vs -0.000267028808594 ===> DIFF: 0.000001907348633 (0.70921981%)
[49262]: INPUT -3.718750000000000 ===> -0.000251770019531 vs -0.000249862670898 ===> DIFF: 0.000001907348633 (0.75757575%)
[49263]: INPUT -3.734375000000000 ===> -0.000234603881836 vs -0.000232696533203 ===> DIFF: 0.000001907348633 (0.81300807%)
[49264]: INPUT -3.750000000000000 ===> -0.000219345092773 vs -0.000217437744141 ===> DIFF: 0.000001907348633 (0.86956519%)
[49265]: INPUT -3.765625000000000 ===> -0.000205039978027 vs -0.000203132629395 ===> DIFF: 0.000001907348633 (0.93023252%)
[49266]: INPUT -3.781250000000000 ===> -0.000191688537598 vs -0.000190734863281 ===> DIFF: 0.000000953674316 (0.49751243%)
[49267]: INPUT -3.796875000000000 ===> -0.000178337097168 vs -0.000179290771484 ===> DIFF: 0.000000953674316 (0.53475934%)
[49269]: INPUT -3.828125000000000 ===> -0.000155448913574 vs -0.000156402587891 ===> DIFF: 0.000000953674316 (0.61349690%)
[49270]: INPUT -3.843750000000000 ===> -0.000144958496094 vs -0.000145912170410 ===> DIFF: 0.000000953674316 (0.65789473%)
[49272]: INPUT -3.875000000000000 ===> -0.000125885009766 vs -0.000126838684082 ===> DIFF: 0.000000953674316 (0.75757575%)
[49273]: INPUT -3.890625000000000 ===> -0.000117301940918 vs -0.000117778778076 ===> DIFF: 0.000000476837158 (0.40650403%)
[49275]: INPUT -3.921875000000000 ===> -0.000101566314697 vs -0.000101089477539 ===> DIFF: 0.000000476837158 (0.46948358%)
[49276]: INPUT -3.937500000000000 ===> -0.000094413757324 vs -0.000093936920166 ===> DIFF: 0.000000476837158 (0.50505048%)
[49277]: INPUT -3.953125000000000 ===> -0.000087738037109 vs -0.000087261199951 ===> DIFF: 0.000000476837158 (0.54347825%)
[49278]: INPUT -3.968750000000000 ===> -0.000081539154053 vs -0.000080585479736 ===> DIFF: 0.000000953674316 (1.16959071%)
[49279]: INPUT -3.984375000000000 ===> -0.000075817108154 vs -0.000074863433838 ===> DIFF: 0.000000953674316 (1.25786161%)
[49281]: INPUT -4.031250000000000 ===> -0.000060319900513 vs -0.000060796737671 ===> DIFF: 0.000000476837158 (0.79051387%)
[49282]: INPUT -4.062500000000000 ===> -0.000051975250244 vs -0.000052213668823 ===> DIFF: 0.000000238418579 (0.45871559%)
[49283]: INPUT -4.093750000000000 ===> -0.000044584274292 vs -0.000044822692871 ===> DIFF: 0.000000238418579 (0.53475934%)
[49285]: INPUT -4.156250000000000 ===> -0.000032663345337 vs -0.000032424926758 ===> DIFF: 0.000000238418579 (0.72992700%)
[49290]: INPUT -4.312500000000000 ===> -0.000014543533325 vs -0.000014424324036 ===> DIFF: 0.000000119209290 (0.81967211%)
[49291]: INPUT -4.343750000000000 ===> -0.000012278556824 vs -0.000012159347534 ===> DIFF: 0.000000119209290 (0.97087377%)
[49294]: INPUT -4.437500000000000 ===> -0.000007271766663 vs -0.000007152557373 ===> DIFF: 0.000000119209290 (1.63934422%)
[49295]: INPUT -4.468750000000000 ===> -0.000006139278412 vs -0.000005990266800 ===> DIFF: 0.000000149011612 (2.42718434%)
[49296]: INPUT -4.500000000000000 ===> -0.000005096197128 vs -0.000004947185516 ===> DIFF: 0.000000149011612 (2.92397666%)
[49297]: INPUT -4.531250000000000 ===> -0.000004321336746 vs -0.000004172325134 ===> DIFF: 0.000000149011612 (3.44827580%)
[49298]: INPUT -4.562500000000000 ===> -0.000003531575203 vs -0.000003665685654 ===> DIFF: 0.000000134110451 (3.79746819%)
[49302]: INPUT -4.687500000000000 ===> -0.000001676380634 vs -0.000001817941666 ===> DIFF: 0.000000141561031 (8.44444370%)
[49303]: INPUT -4.718750000000000 ===> -0.000001408159733 vs -0.000001683831215 ===> DIFF: 0.000000275671482 (19.57671928%)
[49304]: INPUT -4.750000000000000 ===> -0.000001132488251 vs -0.000001274049282 ===> DIFF: 0.000000141561031 (12.50000000%)
[49306]: INPUT -4.812500000000000 ===> -0.000000715255737 vs -0.000000860542059 ===> DIFF: 0.000000145286322 (20.31250000%)
[49307]: INPUT -4.843750000000000 ===> -0.000000577419996 vs -0.000000722706318 ===> DIFF: 0.000000145286322 (25.16128922%)
[49308]: INPUT -4.875000000000000 ===> -0.000000581145287 vs 0.000000000000000 ===> DIFF: 0.000000581145287 (100.00000000%)
[49309]: INPUT -4.906250000000000 ===> -0.000000439584255 vs 0.000000000000000 ===> DIFF: 0.000000439584255 (100.00000000%)
[49310]: INPUT -4.937500000000000 ===> -0.000000294297934 vs 0.000000000000000 ===> DIFF: 0.000000294297934 (100.00000000%)
[49311]: INPUT -4.968750000000000 ===> -0.000000296160579 vs 0.000000000000000 ===> DIFF: 0.000000296160579 (100.00000000%)
[49312]: INPUT -5.000000000000000 ===> -0.000000298023224 vs 0.000000000000000 ===> DIFF: 0.000000298023224 (100.00000000%)
[49313]: INPUT -5.031250000000000 ===> -0.000000149942935 vs 0.000000000000000 ===> DIFF: 0.000000149942935 (100.00000000%)
[49314]: INPUT -5.062500000000000 ===> -0.000000150874257 vs 0.000000000000000 ===> DIFF: 0.000000150874257 (100.00000000%)
[49315]: INPUT -5.093750000000000 ===> -0.000000151805580 vs 0.000000000000000 ===> DIFF: 0.000000151805580 (100.00000000%)
[49316]: INPUT -5.125000000000000 ===> -0.000000152736902 vs 0.000000000000000 ===> DIFF: 0.000000152736902 (100.00000000%)
[49317]: INPUT -5.156250000000000 ===> -0.000000153668225 vs 0.000000000000000 ===> DIFF: 0.000000153668225 (100.00000000%)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants