diff --git a/dev/cuda/common.h b/dev/cuda/common.h index 61a783a60..8bb880bfc 100644 --- a/dev/cuda/common.h +++ b/dev/cuda/common.h @@ -315,7 +315,7 @@ void validate_result(D* device_result, const T* cpu_reference, const char* name, #ifndef ENABLE_BF16 float epsilon = FLT_EPSILON; #else - float epsilon = 0.079; + float epsilon = 0.0079; // ~2^-7 (where 7 is the number of mantissa bits in BF16) #endif for (int i = 0; i < num_elements; i++) { // Skip masked elements diff --git a/dev/cuda/gelu_backward.cu b/dev/cuda/gelu_backward.cu index 3d12dd864..ae2c73758 100644 --- a/dev/cuda/gelu_backward.cu +++ b/dev/cuda/gelu_backward.cu @@ -80,6 +80,32 @@ __global__ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* do } } +template +__global__ void gelu_backward3(Tdinp* dinp, const Ti* inp, const Tdout* dout, const int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * Packed128::size; + if (idx >= N) { return; } + + Packed128 packed_dinp; + Packed128 packed_inp = load128cs(inp + idx); + Packed128 packed_dout = load128(dout + idx); + for (int k = 0; k < Packed128::size; ++k) { + float x = (float)packed_inp[k]; + float cube = 0.044715f * x * x * x; + + float tanh_in_out = GELU_SCALING_FACTOR * (x + cube); + #if !defined(PRECISE_GELU_TANH) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + float sech_out = 1.0f - (tanh_in_out * tanh_in_out); + float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); + float result = local_grad * (float)packed_dout[k]; + packed_dinp[k] = (Tdinp)(result); + } + store128(dinp + idx, packed_dinp); +} // ---------------------------------------------------------------------------- // kernel launcher @@ -95,10 +121,16 @@ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, int N, cudaCheck(cudaGetLastError()); } +void gelu_backward3(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) { + const int grid_size = ceil_div(N, block_size * x128::size); + gelu_backward3<<>>(dinp, inp, dout, N); + cudaCheck(cudaGetLastError()); +} + // kernel version dispatch void gelu_backward(int kernel_num, - floatX* dinp, - const floatX* inp, + floatX* dinp, + const floatX* inp, const floatX* dout, int B, int T, int C, int block_size) { @@ -109,6 +141,9 @@ void gelu_backward(int kernel_num, case 2: gelu_backward2(dinp, inp, dout, B * T * C, block_size); break; + case 3: + gelu_backward3(dinp, inp, dout, B * T * C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -120,7 +155,7 @@ void gelu_backward(int kernel_num, int main(int argc, char **argv) { setup_main(); - int B = 8; + int B = 128; int T = 1024; int C = 768; @@ -157,9 +192,9 @@ int main(int argc, char **argv) { printf("Checking block size %d.\n", block_size); gelu_backward(kernel_num, d_dinp, d_inp, d_dout, B, T, C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) - float tol = 1e-5; + float tol = 1e-5f; #else - float tol = 1e-2f; + float tol = 1e-3f; #endif validate_result(d_dinp, dinp, "dinp", B * T * C, tol); } @@ -178,7 +213,7 @@ int main(int argc, char **argv) { // napkin math: estimate the memory bandwidth achieved // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each // and e.g. A100 40GB PCIe is advertised at 1,555GB/s - long memory_ops = B * T * C * 2 * 4; + long memory_ops = B * T * C * 3 * (int)sizeof(floatX); float memory_bandwidth = memory_ops / elapsed_time / 1e6; printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth); diff --git a/dev/cuda/gelu_forward.cu b/dev/cuda/gelu_forward.cu index 01abfe2b5..f543571b5 100644 --- a/dev/cuda/gelu_forward.cu +++ b/dev/cuda/gelu_forward.cu @@ -66,6 +66,33 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) { } } +// Optimised with option to use optimised HW TANH instruction by default +__global__ void gelu_forward_kernel3(floatX* out, const floatX* inp, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + if (idx >= N) { return; } + + x128 packed_out; + x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache + for(int k = 0; k < packed_inp.size; ++k) { + float xi = (float)packed_inp[k]; + float cube = 0.044715f * xi * xi * xi; + + float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube); + #if !defined(PRECISE_GELU_TANH) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_xi = 0.5f * xi; + packed_out[k] = (floatX)(half_xi * tanh_in_out + half_xi); + } + // store instead of storecs (without cache streaming) in case it is useful for the + // data to be in the cache for the next operation after this GeLU + store128(out + idx, packed_out); +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -81,6 +108,12 @@ void gelu_forward2(floatX* out, const floatX* inp, int N, const int block_size) cudaCheck(cudaGetLastError()); } +void gelu_forward3(floatX* out, const floatX* inp, int N, const int block_size) { + const int grid_size = ceil_div(N, block_size * x128::size); + gelu_forward_kernel3<<>>(out, inp, N); + cudaCheck(cudaGetLastError()); +} + // kernel version dispatch void gelu_forward(int kernel_num, floatX* out, @@ -94,6 +127,9 @@ void gelu_forward(int kernel_num, case 2: gelu_forward2(out, inp, B * T * C, block_size); break; + case 3: + gelu_forward3(out, inp, B * T * C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -105,7 +141,7 @@ void gelu_forward(int kernel_num, int main(int argc, const char **argv) { setup_main(); - int B = 8; + int B = 128; int T = 1024; int C = 768; @@ -137,9 +173,9 @@ int main(int argc, const char **argv) { printf("Checking block size %d.\n", block_size); gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size); #if !defined(ENABLE_BF16) && !defined(ENABLE_FP16) - float tol = 1e-5; + float tol = 1e-5f; #else - float tol = 1e-2f; + float tol = 1e-3f; #endif validate_result(d_out, out, "out", B * T * C, tol); } diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index cd5c297b6..c1d367ca2 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -18,30 +18,47 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) { for(int k = 0; k < packed_inp.size; ++k) { float xi = (float)packed_inp[k]; float cube = 0.044715f * xi * xi * xi; - packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)))); + + float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_xi = 0.5f * xi; + packed_out[k] = (floatX)(half_xi * tanh_in_out + half_xi); } // store instead of storecs (without cache streaming) in case it is useful for the // data to be in the cache for the next operation after this GeLU store128(out + idx, packed_out); } -__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; +template +__global__ void gelu_backward_kernel(Tdinp* dinp, const Tdout* dout, const Ti* inp) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * Packed128::size; - x128 packed_dinp; - x128 packed_inp = load128cs(inp + idx); - x128 packed_dout = load128(d_in_out + idx); - for (int k = 0; k < packed_inp.size; ++k) { + Packed128 packed_dinp; + Packed128 packed_inp = load128cs(inp + idx); + Packed128 packed_dout = load128(dout + idx); + for (int k = 0; k < Packed128::size; ++k) { float x = (float)packed_inp[k]; float cube = 0.044715f * x * x * x; - float tanh_arg = GELU_SCALING_FACTOR * (x + cube); - float tanh_out = tanhf(tanh_arg); - float coshf_out = coshf(tanh_arg); - float sech_out = 1.0f / (coshf_out * coshf_out); - float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); - packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]); + + float tanh_in_out = GELU_SCALING_FACTOR * (x + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + float sech_out = 1.0f - (tanh_in_out * tanh_in_out); + float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); + float result = local_grad * (float)packed_dout[k]; + packed_dinp[k] = (Tdinp)(result); } - store128(d_in_out + idx, packed_dinp); + store128(dinp + idx, packed_dinp); } // ---------------------------------------------------------------------------- @@ -61,6 +78,6 @@ void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cud const int block_size = 128; assert(N % (block_size * x128::size) == 0); const int grid_size = CEIL_DIV(N, block_size * x128::size); - gelu_backward_inplace_kernel<<>>(d_in_out, inp); + gelu_backward_kernel<<>>(d_in_out, d_in_out, inp); cudaCheck(cudaGetLastError()); }