Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 18, 2023
1 parent eb14798 commit e7fa9ce
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 56 deletions.
2 changes: 0 additions & 2 deletions include/flexflow/ops/kernels/residual_rms_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ void backward_kernel_wrapper(
ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &residual_output_rms_input,
GenericTensorAccessorR const &residual_input0,
GenericTensorAccessorW const &residual_input0_grad,
GenericTensorAccessorR const &residual_input1,
GenericTensorAccessorW const &residual_input1_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &weight_grad);
Expand Down
14 changes: 2 additions & 12 deletions src/ops/kernels/residual_rms_norm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ template <typename T>
void backward_kernel(ResidualRMSNormMeta const *m,
T const *output_grad_ptr,
T const *residual_output_rms_input_ptr,
T const *residual_input0_ptr,
T *residual_input0_grad_ptr,
T const *residual_input1_ptr,
T *residual_input1_grad_ptr,
T const *weight_ptr,
T *weight_grad_ptr,
Expand Down Expand Up @@ -341,9 +339,7 @@ void backward_kernel_wrapper(
ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &residual_output_rms_input,
GenericTensorAccessorR const &residual_input0,
GenericTensorAccessorW const &residual_input0_grad,
GenericTensorAccessorR const &residual_input1,
GenericTensorAccessorW const &residual_input1_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &weight_grad) {
Expand All @@ -356,20 +352,16 @@ void backward_kernel_wrapper(
cudaEventRecord(t_start, stream);
}
assert(output_grad.data_type == residual_output_rms_input.data_type);
assert(residual_output_rms_input.data_type == residual_input0.data_type);
assert(residual_input0.data_type == residual_input0_grad.data_type);
assert(residual_input0_grad.data_type == residual_input1.data_type);
assert(residual_input1.data_type == residual_input1_grad.data_type);
assert(residual_output_rms_input.data_type == residual_input0_grad.data_type);
assert(residual_input0_grad.data_type == residual_input1_grad.data_type);
assert(residual_input1_grad.data_type == weight.data_type);
assert(weight.data_type == weight_grad.data_type);

if (output_grad.data_type == DT_HALF) {
backward_kernel(m,
output_grad.get_half_ptr(),
residual_output_rms_input.get_half_ptr(),
residual_input0.get_half_ptr(),
residual_input0_grad.get_half_ptr(),
residual_input1.get_half_ptr(),
residual_input1_grad.get_half_ptr(),
weight.get_half_ptr(),
weight_grad.get_half_ptr(),
Expand All @@ -378,9 +370,7 @@ void backward_kernel_wrapper(
backward_kernel(m,
output_grad.get_float_ptr(),
residual_output_rms_input.get_float_ptr(),
residual_input0.get_float_ptr(),
residual_input0_grad.get_float_ptr(),
residual_input1.get_float_ptr(),
residual_input1_grad.get_float_ptr(),
weight.get_float_ptr(),
weight_grad.get_float_ptr(),
Expand Down
62 changes: 20 additions & 42 deletions src/ops/residual_rms_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,68 +511,52 @@ void ResidualRMSNorm::backward(FFModel const &ff) {
EXCLUSIVE,
outputs[0]->region));
launcher.add_field(1, FID_DATA);
// regions[2](I): residual input 0
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
inputs[0]->region));
launcher.add_field(2, FID_DATA);
// regions[3](I/O): residual input grad 0
// regions[2](I/O): residual input grad 0
launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
inputs[0]->region_grad));
launcher.add_field(3, FID_DATA);
// regions[4](I): residual input 1
launcher.add_region_requirement(RegionRequirement(inputs[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
inputs[1]->region));
launcher.add_field(4, FID_DATA);
// regions[5](I/O): residual input grad 1
launcher.add_field(2, FID_DATA);
// regions[3](I/O): residual input grad 1
launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
inputs[1]->region_grad));
launcher.add_field(5, FID_DATA);
// regions[3](I): gamma
launcher.add_field(3, FID_DATA);
// regions[4](I): gamma
launcher.add_region_requirement(RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region));
launcher.add_field(6, FID_DATA);
// regions[4](I/O): gamma_grad
launcher.add_field(4, FID_DATA);
// regions[5](I/O): gamma_grad
launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
weights[0]->region_grad));
launcher.add_field(7, FID_DATA);
launcher.add_field(5, FID_DATA);

runtime->execute_index_space(ctx, launcher);
}

/*
regions[0](I): RMS output_grad
regions[1](I): Residual output / RMS input
regions[2](I): Residual input 0
regions[3](I/O): Residual input 0 grad
regions[4](I): Residual input 1
regions[5](I/O): Residual input 1 grad
regions[6](I): weight
regions[7](I/O): weight_grad
regions[2](I/O): Residual input 0 grad
regions[3](I/O): Residual input 1 grad
regions[4](I): weight
regions[5](I/O): weight_grad
*/
void ResidualRMSNorm::backward_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
assert(task->regions.size() == 8);
assert(regions.size() == 8);
assert(task->regions.size() == 6);
assert(regions.size() == 6);
ResidualRMSNormMeta const *m = *((ResidualRMSNormMeta **)task->local_args);
GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO(
m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
Expand All @@ -583,34 +567,28 @@ void ResidualRMSNorm::backward_task(Task const *task,
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR residual_input0 = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime);
GenericTensorAccessorW residual_input0_grad =
helperGetGenericTensorAccessorRW(m->input_type[0],
regions[3],
task->regions[3],
regions[2],
task->regions[2],
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR residual_input1 = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime);
GenericTensorAccessorW residual_input1_grad =
helperGetGenericTensorAccessorRW(m->input_type[0],
regions[5],
task->regions[5],
regions[3],
task->regions[3],
FID_DATA,
ctx,
runtime);
GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO(
m->weight_type[0], regions[6], task->regions[6], FID_DATA, ctx, runtime);
m->weight_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime);
GenericTensorAccessorW weight_grad = helperGetGenericTensorAccessorRW(
m->weight_type[0], regions[7], task->regions[7], FID_DATA, ctx, runtime);
m->weight_type[0], regions[5], task->regions[5], FID_DATA, ctx, runtime);
backward_kernel_wrapper(m,
output_grad,
residual_output_rms_input,
residual_input0,
residual_input0_grad,
residual_input1,
residual_input1_grad,
weight,
weight_grad);
Expand Down

0 comments on commit e7fa9ce

Please sign in to comment.