Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RoPE positional encoding - llama3 feature branch #756

Open
wants to merge 15 commits into
base: llama3
Choose a base branch
from
Open
62 changes: 61 additions & 1 deletion llmc/attention.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,55 @@ __global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, con
}
}

__global__ void rope_rotate_kernel(floatX* q, floatX* k, float* rope_freqs, int B, int NH, int T, int HS, int is_backward) {
// thanks to the nice mathematical properties of RoPE this is both our fwd & bwd pass kernel!
// the only difference is that we have to toggle the sign of the sin term in the rotation
// q, k are of shape (B, NH, T, HS)
// rope_freqs is of shape (T, HS/2)
int n = HS / x128::size;
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx >= B * NH * T * n) { return; }

int b = thread_idx / (NH * T * n);
int rest = thread_idx % (NH * T * n);
int nh = rest / (T * n);
rest = rest % (T * n);
int t = rest / n;
int i = rest % n;

float* rope_freqs_t = rope_freqs + t * (HS / 2) + i * (x128::size / 2);
f128 freqs_reg = load128(rope_freqs_t); // caching the frequencies

int idx = b * NH * T * HS + nh * T * HS + t * HS + i * x128::size;
x128 q_reg = load128cs(&q[idx]);
x128 k_reg = load128cs(&k[idx]);
x128 qout_reg, kout_reg;
for (int k = 0; k < x128::size / 2; k++) { // div by 2 because we're processing tuples of 2
// rotate q
floatX x1 = q_reg[2*k];
floatX x2 = q_reg[2*k + 1];
floatX q_out1 = (floatX)((float)x1 * cosf(freqs_reg[k]) + (is_backward ? 1 : -1) * (float)x2 * sinf(freqs_reg[k]));
floatX q_out2 = (floatX)((float)x2 * cosf(freqs_reg[k]) + (is_backward ? -1 : 1) * (float)x1 * sinf(freqs_reg[k]));
qout_reg[2*k] = q_out1;
qout_reg[2*k + 1] = q_out2;
// rotate k
x1 = k_reg[2*k];
x2 = k_reg[2*k + 1];
floatX k_out1 = (floatX)((float)x1 * cosf(freqs_reg[k]) + (is_backward ? 1 : -1) * (float)x2 * sinf(freqs_reg[k]));
floatX k_out2 = (floatX)((float)x2 * cosf(freqs_reg[k]) + (is_backward ? -1 : 1) * (float)x1 * sinf(freqs_reg[k]));
kout_reg[2*k] = k_out1;
kout_reg[2*k + 1] = k_out2;
}

store128cs(&q[idx], qout_reg);
store128cs(&k[idx], kout_reg);
}

// ----------------------------------------------------------------------------
// kernel launchers

void attention_forward(floatX* out, floatX* qkvr, floatX* att,
floatX* inp,
floatX* inp, int use_rope, float* rope_freqs,
int B, int T, int C, int NH, cudaStream_t stream) {
NVTX_RANGE_FN();
// Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.
Expand All @@ -214,6 +258,13 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
int num_blocks = CEIL_DIV(total_threads, block_size);
permute_kernel<<<num_blocks, block_size, 0, stream>>>(q, k, v, inp, B, T, NH, HS);

if (use_rope) {
assert(HS % x128::size == 0);
total_threads = B * NH * T * (HS / x128::size);
num_blocks = CEIL_DIV(total_threads, block_size);
rope_rotate_kernel<<<num_blocks, block_size, 0, stream>>>(q, k, rope_freqs, B, NH, T, HS, 0);
}

floatX* preatt = inp; // reuse inp as scratch buffer
matmul_cublaslt(preatt, k, q, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T);

Expand All @@ -239,6 +290,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch,
const floatX* dout,
const floatX* qkvr, const floatX* att,
int use_rope, float* rope_freqs,
int B, int T, int C, int NH, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
Expand Down Expand Up @@ -269,6 +321,14 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat
matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS);
// backward into k
matmul_cublaslt(dk, q, dpreatt, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS);

if (use_rope) {
assert(HS % x128::size == 0);
int total_threads = B * NH * T * (HS / x128::size);
num_blocks = CEIL_DIV(total_threads, block_size);
rope_rotate_kernel<<<num_blocks, block_size, 0, stream>>>(dq, dk, rope_freqs, B, NH, T, HS, 1);
}

