Skip to content

Commit

Permalink
[fix] hotfix normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Dec 23, 2024
1 parent 130229f commit b553453
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,24 @@ def forward(self, input):

FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)

def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)

def forward(self, input):
output = super().forward(input)
output = hook_parameter_in_backward(output, self.weight)
return output

FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
)

FusedRMSNormWithHook = CUDAFusedRMSNormWithHook

FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
Expand Down

0 comments on commit b553453

Please sign in to comment.