diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index c78cf4ba5..f9f348b9f 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -836,7 +836,7 @@ struct MixedInputUtils { } } else { - auto stage = make_tensor_like(src_vm); + auto stage = make_tensor_like(src_vm(_, 0)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { LayoutAwareConvert(src_vm(_, i), stage); @@ -868,7 +868,7 @@ struct MixedInputUtils { } } else { - auto stage = make_tensor_like(src_vm); + auto stage = make_tensor_like(src_vm(_, 0)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { LayoutAwareConvert(src_vm(_, i), stage);