Skip to content

Commit

Permalink
residual rms norm backward
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 18, 2023
1 parent 4d55b40 commit eb14798
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 1 deletion.
2 changes: 2 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ enum TaskIDs {
RMSNORM_PEFT_BWD_TASK_ID,
RESIDUAL_RMSNORM_INIT_TASK_ID,
RESIDUAL_RMSNORM_INF_TASK_ID,
RESIDUAL_RMSNORM_BWD_TASK_ID,
RESIDUAL_RMSNORM_PEFT_BWD_TASK_ID,
BEAM_TOPK_INIT_TASK_ID,
BEAM_TOPK_INF_TASK_ID,
INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID,
Expand Down
10 changes: 10 additions & 0 deletions include/flexflow/ops/kernels/residual_rms_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ void forward_kernel_wrapper(ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &residual_output,
GenericTensorAccessorW const &output);
void backward_kernel_wrapper(
ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &residual_output_rms_input,
GenericTensorAccessorR const &residual_input0,
GenericTensorAccessorW const &residual_input0_grad,
GenericTensorAccessorR const &residual_input1,
GenericTensorAccessorW const &residual_input1_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &weight_grad);
} // namespace ResidualRMSNorm
} // namespace Kernels
} // namespace FlexFlow
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ops/residual_rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class ResidualRMSNorm : public Op {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void backward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
Expand Down
181 changes: 181 additions & 0 deletions src/ops/kernels/residual_rms_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ __inline__ __device__ T WarpReduceSum(T val) {
return val;
}

template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T *shared) {
int const lid = threadIdx.x % C10_WARP_SIZE;
int const wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < (blockDim.x / C10_WARP_SIZE)) ? shared[lid] : T(0);
if (wid == 0) {
val = WarpReduceSum(val);
}
return val;
}

template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T *shared, int max_num_threads) {
int const lid = threadIdx.x % C10_WARP_SIZE;
Expand Down Expand Up @@ -219,6 +236,170 @@ void forward_kernel_wrapper(ResidualRMSNormMeta const *m,
}
}

template <typename T>
__global__ void ComputeInternalGradientsCUDAKernel(
int64_t N, T const *dY, T const *X, T const *gamma, T const *rrms, T *c2) {
__shared__ T ds_storage[C10_WARP_SIZE];
const int64_t i = blockIdx.x;
T ds = 0;
for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
int const index = i * N + j;
ds += dY[index] * X[index] * gamma[j];
}
ds = BlockReduceSum<T>(ds, ds_storage);
if (threadIdx.x == 0) {
c2[i] = -ds * (rrms[i] * rrms[i] * rrms[i]) / static_cast<T>((int)N);
}
}

template <typename T>
__global__ void RMSNormBackwardCUDAKernel(int64_t N,
T const *dY,
T const *X,
T const *gamma,
T const *c1,
T const *c2,
T *dX1,
T *dX2) {
const int64_t i = blockIdx.x;
for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
const int64_t index = i * N + j;
T dX_val = c1[i] * dY[index] * gamma[j] + c2[i] * X[index];
dX1[index] += dX_val;
dX2[index] += dX_val;
}
}

// Assume the batch size will not be very large, direct implementation is the
// most efficient one.
template <typename T>
__global__ void GammaBackwardCUDAKernel(
int64_t M, int64_t N, T const *dY, T const *X, T const *rrms, T *dg) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
if (j < N) {
T sum1 = 0;
for (int64_t i = 0; i < M; ++i) {
const int64_t index = i * N + j;
sum1 += dY[index] * X[index] * rrms[i];
}
dg[j] = sum1;
}
}

template <typename T>
void backward_kernel(ResidualRMSNormMeta const *m,
T const *output_grad_ptr,
T const *residual_output_rms_input_ptr,
T const *residual_input0_ptr,
T *residual_input0_grad_ptr,
T const *residual_input1_ptr,
T *residual_input1_grad_ptr,
T const *weight_ptr,
T *weight_grad_ptr,
cudaStream_t stream) {
const int64_t M = m->batch_size;
const int64_t N = m->num_elements;
ComputeInternalGradientsCUDAKernel<T>
<<<M, kCUDABlockReduceNumThreads, 0, stream>>>(
N,
output_grad_ptr,
residual_output_rms_input_ptr,
weight_ptr,
static_cast<T *>(m->rms_ptr),
static_cast<T *>(m->norm_ptr));

RMSNormBackwardCUDAKernel<T>
<<<M, kCUDANumThreads, 0, stream>>>(N,
output_grad_ptr,
residual_output_rms_input_ptr,
weight_ptr,
static_cast<T *>(m->rms_ptr),
static_cast<T *>(m->norm_ptr),
residual_input0_grad_ptr,
residual_input1_grad_ptr);
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
GammaBackwardCUDAKernel<T>
<<<B, kCUDANumThreads, 0, stream>>>(M,
N,
output_grad_ptr,
residual_output_rms_input_ptr,
static_cast<T *>(m->rms_ptr),
weight_grad_ptr);
}