// backward into inp
num_blocks = CEIL_DIV(B * NH * T * HS, block_size);
permute_kernel_backward<<<num_blocks, block_size, 0, stream>>>(dinp, dq, dk, dv, B, T, NH, HS);
Expand Down
49 changes: 37 additions & 12 deletions llmc/encoder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ In the backward pass, the gradients flow to both, handled by different kernels
// CUDA kernels

__global__ void encoder_forward_kernel3(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
const int* inp, const floatX* wte, const floatX* wpe, int use_rope,
int B, int T, int C) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
int N = B * T * C;
Expand All @@ -36,9 +36,16 @@ __global__ void encoder_forward_kernel3(floatX* out,

x128 packed_out;
x128 wte128 = load128cs(wte_ix);
x128 wpe128 = load128cs(wpe_tc);
x128 wpe128;
if (!use_rope) {
wpe128 = load128cs(wpe_tc);
}
for (int k = 0; k < x128::size; k++) {
packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);
if (!use_rope) {
packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);
} else {
packed_out[k] = wte128[k];
}
}
store128(out_btc, packed_out);
}
Expand Down Expand Up @@ -151,33 +158,51 @@ __global__ void wpe_backward_kernel(floatX* dwpe,
store128(dwpe_tc, packed_dwpe);
}

__global__ void init_rope_freqs_kernel(float* rope_freqs, float rope_base_freq) {
int m = blockIdx.x;
int d_half = blockDim.x;
int i = threadIdx.x + 1;
int out_idx = m * d_half + i - 1;

float theta_i = __powf(rope_base_freq, -2.0f * (float)(i - 1) / (2.f * (float)d_half));
rope_freqs[out_idx] = (float)m * theta_i;
}

// ----------------------------------------------------------------------------
// kernel launchers

void init_rope_freqs(float* rope_freqs, int max_seq_len, int HS, float rope_base_freq, cudaStream_t stream) {
NVTX_RANGE_FN();
init_rope_freqs_kernel<<<max_seq_len, HS, 0, stream>>>(rope_freqs, rope_base_freq);
cudaCheck(cudaGetLastError());
}

void encoder_forward(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
const int* inp, const floatX* wte, const floatX* wpe, int use_rope,
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
const int N = B * T * C;
const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));
encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, B, T, C);
encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, use_rope, B, T, C);
cudaCheck(cudaGetLastError());
}

// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details)
void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
int use_rope, int B, int T, int C, unsigned int seed, cudaStream_t stream) {
NVTX_RANGE_FN();

// Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)
const int block_size = 256;
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);
cudaCheck(cudaGetLastError());
if (!use_rope) {
// Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)
const int block_size = 256;
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);
cudaCheck(cudaGetLastError());
}

// check the GPU scratch buffer is large enough to hold the bucket info and workload indices
// todo - this is trivially true given hardcoded scratch buffer size here, is this useful?
Expand Down
1 change: 1 addition & 0 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ void multi_gpu_async_reduce_gradient(
cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if (pointers[i] == NULL) continue;
if(config->zero_stage == 0) {
ncclCheck(ncclAllReduce(
pointers[i], pointers[i],
Expand Down
49 changes: 43 additions & 6 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ typedef struct {
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
int use_rope; // use rope position encoding
float rope_base_freq; // base frequency for rope position encoding
float* rope_freqs; // rope position encoding frequencies
} GPT2;

void gpt2_init_common(GPT2 *model) {
Expand Down Expand Up @@ -348,6 +351,10 @@ void gpt2_init_common(GPT2 *model) {
model->init_state = true;
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())
// architecture specific settings
model->use_rope = 0; // use rope position encoding
model->rope_base_freq = 10000.0f; // base frequency for rope position encoding
model->rope_freqs = NULL; // rope position encoding frequencies
}

void gpt2_allocate_weights(GPT2 *model) {
Expand All @@ -362,6 +369,15 @@ void gpt2_allocate_weights(GPT2 *model) {
// create memory for model parameters on the device
assert(model->params_memory == nullptr);
model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);

// allocate memory for rope frequencies
if (model->use_rope) {
int HS = model->config.channels / model->config.num_heads;
assert(HS % 2 == 0); // HS must be even for RoPE
cudaCheck(cudaMalloc((float**)&model->rope_freqs, model->config.max_seq_len * (HS / 2) * sizeof(float)));
// TODO(gordicaleksa): would floatX mess up the rope frequencies due to a lower precision?
init_rope_freqs(model->rope_freqs, model->config.max_seq_len, HS / 2, model->rope_base_freq, main_stream);
}
}

void gpt2_allocate_state(GPT2 *model, int B, int T) {
Expand Down Expand Up @@ -677,7 +693,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
// forward pass
ParameterTensors params = model->params; // for brevity
ActivationTensors acts = model->acts;
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, model->use_rope, B, T, C, main_stream); // encoding goes into residual[0]

// first layernorm isn't fused
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);
Expand Down Expand Up @@ -727,7 +743,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
// these are only needed as scratchpads for the forward pass, but
// need not be stored for backward
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream);
attention_forward(l_atty, l_qkvr, l_att, scratch, model->use_rope, model->rope_freqs, B, T, C, NH, main_stream);
#endif

matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
Expand Down Expand Up @@ -915,7 +931,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, model->use_rope, model->rope_freqs, B, T, C, NH, main_stream);
#endif
if(model->recompute >= 2) {
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);
Expand Down Expand Up @@ -947,7 +963,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
}
}
encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,
dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream);
dresidual, model->inputs, inputs, model->use_rope, B, T, C, random_u32(&model->rng_state), main_stream);

