diff --git a/include/flexflow/ops/kernels/residual_rms_norm_kernels.h b/include/flexflow/ops/kernels/residual_rms_norm_kernels.h index 26a5686f0b..75dcfc945f 100644 --- a/include/flexflow/ops/kernels/residual_rms_norm_kernels.h +++ b/include/flexflow/ops/kernels/residual_rms_norm_kernels.h @@ -52,9 +52,7 @@ void backward_kernel_wrapper( ResidualRMSNormMeta const *m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &residual_output_rms_input, - GenericTensorAccessorR const &residual_input0, GenericTensorAccessorW const &residual_input0_grad, - GenericTensorAccessorR const &residual_input1, GenericTensorAccessorW const &residual_input1_grad, GenericTensorAccessorR const &weight, GenericTensorAccessorW const &weight_grad); diff --git a/src/ops/kernels/residual_rms_norm_kernels.cu b/src/ops/kernels/residual_rms_norm_kernels.cu index 75dee4808c..2fc4cc95c2 100644 --- a/src/ops/kernels/residual_rms_norm_kernels.cu +++ b/src/ops/kernels/residual_rms_norm_kernels.cu @@ -290,9 +290,7 @@ template void backward_kernel(ResidualRMSNormMeta const *m, T const *output_grad_ptr, T const *residual_output_rms_input_ptr, - T const *residual_input0_ptr, T *residual_input0_grad_ptr, - T const *residual_input1_ptr, T *residual_input1_grad_ptr, T const *weight_ptr, T *weight_grad_ptr, @@ -341,9 +339,7 @@ void backward_kernel_wrapper( ResidualRMSNormMeta const *m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &residual_output_rms_input, - GenericTensorAccessorR const &residual_input0, GenericTensorAccessorW const &residual_input0_grad, - GenericTensorAccessorR const &residual_input1, GenericTensorAccessorW const &residual_input1_grad, GenericTensorAccessorR const &weight, GenericTensorAccessorW const &weight_grad) { @@ -356,10 +352,8 @@ void backward_kernel_wrapper( cudaEventRecord(t_start, stream); } assert(output_grad.data_type == residual_output_rms_input.data_type); - assert(residual_output_rms_input.data_type == residual_input0.data_type); - assert(residual_input0.data_type == residual_input0_grad.data_type); - assert(residual_input0_grad.data_type == residual_input1.data_type); - assert(residual_input1.data_type == residual_input1_grad.data_type); + assert(residual_output_rms_input.data_type == residual_input0_grad.data_type); + assert(residual_input0_grad.data_type == residual_input1_grad.data_type); assert(residual_input1_grad.data_type == weight.data_type); assert(weight.data_type == weight_grad.data_type); @@ -367,9 +361,7 @@ void backward_kernel_wrapper( backward_kernel(m, output_grad.get_half_ptr(), residual_output_rms_input.get_half_ptr(), - residual_input0.get_half_ptr(), residual_input0_grad.get_half_ptr(), - residual_input1.get_half_ptr(), residual_input1_grad.get_half_ptr(), weight.get_half_ptr(), weight_grad.get_half_ptr(), @@ -378,9 +370,7 @@ void backward_kernel_wrapper( backward_kernel(m, output_grad.get_float_ptr(), residual_output_rms_input.get_float_ptr(), - residual_input0.get_float_ptr(), residual_input0_grad.get_float_ptr(), - residual_input1.get_float_ptr(), residual_input1_grad.get_float_ptr(), weight.get_float_ptr(), weight_grad.get_float_ptr(), diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index d382f05394..1e0b652163 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -511,48 +511,34 @@ void ResidualRMSNorm::backward(FFModel const &ff) { EXCLUSIVE, outputs[0]->region)); launcher.add_field(1, FID_DATA); - // regions[2](I): residual input 0 - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(2, FID_DATA); - // regions[3](I/O): residual input grad 0 + // regions[2](I/O): residual input grad 0 launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, 0 /*projection id*/, READ_WRITE, EXCLUSIVE, inputs[0]->region_grad)); - launcher.add_field(3, FID_DATA); - // regions[4](I): residual input 1 - launcher.add_region_requirement(RegionRequirement(inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[1]->region)); - launcher.add_field(4, FID_DATA); - // regions[5](I/O): residual input grad 1 + launcher.add_field(2, FID_DATA); + // regions[3](I/O): residual input grad 1 launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, 0 /*projection id*/, READ_WRITE, EXCLUSIVE, inputs[1]->region_grad)); - launcher.add_field(5, FID_DATA); - // regions[3](I): gamma + launcher.add_field(3, FID_DATA); + // regions[4](I): gamma launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[0]->region)); - launcher.add_field(6, FID_DATA); - // regions[4](I/O): gamma_grad + launcher.add_field(4, FID_DATA); + // regions[5](I/O): gamma_grad launcher.add_region_requirement(RegionRequirement(weights[0]->part_grad, 0 /*projection id*/, READ_WRITE, EXCLUSIVE, weights[0]->region_grad)); - launcher.add_field(7, FID_DATA); + launcher.add_field(5, FID_DATA); runtime->execute_index_space(ctx, launcher); } @@ -560,19 +546,17 @@ void ResidualRMSNorm::backward(FFModel const &ff) { /* regions[0](I): RMS output_grad regions[1](I): Residual output / RMS input - regions[2](I): Residual input 0 - regions[3](I/O): Residual input 0 grad - regions[4](I): Residual input 1 - regions[5](I/O): Residual input 1 grad - regions[6](I): weight - regions[7](I/O): weight_grad + regions[2](I/O): Residual input 0 grad + regions[3](I/O): Residual input 1 grad + regions[4](I): weight + regions[5](I/O): weight_grad */ void ResidualRMSNorm::backward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(task->regions.size() == 8); - assert(regions.size() == 8); + assert(task->regions.size() == 6); + assert(regions.size() == 6); ResidualRMSNormMeta const *m = *((ResidualRMSNormMeta **)task->local_args); GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); @@ -583,34 +567,28 @@ void ResidualRMSNorm::backward_task(Task const *task, FID_DATA, ctx, runtime); - GenericTensorAccessorR residual_input0 = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); GenericTensorAccessorW residual_input0_grad = helperGetGenericTensorAccessorRW(m->input_type[0], - regions[3], - task->regions[3], + regions[2], + task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR residual_input1 = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime); GenericTensorAccessorW residual_input1_grad = helperGetGenericTensorAccessorRW(m->input_type[0], - regions[5], - task->regions[5], + regions[3], + task->regions[3], FID_DATA, ctx, runtime); GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[6], task->regions[6], FID_DATA, ctx, runtime); + m->weight_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime); GenericTensorAccessorW weight_grad = helperGetGenericTensorAccessorRW( - m->weight_type[0], regions[7], task->regions[7], FID_DATA, ctx, runtime); + m->weight_type[0], regions[5], task->regions[5], FID_DATA, ctx, runtime); backward_kernel_wrapper(m, output_grad, residual_output_rms_input, - residual_input0, residual_input0_grad, - residual_input1, residual_input1_grad, weight, weight_grad);