From 54fe0fea6fb3d7807ff0f214f9243e2ccdee9662 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Fri, 17 Jan 2025 06:28:18 +0000 Subject: [PATCH 1/3] add the custom_ops for paddlenlp --- csrc/utils/tune_cutlass_fp8_dual_gemm.py | 3 +- paddlenlp/__init__.py | 1 + paddlenlp/custom_ops/__init__.py | 15 + paddlenlp/custom_ops/custom_ops.py | 515 ++++++++++++++++++ .../transformers/bloom/modeling.py | 4 +- .../transformers/chatglm/modeling.py | 2 +- .../transformers/chatglm_v2/modeling.py | 4 +- .../transformers/fused_transformer_layers.py | 24 +- .../transformers/generation_utils.py | 34 +- .../experimental/transformers/gpt/modeling.py | 2 +- .../transformers/llama/modeling.py | 8 +- .../transformers/mixtral/modeling.py | 6 +- .../experimental/transformers/opt/modeling.py | 2 +- .../experimental/transformers/proposers.py | 2 +- .../transformers/qwen/modeling.py | 4 +- .../transformers/qwen2/modeling.py | 6 +- .../transformers/qwen2_moe/modeling.py | 6 +- .../transformers/ring_flash_attention.py | 2 +- paddlenlp/trl/llm_utils.py | 4 +- 19 files changed, 589 insertions(+), 55 deletions(-) create mode 100644 paddlenlp/custom_ops/__init__.py create mode 100644 paddlenlp/custom_ops/custom_ops.py diff --git a/csrc/utils/tune_cutlass_fp8_dual_gemm.py b/csrc/utils/tune_cutlass_fp8_dual_gemm.py index 6fa6b32e4dc1..68b22943d5ea 100644 --- a/csrc/utils/tune_cutlass_fp8_dual_gemm.py +++ b/csrc/utils/tune_cutlass_fp8_dual_gemm.py @@ -15,7 +15,8 @@ import argparse import paddle -from paddlenlp_ops import cutlass_fp8_fp8_fp8_dual_gemm_fused + +from paddlenlp.custom_ops import cutlass_fp8_fp8_fp8_dual_gemm_fused def setup_args(): diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index d2af409f98d5..ec9ffd5a1a15 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -41,6 +41,7 @@ import paddle from . import ( + custom_ops, data, dataaug, datasets, diff --git a/paddlenlp/custom_ops/__init__.py b/paddlenlp/custom_ops/__init__.py new file mode 100644 index 000000000000..00a224c0b6b6 --- /dev/null +++ b/paddlenlp/custom_ops/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from custom_ops import * diff --git a/paddlenlp/custom_ops/custom_ops.py b/paddlenlp/custom_ops/custom_ops.py new file mode 100644 index 000000000000..ad9efdfc2ae4 --- /dev/null +++ b/paddlenlp/custom_ops/custom_ops.py @@ -0,0 +1,515 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddlenlp_ops as _ops + + +def append_attention( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + compute_type, + cache_quant_type, + use_neox_rotary_style, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder, +): + return _ops.append_attention( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + compute_type, + cache_quant_type, + use_neox_rotary_style, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder, + ) + + +def avx_weight_only(x, weight, alog, trans): + return _ops.avx_weight_only(x, weight, alog, trans) + + +def dequant_int8(intput, out_scale, dtype): + return _ops.dequant_int8(intput, out_scale, dtype) + + +def encode_rotary_qk(q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox): + return _ops.encode_rotary_qk(q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox) + + +def flash_attn_bwd(q, k, v, out, softmax_lse, seed_offset, attn_mask, out_grad, dropout, causal): + return _ops.flash_attn_bwd(q, k, v, out, softmax_lse, seed_offset, attn_mask, out_grad, dropout, causal) + + +def cutlass_fp8_fp8_fp8_dual_gemm_fused( + x, y0, y1, bias0, bias1, transpose_x, transpose_y, scale0, scale1, scale_out, act +): + return _ops.cutlass_fp8_fp8_fp8_dual_gemm_fused( + x, y0, y1, bias0, bias1, transpose_x, transpose_y, scale0, scale1, scale_out, act + ) + + +def cutlass_fp8_fp8_half_gemm_fused(x, y, bias, transpose_x, transpose_y, scale, output_type, act): + return _ops.cutlass_fp8_fp8_half_gemm_fused(x, y, bias, transpose_x, transpose_y, scale, output_type, act) + + +def fused_get_rotary_embedding(input_ids, position_ids, head_dim_shape_tensor, prompt_num, theta, use_neox): + return _ops.fused_get_rotary_embedding(input_ids, position_ids, head_dim_shape_tensor, prompt_num, theta, use_neox) + + +def gemm_dequant(x, y, scale, out_dtype): + return _ops.gemm_dequant(x, y, scale, out_dtype) + + +def get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + max_enc_len_this_time, + max_dec_len_this_time, + seq_lens_this_time, + cum_offsets, + group_size, + block_size, + decoder_step_token_num, +): + return _ops.get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + max_enc_len_this_time, + max_dec_len_this_time, + seq_lens_this_time, + cum_offsets, + group_size, + block_size, + decoder_step_token_num, + ) + + +def get_output(x, rank_id, wait_flag): + return _ops.get_output(x, rank_id, wait_flag) + + +def get_padding_offset(input_ids, cum_offsets, token_num, seq_len): + return _ops.get_padding_offset(input_ids, cum_offsets, token_num, seq_len) + + +def get_padding_offset_v2(input_ids, cum_offsets, token_num, seq_len, draft_tokens, seq_lens_encoder): + return _ops.get_padding_offset_v2(input_ids, cum_offsets, token_num, seq_len, draft_tokens, seq_lens_encoder) + + +def get_token_penalty_multi_scores( + pre_ids, logits, penalty_scores, frequency_scores, presence_scores, cur_len, min_len, eos_token_id +): + return _ops.get_token_penalty_multi_scores( + pre_ids, logits, penalty_scores, frequency_scores, presence_scores, cur_len, min_len, eos_token_id + ) + + +def ngram_match( + input_ids, + input_ids_len, + pre_ids, + step_idx, + draft_token_num, + draft_tokens, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_batch_size, + max_ngram_size, + max_draft_tokens, +): + return _ops.ngram_match( + input_ids, + input_ids_len, + pre_ids, + step_idx, + draft_token_num, + draft_tokens, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + real_batch_size, + max_ngram_size, + max_draft_tokens, + ) + + +def qkv_transpose_split(qkv, padding_offset, seq_lens, input_ids, num_head, head_size): + return _ops.qkv_transpose_split(qkv, padding_offset, seq_lens, input_ids, num_head, head_size) + + +def quant_int8(intput, shift, smooth, scale, round_type, max_bound, min_bound): + return _ops.quant_int8(intput, shift, smooth, scale, round_type, max_bound, min_bound) + + +def rebuild_padding(tmp_out, padding_offset, seq_lens, input_ids): + return _ops.rebuild_padding(tmp_out, padding_offset, seq_lens, input_ids) + + +def rebuild_padding_v2( + tmp_out, cum_offsets, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length +): + return _ops.rebuild_padding_v2( + tmp_out, cum_offsets, seq_lens_decoder, seq_lens_encoder, output_padding_offset, max_input_length + ) + + +def save_output(x, not_need_stop, rank_id): + return _ops.save_output(x, not_need_stop, rank_id) + + +def save_with_output(x, batch_idx, step_idx, file_path, rank_id): + return _ops.save_with_output(x, batch_idx, step_idx, file_path, rank_id) + + +def set_preids_token_penalty_multi_scores( + pre_ids, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + stop_flags, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, +): + return _ops.set_preids_token_penalty_multi_scores( + pre_ids, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + stop_flags, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + ) + + +def set_stop_value_multi_ends(topk_ids, stop_flags, end_ids, mode): + return _ops.set_stop_value_multi_ends(topk_ids, stop_flags, end_ids, mode) + + +def set_value_by_flags_and_idx(pre_ids_all, pre_ids_now, step_idx, stop_flags): + return _ops.set_value_by_flags_and_idx(pre_ids_all, pre_ids_now, step_idx, stop_flags) + + +def speculate_get_output(x, rank_id, wait_flag): + return _ops.speculate_get_output(x, rank_id, wait_flag) + + +def speculate_get_output_padding_offset(output_cum_offsets_tmp, out_token_num, seq_lens_output, max_seq_len): + return _ops.speculate_get_output_padding_offset( + output_cum_offsets_tmp, out_token_num, seq_lens_output, max_seq_len + ) + + +def speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder): + return _ops.speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder) + + +def speculate_get_token_penalty_multi_scores( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + seq_lens_this_time, + output_padding_offset, + output_cum_offsets, + max_seq_len, +): + return _ops.speculate_get_token_penalty_multi_scores( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id, + seq_lens_this_time, + output_padding_offset, + output_cum_offsets, + max_seq_len, + ) + + +def speculate_save_output(accept_tokens, accept_num, not_need_stop, rank_id): + return _ops.speculate_save_output(accept_tokens, accept_num, not_need_stop, rank_id) + + +def speculate_set_value_by_flags_and_idx( + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, +): + return _ops.speculate_set_value_by_flags_and_idx( + pre_ids_all, + accept_tokens, + accept_num, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + ) + + +def speculate_verify_and_update( + accept_tokens, + accept_num, + step_idx, + seq_lens_encoder, + seq_lens_decoder, + stop_flags, + not_need_stop, + draft_tokens, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens, + is_block_step, + output_cum_offsets, + actual_candidate_len, + actual_draft_token_nums, + topp, + max_seq_len, + verify_window, + enable_topp, +): + return _ops.speculate_verify_and_update( + accept_tokens, + accept_num, + step_idx, + seq_lens_encoder, + seq_lens_decoder, + stop_flags, + not_need_stop, + draft_tokens, + seq_lens_this_time, + verify_tokens, + verify_scores, + max_dec_len, + end_tokens, + is_block_step, + output_cum_offsets, + actual_candidate_len, + actual_draft_token_nums, + topp, + max_seq_len, + verify_window, + enable_topp, + ) + + +def top_p_candidates(probs, top_p, output_padding_offset, candidates_len, max_seq_len): + return _ops.top_p_candidates(probs, top_p, output_padding_offset, candidates_len, max_seq_len) + + +def top_p_sampling_reject(probs, top_p, seed): + return _ops.top_p_sampling_reject(probs, top_p, seed) + + +def transpose_remove_padding(input, seq_lens, padding_offset): + return _ops.transpose_remove_padding(input, seq_lens, padding_offset) + + +def update_inputs_v2( + stop_flags, + step_idx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + max_dec_len, + input_ids, + stop_nums, + next_tokens, + is_block_step, + end_ids, + kwargs_next_tokens, +): + return _ops.update_inputs_v2( + stop_flags, + step_idx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + max_dec_len, + input_ids, + stop_nums, + next_tokens, + is_block_step, + end_ids, + kwargs_next_tokens, + ) + + +def write_cache_kv(input_k, input_v, cache_kv, sequence_lengths): + return _ops.write_cache_kv(input_k, input_v, cache_kv, sequence_lengths) + + +def xft_greedy_search(probs): + return _ops.xft_greedy_search(probs) + + +def xft_transformer( + input, + ln1Gamma, + qkvWeight, + attnOutWeight, + ln2Gamma, + gateWeight, + upWeight, + downWeight, + pastSeqLen, + currentSeqLen, + step, + hiddensize, + totalLayer, + computeType, + cacheDtype, + activation, + normType, + attHeadDim, + attHeadNum, + kvHeadNum, + maxPositions, + maxPosEmbed, + intermediateSize, +): + return _ops.xft_transformer( + input, + ln1Gamma, + qkvWeight, + attnOutWeight, + ln2Gamma, + gateWeight, + upWeight, + downWeight, + pastSeqLen, + currentSeqLen, + step, + hiddensize, + totalLayer, + computeType, + cacheDtype, + activation, + normType, + attHeadDim, + attHeadNum, + kvHeadNum, + maxPositions, + maxPosEmbed, + intermediateSize, + ) diff --git a/paddlenlp/experimental/transformers/bloom/modeling.py b/paddlenlp/experimental/transformers/bloom/modeling.py index 2d1218802449..4362864fc22e 100644 --- a/paddlenlp/experimental/transformers/bloom/modeling.py +++ b/paddlenlp/experimental/transformers/bloom/modeling.py @@ -221,7 +221,7 @@ def set_input_embeddings(self, new_embeddings: Tensor): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -597,7 +597,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/experimental/transformers/chatglm/modeling.py b/paddlenlp/experimental/transformers/chatglm/modeling.py index 26a014f3c602..ea1f8461ec01 100644 --- a/paddlenlp/experimental/transformers/chatglm/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm/modeling.py @@ -268,7 +268,7 @@ def __init__(self, config: ChatGLMConfig): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index 58181020d594..8341513b6eca 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -202,7 +202,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -391,7 +391,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index a1f67843be31..ac53ef7e7b1b 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -45,7 +45,7 @@ if ( paddle.device.get_all_custom_device_type() is not None and len(paddle.device.get_all_custom_device_type()) > 0 ) or paddle.is_compiled_with_cuda(): - from paddlenlp_ops import rebuild_padding_v2 + from paddlenlp.custom_ops import rebuild_padding_v2 def use_cutlass_fp8_gemm(): @@ -55,14 +55,16 @@ def use_cutlass_fp8_gemm(): if paddle.is_compiled_with_cuda(): if use_cutlass_fp8_gemm(): logger.info("cutlass fp8 gemm is used. you can turn it off by setting FLAGS_CUTLASS_FP8_GEMM to False.") - from paddlenlp_ops import ( + from paddlenlp.custom_ops import ( cutlass_fp8_fp8_fp8_dual_gemm_fused as fp8_dual_gemm_fused, ) - from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused + from paddlenlp.custom_ops import ( + cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused, + ) else: from paddle.linalg import fp8_fp8_half_gemm_fused as fp8_gemm_fused try: - from paddlenlp_ops import ( + from paddlenlp.custom_ops import ( dequant_int8, encode_rotary_qk, qkv_transpose_split, @@ -1034,7 +1036,7 @@ def forward( if self.config.append_attn: - from paddlenlp_ops import get_block_shape_and_split_kv_block + from paddlenlp.custom_ops import get_block_shape_and_split_kv_block ( kwargs["encoder_batch_ids"], @@ -1649,7 +1651,7 @@ def forward( step_idx=None, **kwargs, ): - from paddlenlp_ops import xft_transformer + from paddlenlp.custom_ops import xft_transformer xft_out = xft_transformer( paddle.cast(src, "float32"), # input @@ -2028,7 +2030,7 @@ def compute_out_linear(self, fmha_out, i): out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype) else: if self.use_gemm_dequant: - from paddlenlp_ops import gemm_dequant + from paddlenlp.custom_ops import gemm_dequant out_linear_out = gemm_dequant( fmha_out, self.linear_weights[i], self.linear_out_scales[i], self._dtype @@ -2089,7 +2091,7 @@ def compute_ffn2(self, ffn1_out, i): ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype) else: if self.use_gemm_dequant: - from paddlenlp_ops import gemm_dequant + from paddlenlp.custom_ops import gemm_dequant ffn2_out = gemm_dequant(ffn1_out, self.ffn2_weights[i], self.ffn2_out_scales[i], self._dtype) else: @@ -2149,7 +2151,7 @@ def compute_attn( **kwargs, ): if self.config.append_attn: - from paddlenlp_ops import append_attention + from paddlenlp.custom_ops import append_attention fmha_out = append_attention( qkv_out, @@ -2343,7 +2345,7 @@ def compute_attn( cache_quant_type_str = "cache_int8" if self.config.append_attn: - from paddlenlp_ops import append_attention + from paddlenlp.custom_ops import append_attention fmha_out = append_attention( qkv_out, @@ -2702,7 +2704,7 @@ def compute_attn( cache_quant_type_str = "cache_int8" if self.config.append_attn: - from paddlenlp_ops import append_attention + from paddlenlp.custom_ops import append_attention fmha_out = append_attention( qkv_out, diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index e224f7b2a1c9..3ec185bf2b1e 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -205,7 +205,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e model_kwargs["stop_flags"] = paddle.logical_or(model_kwargs["stop_flags"], length_cond) if cache is None: next_tokens = paddle.where(just_decoder, paddle.full_like(next_tokens, -1), next_tokens) - from paddlenlp_ops import set_stop_value_multi_ends + from paddlenlp.custom_ops import set_stop_value_multi_ends next_tokens, model_kwargs["stop_flags"] = set_stop_value_multi_ends( next_tokens, model_kwargs["stop_flags"], eos_token_id, 2 @@ -305,7 +305,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): ) # not update when continue decode else: step_idx = model_kwargs["step_idx"] - from paddlenlp_ops import set_value_by_flags_and_idx + from paddlenlp.custom_ops import set_value_by_flags_and_idx model_kwargs["stop_flags"] = set_value_by_flags_and_idx( model_kwargs["pre_ids"], @@ -318,7 +318,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) - from paddlenlp_ops import get_token_penalty_multi_scores + from paddlenlp.custom_ops import get_token_penalty_multi_scores logits = get_token_penalty_multi_scores( model_kwargs["pre_ids"], @@ -337,7 +337,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): # compute next_tokens if use_faster_top_p_sampling(): - from paddlenlp_ops import top_p_sampling_reject + from paddlenlp.custom_ops import top_p_sampling_reject next_tokens = top_p_sampling_reject(probs, top_p, 0) else: @@ -356,7 +356,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): else: model_kwargs["all_input_ids"] = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1) - from paddlenlp_ops import save_with_output + from paddlenlp.custom_ops import save_with_output save_with_output( next_tokens, @@ -560,7 +560,7 @@ def get_output_padding_offset(self, seq_lens_this_time, seq_lens_encoder, seq_le In the senerio of speculate decoding, the length of output token after rebuild_padding is no longer bsz. So we need to calculate the output_padding_offset after rebuild_padding. """ - from paddlenlp_ops import ( + from paddlenlp.custom_ops import ( speculate_get_output_padding_offset, speculate_get_seq_lens_output, ) @@ -693,7 +693,7 @@ def _post_process_( step_idx = model_kwargs["step_idx"] logits = paddle.cast(outputs, paddle.float32) - from paddlenlp_ops import set_preids_token_penalty_multi_scores + from paddlenlp.custom_ops import set_preids_token_penalty_multi_scores set_preids_token_penalty_multi_scores( model_kwargs["pre_ids"], @@ -718,7 +718,7 @@ def _post_process_( # compute next_tokens if use_faster_top_p_sampling(): - from paddlenlp_ops import top_p_sampling_reject + from paddlenlp.custom_ops import top_p_sampling_reject next_tokens = top_p_sampling_reject(probs, top_p, 0) else: @@ -727,7 +727,7 @@ def _post_process_( if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0) - from paddlenlp_ops import update_inputs_v2 + from paddlenlp.custom_ops import update_inputs_v2 update_inputs_v2( model_kwargs["stop_flags"], @@ -745,7 +745,7 @@ def _post_process_( model_kwargs["next_tokens"], ) - from paddlenlp_ops import save_output + from paddlenlp.custom_ops import save_output save_output( next_tokens, @@ -799,7 +799,7 @@ def _post_process_( step_idx = model_kwargs["step_idx"] logits = paddle.cast(outputs, paddle.float32) - from paddlenlp_ops import speculate_get_token_penalty_multi_scores + from paddlenlp.custom_ops import speculate_get_token_penalty_multi_scores speculate_get_token_penalty_multi_scores( model_kwargs["pre_ids"], @@ -821,7 +821,7 @@ def _post_process_( # sample probs = F.softmax(logits) - from paddlenlp_ops import ( + from paddlenlp.custom_ops import ( speculate_save_output, speculate_set_value_by_flags_and_idx, speculate_verify_and_update, @@ -1026,7 +1026,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e model_kwargs["stop_flags"] = paddle.logical_or(model_kwargs["stop_flags"], length_cond) if cache is None: next_tokens = paddle.where(just_decoder, paddle.full_like(next_tokens, -1), next_tokens) - from paddlenlp_ops import set_stop_value_multi_ends + from paddlenlp.custom_ops import set_stop_value_multi_ends next_tokens, model_kwargs["stop_flags"] = set_stop_value_multi_ends( next_tokens, model_kwargs["stop_flags"], eos_token_id @@ -1097,7 +1097,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): else: step_idx = model_kwargs["step_idx"] - from paddlenlp_ops import set_value_by_flags_and_idx + from paddlenlp.custom_ops import set_value_by_flags_and_idx model_kwargs["stop_flags"] = set_value_by_flags_and_idx( model_kwargs["pre_ids"], @@ -1109,7 +1109,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) - from paddlenlp_ops import get_token_penalty_multi_scores + from paddlenlp.custom_ops import get_token_penalty_multi_scores logits = get_token_penalty_multi_scores( model_kwargs["pre_ids"], @@ -1124,7 +1124,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): logits = logits / temperature probs = F.softmax(logits) - from paddlenlp_ops import xft_greedy_search + from paddlenlp.custom_ops import xft_greedy_search next_tokens = xft_greedy_search(probs) @@ -1138,7 +1138,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): else: model_kwargs["all_input_ids"] = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1) - from paddlenlp_ops import save_with_output + from paddlenlp.custom_ops import save_with_output save_with_output( next_tokens, diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index d021f858cac0..fdecf3c0ba6b 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -198,7 +198,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 59f47405910a..1bc92b606aae 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -112,7 +112,7 @@ def __init__(self, config: LlamaConfig): ) def forward(self, input): - from paddlenlp_ops import avx_weight_only + from paddlenlp.custom_ops import avx_weight_only return avx_weight_only(input, self.weight, self.alog, trans=False) @@ -687,7 +687,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -774,7 +774,7 @@ def forward( position_offset = 0 if not is_decoder and pre_caches is not None: position_offset = 128 - from paddlenlp_ops import fused_get_rotary_embedding + from paddlenlp.custom_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox @@ -1388,7 +1388,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/experimental/transformers/mixtral/modeling.py b/paddlenlp/experimental/transformers/mixtral/modeling.py index 27e638d9d9f1..02723969b059 100644 --- a/paddlenlp/experimental/transformers/mixtral/modeling.py +++ b/paddlenlp/experimental/transformers/mixtral/modeling.py @@ -367,7 +367,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -454,7 +454,7 @@ def forward( position_offset = 0 if not is_decoder and pre_caches is not None: position_offset = 128 - from paddlenlp_ops import fused_get_rotary_embedding + from paddlenlp.custom_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox @@ -1086,7 +1086,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/experimental/transformers/opt/modeling.py b/paddlenlp/experimental/transformers/opt/modeling.py index 3ce294c3b709..ad61b921c00c 100644 --- a/paddlenlp/experimental/transformers/opt/modeling.py +++ b/paddlenlp/experimental/transformers/opt/modeling.py @@ -147,7 +147,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time diff --git a/paddlenlp/experimental/transformers/proposers.py b/paddlenlp/experimental/transformers/proposers.py index f2a1d2b0a50f..202c90fce897 100644 --- a/paddlenlp/experimental/transformers/proposers.py +++ b/paddlenlp/experimental/transformers/proposers.py @@ -72,7 +72,7 @@ def run(self, model_inputs: dict[str, paddle.Tensor], **kargs): seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu() seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu() - from paddlenlp_ops import ngram_match + from paddlenlp.custom_ops import ngram_match ngram_match( self.input_ids_cpu, diff --git a/paddlenlp/experimental/transformers/qwen/modeling.py b/paddlenlp/experimental/transformers/qwen/modeling.py index 44852e6be50d..79106cad42c7 100644 --- a/paddlenlp/experimental/transformers/qwen/modeling.py +++ b/paddlenlp/experimental/transformers/qwen/modeling.py @@ -239,7 +239,7 @@ def set_state_dict(self, state_dict): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -326,7 +326,7 @@ def forward( if not is_decoder and pre_caches is not None: position_offset = 128 - from paddlenlp_ops import fused_get_rotary_embedding + from paddlenlp.custom_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 6098079d9084..1344256621b3 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -925,7 +925,7 @@ def set_state_dict(self, state_dict): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -1012,7 +1012,7 @@ def forward( if not is_decoder and pre_caches is not None: position_offset = 128 - from paddlenlp_ops import fused_get_rotary_embedding + from paddlenlp.custom_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox @@ -1233,7 +1233,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py index 1aa0969b4a11..86681f045689 100644 --- a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py @@ -472,7 +472,7 @@ def set_state_dict(self, state_dict): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset + from paddlenlp.custom_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time @@ -559,7 +559,7 @@ def forward( if not is_decoder and pre_caches is not None: position_offset = 128 - from paddlenlp_ops import fused_get_rotary_embedding + from paddlenlp.custom_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox @@ -777,7 +777,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time, draft_tokens=None, seq_lens_encoder=None): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) - from paddlenlp_ops import get_padding_offset_v2 + from paddlenlp.custom_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time, draft_tokens, seq_lens_encoder diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index b3faf2463dff..936309bbd559 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -223,7 +223,7 @@ def balanced_ring_flash_attention_bwd_func( attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) try: - from paddlenlp_ops import flash_attn_bwd + from paddlenlp.custom_ops import flash_attn_bwd except (ImportError, ModuleNotFoundError): from paddlenlp.utils.log import logger diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index d5fa8dc76354..15f9ccf464de 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -607,7 +607,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q logger.info("Start read result message") logger.info(f"Current path is {os.getcwd()}") - from paddlenlp_ops import get_output + from paddlenlp.custom_ops import get_output while True: get_output(output_tensor, 0, True) @@ -641,7 +641,7 @@ def speculate_read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_q logger.info("Start speculate read result message") logger.info(f"Current path is {os.getcwd()}") - from paddlenlp_ops import speculate_get_output + from paddlenlp.custom_ops import speculate_get_output while True: speculate_get_output(output_tensor, 0, True) From e237bc592656613b4a864354ef6ddef310beff84 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Fri, 17 Jan 2025 06:35:05 +0000 Subject: [PATCH 2/3] fix --- llm/output.json | 1 + paddlenlp/custom_ops/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 llm/output.json diff --git a/llm/output.json b/llm/output.json new file mode 100644 index 000000000000..4368c58cd712 --- /dev/null +++ b/llm/output.json @@ -0,0 +1 @@ +{"src": "解释一下温故而知新", "tgt": "", "output": "\"温故而知新\",这句话出自《论语·为政》。这句话的意思是,通过回顾和复习旧的知识,可以加深对新知识的理解和记忆。这句话强调了学习的重要性,因为只有通过反复学习和复习,才能真正掌握新知识,从而更好地理解和应用新知识。"} diff --git a/paddlenlp/custom_ops/__init__.py b/paddlenlp/custom_ops/__init__.py index 00a224c0b6b6..e8b0768b835a 100644 --- a/paddlenlp/custom_ops/__init__.py +++ b/paddlenlp/custom_ops/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from custom_ops import * +from .custom_ops import * From 768daa7ed0d4a968df58e86b2f1cabe8529e6f61 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Fri, 17 Jan 2025 06:37:45 +0000 Subject: [PATCH 3/3] fix --- csrc/utils/tune_cutlass_fp8_dual_gemm.py | 3 +-- llm/output.json | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 llm/output.json diff --git a/csrc/utils/tune_cutlass_fp8_dual_gemm.py b/csrc/utils/tune_cutlass_fp8_dual_gemm.py index 68b22943d5ea..6fa6b32e4dc1 100644 --- a/csrc/utils/tune_cutlass_fp8_dual_gemm.py +++ b/csrc/utils/tune_cutlass_fp8_dual_gemm.py @@ -15,8 +15,7 @@ import argparse import paddle - -from paddlenlp.custom_ops import cutlass_fp8_fp8_fp8_dual_gemm_fused +from paddlenlp_ops import cutlass_fp8_fp8_fp8_dual_gemm_fused def setup_args(): diff --git a/llm/output.json b/llm/output.json deleted file mode 100644 index 4368c58cd712..000000000000 --- a/llm/output.json +++ /dev/null @@ -1 +0,0 @@ -{"src": "解释一下温故而知新", "tgt": "", "output": "\"温故而知新\",这句话出自《论语·为政》。这句话的意思是,通过回顾和复习旧的知识,可以加深对新知识的理解和记忆。这句话强调了学习的重要性,因为只有通过反复学习和复习,才能真正掌握新知识,从而更好地理解和应用新知识。"}