Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster GELU forward & backward using MUFU.TANH for SM7.5+ #721

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ void validate_result(D* device_result, const T* cpu_reference, const char* name,
#ifndef ENABLE_BF16
float epsilon = FLT_EPSILON;
#else
float epsilon = 0.079;
float epsilon = 0.0079; // ~2^-7 (where 7 is the number of mantissa bits in BF16)
#endif
for (int i = 0; i < num_elements; i++) {
// Skip masked elements
Expand Down
47 changes: 41 additions & 6 deletions dev/cuda/gelu_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ __global__ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* do
}
}

template <typename Ti, typename Tdout, typename Tdinp>
__global__ void gelu_backward3(Tdinp* dinp, const Ti* inp, const Tdout* dout, const int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * Packed128<Tdout>::size;
if (idx >= N) { return; }

Packed128<Tdinp> packed_dinp;
Packed128<Ti> packed_inp = load128cs(inp + idx);
Packed128<Tdout> packed_dout = load128(dout + idx);
for (int k = 0; k < Packed128<Tdout>::size; ++k) {
float x = (float)packed_inp[k];
float cube = 0.044715f * x * x * x;

float tanh_in_out = GELU_SCALING_FACTOR * (x + cube);
#if !defined(PRECISE_GELU_TANH) && __CUDA_ARCH__ >= 750
asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out));
#else
tanh_in_out = tanhf(tanh_in_out);
#endif

float sech_out = 1.0f - (tanh_in_out * tanh_in_out);
float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x));
float result = local_grad * (float)packed_dout[k];
packed_dinp[k] = (Tdinp)(result);
}
store128(dinp + idx, packed_dinp);
}
// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -95,10 +121,16 @@ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, int N,
cudaCheck(cudaGetLastError());
}

void gelu_backward3(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size * x128::size);
gelu_backward3<<<grid_size, block_size>>>(dinp, inp, dout, N);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void gelu_backward(int kernel_num,
floatX* dinp,
const floatX* inp,
floatX* dinp,
const floatX* inp,
const floatX* dout,
int B, int T, int C,
int block_size) {
Expand All @@ -109,6 +141,9 @@ void gelu_backward(int kernel_num,
case 2:
gelu_backward2(dinp, inp, dout, B * T * C, block_size);
break;
case 3:
gelu_backward3(dinp, inp, dout, B * T * C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -120,7 +155,7 @@ void gelu_backward(int kernel_num,
int main(int argc, char **argv) {
setup_main();

int B = 8;
int B = 128;
int T = 1024;
int C = 768;

Expand Down Expand Up @@ -157,9 +192,9 @@ int main(int argc, char **argv) {
printf("Checking block size %d.\n", block_size);
gelu_backward(kernel_num, d_dinp, d_inp, d_dout, B, T, C, block_size);
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
float tol = 1e-5;
float tol = 1e-5f;
#else
float tol = 1e-2f;
float tol = 1e-3f;
#endif
validate_result(d_dinp, dinp, "dinp", B * T * C, tol);
}
Expand All @@ -178,7 +213,7 @@ int main(int argc, char **argv) {
// napkin math: estimate the memory bandwidth achieved
// for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each
// and e.g. A100 40GB PCIe is advertised at 1,555GB/s
long memory_ops = B * T * C * 2 * 4;
long memory_ops = B * T * C * 3 * (int)sizeof(floatX);
float memory_bandwidth = memory_ops / elapsed_time / 1e6;

printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth);
Expand Down
42 changes: 39 additions & 3 deletions dev/cuda/gelu_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) {
}
}

// Optimised with option to use optimised HW TANH instruction by default
__global__ void gelu_forward_kernel3(floatX* out, const floatX* inp, int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx >= N) { return; }

x128 packed_out;
x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache
for(int k = 0; k < packed_inp.size; ++k) {
float xi = (float)packed_inp[k];
float cube = 0.044715f * xi * xi * xi;

float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube);
#if !defined(PRECISE_GELU_TANH) && __CUDA_ARCH__ >= 750
asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out));
#else
tanh_in_out = tanhf(tanh_in_out);
#endif

// the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)"
float half_xi = 0.5f * xi;
packed_out[k] = (floatX)(half_xi * tanh_in_out + half_xi);
}
// store instead of storecs (without cache streaming) in case it is useful for the
// data to be in the cache for the next operation after this GeLU
store128(out + idx, packed_out);
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -81,6 +108,12 @@ void gelu_forward2(floatX* out, const floatX* inp, int N, const int block_size)
cudaCheck(cudaGetLastError());
}