// Aggregate all gradients that are not part of the transformer blocks
if(last_step) {
Expand All @@ -959,7 +975,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
#endif
cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream));
// reduce the gradients for non-transformer block parameters
floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb};
floatX* const pointers[] = {grads.wte, model->use_rope ? NULL : grads.wpe, grads.lnfw, grads.lnfb};
const size_t nelem[] = {Vp * C, T * C, C, C};
multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
}
Expand Down Expand Up @@ -1004,6 +1020,10 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {
// grads_memory only contains the averaged gradients at the local shards,
// so we only calculate the grad norm at the grads_memory belonging to the local shards
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
if (model->use_rope && i == 1) {
// skip the wpe tensor if we are using RoPE -> minor optimization, results would be correct without this as well
continue;
}
ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i);
ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);
ptrdiff_t offset = tensor.offset + shard.offset;
Expand Down Expand Up @@ -1060,6 +1080,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
// AdamW update
// handle adamw for all the transformer blocks
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
if (model->use_rope && i == 1) {
// skip the wpe tensor if we are using RoPE
continue;
}

// generate a unique seed for each tensor
unsigned int seed = random_u32(&model->rng_state);

Expand Down Expand Up @@ -1404,6 +1429,8 @@ void error_usage() {
// memory management
fprintf(stderr, " -z <int> zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n");
fprintf(stderr, " -r <int> recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n");
// architectural settings
fprintf(stderr, " -er <int> enable RoPE positional embeddings? (default = 0)\n");
// multi-node settings
fprintf(stderr, " -pn <int> num_processes (default = 1)\n");
fprintf(stderr, " -pr <int> process_rank (default = 0)\n");
Expand Down Expand Up @@ -1449,6 +1476,9 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;
// architectural settings
int use_rope = 0; // use RoPE positional embeddings
float rope_base_freq = 10000.0f; // base frequency for RoPE
// multi-node settings
int num_processes = 1; // this should be set by the slurm environment
int process_rank = 0; // this should be set by the slurm environment
Expand All @@ -1463,7 +1493,7 @@ int main(int argc, char *argv[]) {
// read in the args
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'e' && argv[i][2] == '\0') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; }
else if (argv[i][1] == 'n' && argv[i][2] == '\0') { checkpoint_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); }
Expand Down Expand Up @@ -1498,6 +1528,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'e' && argv[i][2] == 'r') { use_rope = atoi(argv[i+1]); }
else { error_usage(); }
}

Expand Down Expand Up @@ -1571,6 +1602,12 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model
GPT2 model;
gpt2_init_common(&model);
// architectural modifications
#ifdef ENABLE_CUDNN
use_rope = 0; // RoPE is not supported with cudnn atm
#endif
model.use_rope = use_rope;
model.rope_base_freq = rope_base_freq;
if (resuming == 1) {
// if `-y 1` was set, then we are resuming from the latest checkpoint
// if we are using master weights, we'll init them later inside load_state()
Expand Down