From 0218e673db79eda513a71054694b8845a4b1ee1b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 1 Nov 2024 07:05:24 +0000 Subject: [PATCH] [fix] fix use_fp8 flag --- colossalai/shardformer/layer/_operation.py | 6 ++---- tests/kit/model_zoo/transformers/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index d918076075e6..8c2e6e7c5d92 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward( - ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication - ): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -793,7 +791,7 @@ def backward(ctx, grad_output): if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4adc386192d3..02996823166a 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,8 @@ from .bert import * from .blip2 import * from .bloom import * -from .chatglm2 import * + +# from .chatglm2 import * from .command import * from .deepseek import * from .falcon import *