void gelu_forward3(floatX* out, const floatX* inp, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size * x128::size);
gelu_forward_kernel3<<<grid_size, block_size>>>(out, inp, N);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void gelu_forward(int kernel_num,
floatX* out,
Expand All @@ -94,6 +127,9 @@ void gelu_forward(int kernel_num,
case 2:
gelu_forward2(out, inp, B * T * C, block_size);
break;
case 3:
gelu_forward3(out, inp, B * T * C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -105,7 +141,7 @@ void gelu_forward(int kernel_num,
int main(int argc, const char **argv) {
setup_main();

int B = 8;
int B = 128;
int T = 1024;
int C = 768;

Expand Down Expand Up @@ -137,9 +173,9 @@ int main(int argc, const char **argv) {
printf("Checking block size %d.\n", block_size);
gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size);
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
float tol = 1e-5;
float tol = 1e-5f;
#else
float tol = 1e-2f;
float tol = 1e-3f;
#endif
validate_result(d_out, out, "out", B * T * C, tol);
}
Expand Down
47 changes: 32 additions & 15 deletions llmc/gelu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,47 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {
for(int k = 0; k < packed_inp.size; ++k) {
float xi = (float)packed_inp[k];
float cube = 0.044715f * xi * xi * xi;
packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));

float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube);
#if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750
asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out));
#else
tanh_in_out = tanhf(tanh_in_out);
#endif

// the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)"
float half_xi = 0.5f * xi;
packed_out[k] = (floatX)(half_xi * tanh_in_out + half_xi);
}
// store instead of storecs (without cache streaming) in case it is useful for the
// data to be in the cache for the next operation after this GeLU
store128(out + idx, packed_out);
}

__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
template <typename Ti, typename Tdout, typename Tdinp>
__global__ void gelu_backward_kernel(Tdinp* dinp, const Tdout* dout, const Ti* inp) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * Packed128<Tdout>::size;

x128 packed_dinp;
x128 packed_inp = load128cs(inp + idx);
x128 packed_dout = load128(d_in_out + idx);
for (int k = 0; k < packed_inp.size; ++k) {
Packed128<Tdinp> packed_dinp;
Packed128<Ti> packed_inp = load128cs(inp + idx);
Packed128<Tdout> packed_dout = load128(dout + idx);
for (int k = 0; k < Packed128<Tdout>::size; ++k) {
float x = (float)packed_inp[k];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);

float tanh_in_out = GELU_SCALING_FACTOR * (x + cube);
#if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750
asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out));
#else
tanh_in_out = tanhf(tanh_in_out);
#endif

float sech_out = 1.0f - (tanh_in_out * tanh_in_out);
float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x));
float result = local_grad * (float)packed_dout[k];
packed_dinp[k] = (Tdinp)(result);
}
store128(d_in_out + idx, packed_dinp);
store128(dinp + idx, packed_dinp);
}

// ----------------------------------------------------------------------------
Expand All @@ -61,6 +78,6 @@ void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cud
const int block_size = 128;
assert(N % (block_size * x128::size) == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_inplace_kernel<<<grid_size, block_size, 0, stream>>>(d_in_out, inp);
gelu_backward_kernel<<<grid_size, block_size, 0, stream>>>(d_in_out, d_in_out, inp);
cudaCheck(cudaGetLastError());
}