Skip to content

Commit

Permalink
[fix] fix use_fp8 flag
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Nov 1, 2024
1 parent 5b5fbcf commit 0218e67
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
6 changes: 2 additions & 4 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""

@staticmethod
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
Expand Down Expand Up @@ -793,7 +791,7 @@ def backward(ctx, grad_output):
if ctx.async_grad_reduce_scatter:
handle.wait()

return output, grad_weight, grad_bias, None, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
Expand Down
3 changes: 2 additions & 1 deletion tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm2 import *

# from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *
Expand Down

0 comments on commit 0218e67

Please sign in to comment.