/*
regions[0](I): RMS output_grad
regions[1](I): Residual output / RMS input
regions[2](I): Residual input 0
regions[3](I/O): Residual input 0 grad
regions[4](I): Residual input 1
regions[5](I/O): Residual input 1 grad
regions[6](I): weight
regions[7](I/O): weight_grad
*/
void backward_kernel_wrapper(
ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &residual_output_rms_input,
GenericTensorAccessorR const &residual_input0,
GenericTensorAccessorW const &residual_input0_grad,
GenericTensorAccessorR const &residual_input1,
GenericTensorAccessorW const &residual_input1_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &weight_grad) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
cudaEvent_t t_start, t_end;
if (m->profiling) {
cudaEventCreate(&t_start);
cudaEventCreate(&t_end);
cudaEventRecord(t_start, stream);
}
assert(output_grad.data_type == residual_output_rms_input.data_type);
assert(residual_output_rms_input.data_type == residual_input0.data_type);
assert(residual_input0.data_type == residual_input0_grad.data_type);
assert(residual_input0_grad.data_type == residual_input1.data_type);
assert(residual_input1.data_type == residual_input1_grad.data_type);
assert(residual_input1_grad.data_type == weight.data_type);
assert(weight.data_type == weight_grad.data_type);

if (output_grad.data_type == DT_HALF) {
backward_kernel(m,
output_grad.get_half_ptr(),
residual_output_rms_input.get_half_ptr(),
residual_input0.get_half_ptr(),
residual_input0_grad.get_half_ptr(),
residual_input1.get_half_ptr(),
residual_input1_grad.get_half_ptr(),
weight.get_half_ptr(),
weight_grad.get_half_ptr(),
stream);
} else if (output_grad.data_type == DT_FLOAT) {
backward_kernel(m,
output_grad.get_float_ptr(),
residual_output_rms_input.get_float_ptr(),
residual_input0.get_float_ptr(),
residual_input0_grad.get_float_ptr(),
residual_input1.get_float_ptr(),
residual_input1_grad.get_float_ptr(),
weight.get_float_ptr(),
weight_grad.get_float_ptr(),
stream);
} else {
assert(false && "Unsupported data type");
}

if (m->profiling) {
cudaEventRecord(t_end, stream);
checkCUDA(cudaEventSynchronize(t_end));
float elapsed = 0;
checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end));
cudaEventDestroy(t_start);
cudaEventDestroy(t_end);
printf("[ResidualRMSNorm] backward time (CF) = %.2fms\n", elapsed);
}
}

} // namespace ResidualRMSNorm
} // namespace Kernels
} // namespace FlexFlow
131 changes: 130 additions & 1 deletion src/ops/residual_rms_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,137 @@ Node ResidualRMSNorm::deserialize(FFModel &ff,
}

void ResidualRMSNorm::backward(FFModel const &ff) {
assert(false);
ArgumentMap argmap;
Context ctx = ff.config.lg_ctx;
Runtime *runtime = ff.config.lg_hlr;
set_argumentmap_for_backward(ff, argmap);
IndexLauncher launcher(RESIDUAL_RMSNORM_BWD_TASK_ID,
parallel_is,
TaskArgument(NULL, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
// regions[0](I): RMS output_grad
launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
outputs[0]->region_grad));
launcher.add_field(0, FID_DATA);
// regions[1](I): residual output / RMS input
launcher.add_region_requirement(RegionRequirement(outputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
outputs[0]->region));
launcher.add_field(1, FID_DATA);
// regions[2](I): residual input 0
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
inputs[0]->region));
launcher.add_field(2, FID_DATA);
// regions[3](I/O): residual input grad 0
launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
inputs[0]->region_grad));
launcher.add_field(3, FID_DATA);
// regions[4](I): residual input 1
launcher.add_region_requirement(RegionRequirement(inputs[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
inputs[1]->region));
launcher.add_field(4, FID_DATA);
// regions[5](I/O): residual input grad 1
launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
inputs[1]->region_grad));
launcher.add_field(5, FID_DATA);
// regions[3](I): gamma
launcher.add_region_requirement(RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region));
launcher.add_field(6, FID_DATA);
// regions[4](I/O): gamma_grad
launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
weights[0]->region_grad));
launcher.add_field(7, FID_DATA);

runtime->execute_index_space(ctx, launcher);
}

/*
regions[0](I): RMS output_grad
regions[1](I): Residual output / RMS input
regions[2](I): Residual input 0
regions[3](I/O): Residual input 0 grad
regions[4](I): Residual input 1
regions[5](I/O): Residual input 1 grad
regions[6](I): weight
regions[7](I/O): weight_grad
*/
void ResidualRMSNorm::backward_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
assert(task->regions.size() == 8);
assert(regions.size() == 8);
ResidualRMSNormMeta const *m = *((ResidualRMSNormMeta **)task->local_args);
GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO(
m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorW residual_output_rms_input =
helperGetGenericTensorAccessorRW(m->input_type[0],
regions[1],
task->regions[1],
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR residual_input0 = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime);
GenericTensorAccessorW residual_input0_grad =
helperGetGenericTensorAccessorRW(m->input_type[0],
regions[3],
task->regions[3],
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR residual_input1 = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime);
GenericTensorAccessorW residual_input1_grad =
helperGetGenericTensorAccessorRW(m->input_type[0],
regions[5],
task->regions[5],
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO(
m->weight_type[0], regions[6], task->regions[6], FID_DATA, ctx, runtime);
GenericTensorAccessorW weight_grad = helperGetGenericTensorAccessorRW(
m->weight_type[0], regions[7], task->regions[7], FID_DATA, ctx, runtime);
backward_kernel_wrapper(m,
output_grad,
residual_output_rms_input,
residual_input0,
residual_input0_grad,
residual_input1,
residual_input1_grad,
weight,
weight_grad);
}

Op *ResidualRMSNorm::materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const {
Expand Down

0 comments on commit eb14798

Please sign in to comment.