diff --git a/docker/Dockerfile b/docker/Dockerfile index 439ecddb34..bba404c965 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ -ARG CUDA_VERSION=12.8.0 +ARG CUDA_VERSION=13.0.0 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.16.0 +ARG VLLM_VERSION=0.21.0 +ARG NIXL_REF=v1.1.0 ARG FLASH_MLA_REF=47c35a7 +ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 ARG ENABLE_CACHE=1 +ARG ENABLE_SM100=0 ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda @@ -44,13 +47,18 @@ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip -RUN pip install -r /lightllm/requirements.txt --no-cache-dir -RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 \ + vllm==${VLLM_VERSION} +RUN pip install -r /lightllm/requirements.txt --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu130 +RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \ + git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ cd /root/FlashMLA && \ git checkout ${FLASH_MLA_REF} && \ git submodule update --init --recursive && \ - FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . + FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \ + pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* @@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \ RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \ set -e; \ ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \ - NVSHMEM_VERSION=3.3.9; \ - CUDA_ARCHS=90; \ - wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \ - && cd nvshmem \ - && rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \ - && NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \ - && cmake --build build --target install -j64; \ - DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \ - cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \ - cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \ + python -m pip install --upgrade --no-deps \ + "nvidia-nccl-cu13==2.30.4" \ + "nvidia-nvshmem-cu13==3.6.5"; \ + cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \ + ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \ + pip install --no-build-isolation .; \ fi +RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \ + git submodule update --init --recursive && \ + pip install --no-build-isolation . + RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ @@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \ apt-get update && apt-get install -y pkg-config tmux net-tools && \ cd /usr/local/src; \ pip install --upgrade meson pybind11 patchelf; \ - git clone https://github.com/ai-dynamo/nixl.git -b main && \ + git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \ cd nixl && \ rm -rf build && \ mkdir build && \ diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 355d6c65b3..bc1fd73da3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -18,21 +18,23 @@ set -euo pipefail # --no-nixl Disable NIXL (default: enabled) # --no-cache Disable cache (default: enabled) # --lite Disable DEEPEP, NIXL and cache in one shot -# --cuda-version CUDA version (default: 12.8.0) +# --cuda-version CUDA version (default: 13.0.0) # --image-prefix Image prefix (default: lightllm) # --image-tag Image tag (default: generated from enabled features) +# --enable-sm100 Enable SM100 support (default: disabled) # -h / --help Show help ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" cd "${ROOT_DIR}" IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}" -CUDA_VERSION="${CUDA_VERSION:-12.8.0}" +CUDA_VERSION="${CUDA_VERSION:-13.0.0}" IMAGE_TAG="${IMAGE_TAG:-}" ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}" ENABLE_NIXL="${ENABLE_NIXL:-1}" ENABLE_CACHE="${ENABLE_CACHE:-1}" +ENABLE_SM100="${ENABLE_SM100:-0}" print_help() { sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//' @@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do --no-deepep) ENABLE_DEEPEP=0 ;; --no-nixl) ENABLE_NIXL=0 ;; --no-cache) ENABLE_CACHE=0 ;; + --enable-sm100) ENABLE_SM100=1 ;; --lite) ENABLE_DEEPEP=0 ENABLE_NIXL=0 @@ -78,13 +81,16 @@ done # - Other combos: composed from enabled feature names if [[ -z "${IMAGE_TAG}" ]]; then tag_parts=() + if [[ "${ENABLE_SM100}" -eq 1 ]]; then + tag_parts+=("sm100") + fi if [[ "${ENABLE_NIXL}" -eq 1 ]]; then tag_parts+=("nixl") fi if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then tag_parts+=("deepep") fi - if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then + if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then IMAGE_TAG="cuda${CUDA_VERSION}" else prefix="" @@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ + --build-arg ENABLE_SM100="${ENABLE_SM100}" \ --progress=plain \ -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . - diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index c19cc92667..8e7f9d78e8 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -464,6 +464,14 @@ PD 分离模式参数 示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。 +.. option:: --expert_dtype + + EP MoE 专家量化类型,可选值: + + * ``fp8`` + * ``fp4``,仅支持 SM100 GPU + * ``None`` (默认) + .. option:: --vit_quant_type ViT 量化方法,可选值: diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index ad5b381304..84785de3b7 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -465,6 +465,14 @@ Quantization Parameters Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml. +.. option:: --expert_dtype + + Expert quantization dtype for EP MoE, optional values: + + * ``fp8`` + * ``fp4``: SM100 GPUs only + * ``None`` (default) + .. option:: --vit_quant_type ViT quantization method, optional values: diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 23e8d36da9..473dcbafda 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -85,6 +85,7 @@ def __init__(self, kvargs): self.disable_cudagraph = kvargs.get("disable_cudagraph", False) self.quant_type = kvargs.get("quant_type", "none") self.quant_cfg_path = kvargs.get("quant_cfg", None) + self.expert_dtype = kvargs.get("expert_dtype", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -156,7 +157,7 @@ def _verify_params(self): return def _init_quant(self): - self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) + self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..8bcf99b992 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -33,6 +33,7 @@ class BufNode: inner_tensor: torch.Tensor shape_key: Tuple[int, torch.dtype] storage_weak_ptr: int + free_use_count_bias: int = 0 shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict) def __del__(self): @@ -99,7 +100,8 @@ def alloc_tensor( # 回收可能消亡的 tensor for ptr in self.changed_ptr: t_buf_node = self.ptr_to_bufnode[ptr] - if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor): + free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) + if self.use_count(ptr) <= free_use_count: self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) self.changed_ptr.clear() @@ -131,6 +133,7 @@ def alloc_tensor( self.ptr_to_bufnode[storage_weak_ptr] = buf_node if shape not in buf_node.shape_to_tensor: buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape) + buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor)) mark_tensor = buf_node.shape_to_tensor[shape] ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断 ans.storage_weak_ptr = buf_node.storage_weak_ptr diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index c9b8cfa3eb..4d4614c007 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -4,11 +4,16 @@ from lightllm.distributed import dist_group_manager from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( - fused_experts_impl, + fused_experts, + get_ep_num_sms, masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, + deepgemm_grouped_fp8_nt_contiguous, + quantize_fused_experts_input, ) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, @@ -72,23 +77,15 @@ def _fused_experts( router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, ): - w13_weight, w13_scale = w13.weight, w13.weight_scale - w2_weight, w2_scale = w2.weight, w2.weight_scale - use_fp8_w8a8 = self.quant_method.method_name != "none" - output = fused_experts_impl( + output = fused_experts( hidden_states=input_tensor, - w1=w13_weight, - w2=w2_weight, + w13=w13, + w2=w2, topk_weights=topk_weights, topk_idx=topk_ids.to(torch.long), num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, + quant_method=self.quant_method, is_prefill=is_prefill, - use_fp8_w8a8=use_fp8_w8a8, - use_fp8_all2all=use_fp8_w8a8, - use_int8_w8a16=False, # default to False - w1_scale=w13_scale, - w2_scale=w2_scale, previous_event=None, # for overlap ) return output @@ -118,13 +115,13 @@ def low_latency_dispatch( ) topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() use_fp8_w8a8 = self.quant_method.method_name != "none" - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch( + topk_idx=topk_idx, + x=hidden_states, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + num_experts=self.total_expert_num_contain_redundancy, use_fp8=use_fp8_w8a8, async_finish=False, return_recv_hook=True, @@ -155,13 +152,8 @@ def select_experts_and_quant_input( num_expert_group=n_group, scoring_func=scoring_func, ) - w13_weight, w13_scale = w13.weight, w13.weight_scale - block_size_k = 0 - if w13_weight.ndim == 3: - block_size_k = w13_weight.shape[2] // w13_scale.shape[2] - assert block_size_k == 128, "block_size_k must be 128" - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method) + return topk_weights, topk_idx.to(torch.long), qinput_tensor def dispatch( self, @@ -171,38 +163,26 @@ def dispatch( overlap_event: Optional[Any] = None, ): buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch( qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, + num_experts=self.total_expert_num_contain_redundancy, + num_max_tokens_per_rank=num_max_tokens_per_rank, expert_alignment=128, + num_sms=get_ep_num_sms(), + previous_event=overlap_event, + async_with_compute_stream=True, + allocate_on_comm_stream=True, + do_cpu_sync=True, + do_handle_copy=False, ) def hook(): event.current_stream_wait() - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, @@ -281,7 +261,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -295,7 +275,7 @@ def prefilled_group_gemm( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - _deepgemm_grouped_fp8_nt_contiguous( + deepgemm_grouped_fp8_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices ) # gather and local reduce @@ -319,7 +299,7 @@ def low_latency_combine( topk_weights: torch.Tensor, handle: Any, ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True ) return combined_x, hook @@ -335,8 +315,9 @@ def combine( gemm_out_b, handle, topk_weights=None, - async_finish=True, + num_sms=get_ep_num_sms(), previous_event=overlap_event, + async_with_compute_stream=True, allocate_on_comm_stream=True, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..cb2e370cb9 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -1,10 +1,8 @@ """Fused MoE kernel.""" -import os import torch import triton -import triton.language as tl from typing import Any, Callable, Dict, Optional, Tuple -import torch.distributed as dist +from lightllm.distributed import dist_group_manager from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( @@ -15,11 +13,16 @@ tma_align_input_scale, ) from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.utils.envs_utils import ( + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, +) from lightllm.common.triton_utils.autotuner import Autotuner -import numpy as np +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +_MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {} +SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32") try: from deep_ep import Buffer, EventOverlap @@ -31,6 +34,29 @@ HAS_DEEPGEMM = False +def get_ep_num_sms() -> int: + return getattr(dist_group_manager, "ep_num_sms", None) or 0 + + +def use_sm100_mega_moe(quant_method: Any) -> bool: + return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32" + + +def check_ep_expert_dtype(quant_method: Any): + expert_dtype = getattr(quant_method, "method_name", None) + if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES: + raise ValueError( + "EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], " + f"but the resolved fused_moe quant method is `{expert_dtype}`. " + "Please start with --expert_dtype fp8 or --expert_dtype fp4. " + "Note that --expert_dtype fp4 is only supported on SM100 GPUs." + ) + if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu(): + raise RuntimeError( + "--expert_dtype fp4 requires an SM100 GPU for EP MoE; " "please use --expert_dtype fp8 on non-SM100 GPUs." + ) + + def masked_group_gemm( recv_x: Tuple[torch.Tensor, torch.Tensor], masked_m: torch.Tensor, @@ -59,6 +85,138 @@ def masked_group_gemm( return gemm_out_b +def _get_mega_moe_cache_state(w13: Any, w2: Any): + state_key = ( + w13.weight.data_ptr(), + w13.weight_scale.data_ptr(), + w2.weight.data_ptr(), + w2.weight_scale.data_ptr(), + ) + return _MEGA_MOE_STATES.setdefault(state_key, {}) + + +def _get_mega_moe_weights(w13: Any, w2: Any, state: Dict[str, Any]): + if "weight_cache" not in state: + state["weight_cache"] = deep_gemm.transform_weights_for_mega_moe( + (w13.weight, w13.weight_scale), + (w2.weight, w2.weight_scale), + ) + return state["weight_cache"] + + +def _get_mega_moe_cumulative_stats(num_local_experts: int, device: torch.device, state: Dict[str, Any]): + stats = state.get("stats") + if stats is None or stats.numel() != num_local_experts or stats.device != device: + stats = torch.zeros((num_local_experts,), device=device, dtype=torch.int32) + state["stats"] = stats + return stats + + +def mega_moe_impl( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_method: Any, +): + if not (HAS_DEEPGEMM and hasattr(deep_gemm, "fp8_fp4_mega_moe")): + raise RuntimeError("deep_gemm does not provide fp8-fp4 Mega MoE kernel") + + from deep_gemm.utils import per_token_cast_to_fp8 + + buffer = getattr(dist_group_manager, "ep_mega_moe_buffer", None) + if buffer is None: + raise RuntimeError("SM100 Mega MoE requires dist_group_manager.ep_mega_moe_buffer to be initialized") + + num_tokens = hidden_states.shape[0] + if num_tokens > buffer.num_max_tokens_per_rank: + raise RuntimeError( + f"Mega MoE got {num_tokens} tokens, exceeding num_max_tokens_per_rank={buffer.num_max_tokens_per_rank}" + ) + + qinput_tensor = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + state = _get_mega_moe_cache_state(w13, w2) + l1_weights, l2_weights = _get_mega_moe_weights(w13, w2, state) + stats = _get_mega_moe_cumulative_stats(w13.weight.shape[0], hidden_states.device, state) + buffer.x[:num_tokens].copy_(qinput_tensor[0]) + buffer.x_sf[:num_tokens].copy_(qinput_tensor[1]) + buffer.topk_idx[:num_tokens].copy_(topk_ids) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + output = torch.empty_like(hidden_states) + deep_gemm.fp8_fp4_mega_moe( + output, + l1_weights, + l2_weights, + buffer, + cumulative_local_expert_recv_stats=stats, + ) + return output + + +def quantize_fused_experts_input( + hidden_states: torch.Tensor, + w13: Any, + quant_method: Any, +): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): + from deep_gemm.utils import per_token_cast_to_fp8 + + return per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=quant_method.block_size, + use_packed_ue8m0=True, + ) + + block_size_k = 0 + if w13.weight.ndim == 3: + block_size_k = w13.weight.shape[2] // w13.weight_scale.shape[2] + assert block_size_k == 128, "block_size_k must be 128" + return per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13.weight.dtype) + + +def fused_experts( + hidden_states: torch.Tensor, + w13: Any, + w2: Any, + topk_weights: torch.Tensor, + topk_idx: torch.Tensor, + num_experts: int, + quant_method: Any, + is_prefill: Optional[bool], + previous_event: Optional[Any] = None, +): + check_ep_expert_dtype(quant_method) + if use_sm100_mega_moe(quant_method): + return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method) + + buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer + return fused_experts_impl( + hidden_states=hidden_states, + w1=w13.weight, + w2=w2.weight, + topk_weights=topk_weights, + topk_idx=topk_idx, + num_experts=num_experts, + buffer=buffer, + is_prefill=is_prefill, + use_fp8_w8a8=True, + use_fp8_all2all=True, + use_int8_w8a16=False, + w1_scale=w13.weight_scale, + w2_scale=w2.weight_scale, + previous_event=previous_event, + ) + + def fused_experts_impl( hidden_states: torch.Tensor, # [M, K] w1: torch.Tensor, # [group, N, K] @@ -66,14 +224,14 @@ def fused_experts_impl( topk_weights: torch.Tensor, # [M, topk] topk_idx: torch.Tensor, # [M, topk] num_experts: int, - buffer: "Buffer", + buffer: Any, is_prefill: bool, use_fp8_w8a8: bool = False, use_fp8_all2all: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - previous_event: Optional["EventOverlap"] = None, + previous_event: Optional[Any] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -99,39 +257,27 @@ def fused_experts_impl( combined_x = None if is_prefill: qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False - ) - + allocate_on_comm_stream = previous_event is not None # normal dispatch # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] # recv_topk_idx [recive_num_tokens, topk_num] # recv_topk_weights [recive_num_tokens, topk_num] # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch( (qinput_tensor, input_scale), topk_idx=topk_idx, topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + num_experts=num_experts, + num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(), expert_alignment=128, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + do_cpu_sync=True, + do_handle_copy=False, ) # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: @@ -149,7 +295,7 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -169,7 +315,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) input_tensor[1] = tma_align_input_scale(input_tensor[1]) - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) + deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel @@ -183,7 +329,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) + deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -202,13 +348,12 @@ def fused_experts_impl( gather_out, handle, topk_weights=None, - async_finish=False, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=allocate_on_comm_stream, ) else: # low latency dispatch - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts) recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch( hidden_states, @@ -228,7 +373,7 @@ def fused_experts_impl( return combined_x -def _deepgemm_grouped_fp8_nt_contiguous( +def deepgemm_grouped_fp8_nt_contiguous( input_tuple: Tuple[torch.Tensor, torch.Tensor], w_tuple: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, @@ -255,3 +400,22 @@ def _deepgemm_grouped_fp8_nt_masked( if hasattr(deep_gemm, "m_grouped_gemm_fp8_fp8_bf16_nt_masked"): return deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(input_tuple, w_tuple, out, masked_m, expected_m) raise RuntimeError("deep_gemm does not provide grouped_gemm_fp8 NT contiguous GEMM kernel in this version") + + +def deepgemm_grouped_fp8_fp4_nt_contiguous( + input_tuple: Tuple[torch.Tensor, torch.Tensor], + w_tuple: Tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor, + grouped_layout: torch.Tensor, + use_psum_layout: bool = False, +): + if HAS_DEEPGEMM and hasattr(deep_gemm, "m_grouped_fp8_fp4_gemm_nt_contiguous"): + return deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + input_tuple, + w_tuple, + out, + grouped_layout, + use_psum_layout=use_psum_layout, + recipe=(1, 1, 32), + ) + raise RuntimeError("deep_gemm does not provide grouped fp8-fp4 NT contiguous GEMM kernel") diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0454c86628..0a1deba499 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -67,8 +67,7 @@ def profile_size(self, mem_fraction): torch.cuda.empty_cache() world_size = dist.get_world_size() - - available_memory = get_available_gpu_memory(world_size) * mem_fraction + available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) if world_size > 1: diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1f08432c6a..cd534d53ec 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -7,18 +7,36 @@ from .awq import * from .no_quant import * from lightllm.utils.log_utils import init_logger +from lightllm.utils.device_utils import is_sm100_gpu logger = init_logger(__name__) +EXPERT_DTYPE_TO_QUANT_TYPE = { + "fp8": "deepgemm-fp8w8a8-b128", + "fp4": "deepgemm-fp4fp8-b32", +} +SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) + class Quantcfg: - def __init__(self, network_config, quant_type="none", custom_cfg_path=None): + def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None): self.layer_num = network_config["n_layer"] self.quant_type = quant_type + self.expert_dtype = expert_dtype self.network_config_ = network_config self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) + def _get_expert_quant_type(self, expert_dtype): + quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype) + if quant_type is None: + raise ValueError( + f"unsupported expert_dtype `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}" + ) + if expert_dtype == "fp4" and not is_sm100_gpu(): + raise RuntimeError("expert_dtype `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") + return quant_type + def _parse_network_config(self, network_config): hf_quantization_config = network_config.get("quantization_config", None) if hf_quantization_config is None: @@ -44,6 +62,19 @@ def _mapping_quant_method(self): else: self.quant_type = "vllm-fp8w8a8-b128" logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") + + # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, + # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 + expert_dtype = self.expert_dtype or self.network_config_.get("expert_dtype", None) + if expert_dtype is None: + return + target = self._get_expert_quant_type(expert_dtype) + for layer_num in range(self.layer_num): + if self.expert_dtype is not None: + self.quant_cfg[layer_num]["fused_moe"] = target + else: + self.quant_cfg[layer_num].setdefault("fused_moe", target) + logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") elif self.hf_quantization_method == "awq": self.quant_type = "awq" if is_awq_marlin_compatible(self.hf_quantization_config): diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..ec1ee90fd4 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -126,6 +126,78 @@ def _create_weight( return mm_param, mm_param_list +@QUANTMETHODS.register(["deepgemm-fp4fp8-b32"], platform="cuda") +class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 32 + self.weight_suffix = "weight" + self.weight_zero_point_suffix = None + self.weight_scale_suffix = None + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp4fp8-b32" + + def quantize(self, weight: torch.Tensor, output: WeightPack): + from deep_gemm.utils import per_token_cast_to_fp4 + import deep_gemm + + weight = weight.cuda(output.weight.device) + if weight.dim() == 2: + n, k = weight.shape + packed_weight, weight_scale = per_token_cast_to_fp4(weight, use_ue8m0=True, gran_k=self.block_size) + weight_scale = deep_gemm.transform_sf_into_required_layout(weight_scale, n, k, (1, self.block_size), None) + else: + num_groups, n, k = weight.shape + packed_weight = torch.empty((num_groups, n, k // 2), device=weight.device, dtype=torch.int8) + weight_scale = torch.empty((num_groups, n, k // self.block_size), device=weight.device, dtype=torch.float32) + for i in range(num_groups): + packed_weight[i], weight_scale[i] = per_token_cast_to_fp4( + weight[i], use_ue8m0=True, gran_k=self.block_size + ) + weight_scale = deep_gemm.transform_sf_into_required_layout( + weight_scale, n, k, (1, self.block_size), num_groups + ) + output.weight.copy_(packed_weight) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("deepgemm-fp4fp8-b32 is only implemented for fused MoE expert weights") + + def _create_weight( + self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> Tuple[WeightPack, List[WeightPack]]: + out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims + assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension" + assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size" + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda( + device_id + ) + mm_param = WeightPack(weight=weight, weight_scale=weight_scale) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + weight_scale_out_dims=out_dims, + weight_scale_split_dim=-2, + ) + return mm_param, mm_param_list + + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f01f1c87f7..f15badde25 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,7 +27,8 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( get_env_start_args, - get_deepep_num_max_dispatch_tokens_per_rank, + get_deepep_num_max_dispatch_tokens_per_rank_prefill, + get_deepep_num_max_dispatch_tokens_per_rank_decode, get_redundancy_expert_num, ) from lightllm.utils.dist_utils import ( @@ -36,7 +37,7 @@ create_new_group_for_current_dp, create_dp_special_inter_group, ) -from lightllm.utils.device_utils import get_device_sm_count +from lightllm.utils.device_utils import get_device_sm_count, is_sm100_gpu from lightllm.utils.torch_dtype_utils import get_torch_dtype logger = init_logger(__name__) @@ -106,6 +107,10 @@ def all_gather_into_tensor(self, output_: torch.Tensor, input_: torch.Tensor, as class DistributeGroupManager: def __init__(self): self.groups = [] + self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None def __len__(self): return len(self.groups) @@ -127,52 +132,92 @@ def get_default_group(self) -> CustomProcessGroup: def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] - def new_deepep_group(self, n_routed_experts, hidden_size): + def new_deepep_group( + self, + n_routed_experts, + hidden_size, + num_experts_per_tok: int = 1, + moe_intermediate_size: Optional[int] = None, + ): enable_ep_moe = get_env_start_args().enable_ep_moe - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + prefill_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill() + decode_num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode() if not enable_ep_moe: self.ep_buffer = None + self.ep_low_latency_buffer = None + self.ep_mega_moe_buffer = None + self.ep_num_sms = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" - self._set_num_sms_for_deep_gemm() global_world_size = get_global_world_size() deepep_group = dist.new_group(list(range(global_world_size))) - low_latency_mode, num_rdma_bytes = True, 0 - if low_latency_mode: - self.ll_num_tokens, self.ll_hidden = num_max_dispatch_tokens_per_rank, hidden_size - self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - self.ll_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts - ) - self.ep_buffer = deep_ep.Buffer( + self.ll_num_tokens = prefill_num_max_dispatch_tokens_per_rank + self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank + self.ll_hidden = hidden_size + self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size + self.ep_buffer = deep_ep.ElasticBuffer( deepep_group, - int(1e9), - num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=(self.ll_num_experts // global_world_size if low_latency_mode else 1), + num_max_tokens_per_rank=self.ll_num_tokens, + hidden=self.ll_hidden, + num_topk=num_experts_per_tok, + use_fp8_dispatch=True, + allow_multiple_reduction=False, ) + self.ep_mega_moe_buffer = None + self.ep_low_latency_buffer = None + if not is_sm100_gpu(): + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts + ) + self.ep_low_latency_buffer = deep_ep.Buffer( + deepep_group, + int(1e9), + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=(self.ll_num_experts // global_world_size), + ) + else: + if moe_intermediate_size is None: + raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config") + + import deep_gemm + + self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + deepep_group, + self.ll_num_experts, + self.ll_num_tokens, + num_experts_per_tok, + self.ll_hidden, + moe_intermediate_size, + ) + theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok) + self._set_num_sms_for_deep_gemm(theoretical_sms) - def _set_num_sms_for_deep_gemm(self): + def _set_num_sms_for_deep_gemm(self, deepep_sms: int): try: try: from deep_gemm.jit_kernels.utils import set_num_sms except: from deep_gemm import set_num_sms - deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) device_sms = get_device_sm_count() - deep_ep.Buffer.set_num_sms(deepep_sms) - set_num_sms(device_sms - deepep_sms) + deepep_sms = max(0, min(deepep_sms, max(device_sms - 2, 0))) + self.ep_num_sms = deepep_sms + if self.ep_low_latency_buffer is not None: + deep_ep.Buffer.set_num_sms(deepep_sms - deepep_sms % 2) + set_num_sms(max(device_sms - deepep_sms, 2)) except BaseException as e: logger.warning(f"set num sms for deep_gemm failed: {e}") def clear_deepep_buffer(self): """ - prefill 之后需要clean 一下,ep buffer 才能正常执行 decode。 + Prefill after using ElasticBuffer may leave the legacy low-latency buffer dirty for decode. """ - if hasattr(self, "ep_buffer") and self.ep_buffer is not None: - self.ep_buffer.clean_low_latency_buffer(self.ll_num_tokens, self.ll_hidden, self.ll_num_experts) + if self.ep_low_latency_buffer is not None: + self.ep_low_latency_buffer.clean_low_latency_buffer( + self.ll_decode_num_tokens, self.ll_hidden, self.ll_num_experts + ) def all_reduce( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..be819c94a0 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -7,6 +7,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import get_env_start_args @@ -295,7 +296,7 @@ def overlap_tpsp_token_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -421,7 +422,7 @@ def overlap_tpsp_context_forward( infer_state1: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -447,9 +448,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -486,8 +487,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 shared expert if self.n_shared_experts is not None: @@ -518,7 +518,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -533,7 +533,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..ea6620b4e4 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -48,7 +48,12 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 899531448b..d6eaebe2fd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -227,7 +227,15 @@ def _get_indices( import deep_gemm - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + (k_fp8_, k_scale_), + weights.squeeze(-1), + ks, + ke, + clean_logits=False, + max_seqlen_k=infer_state.max_kv_seq_len, + ) from sgl_kernel import fast_topk_v2 @@ -235,7 +243,6 @@ def _get_indices( score=logits, lengths=lengths, topk=self.index_topk, - row_starts=ks, ) b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1) # 将 topk index 转化为 mem index @@ -251,7 +258,7 @@ def _get_indices( @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 - from sgl_kernel import hadamard_transform + from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index d0f8b45f81..f02fc30942 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -112,4 +112,4 @@ def extract_indexer_ks( num_stages=1, ) - return O_fp8, O_scale + return O_fp8, O_scale.squeeze(-1) diff --git a/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py new file mode 100644 index 0000000000..eabf703f56 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/hadamard_transform.py @@ -0,0 +1,80 @@ +import functools + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _butterfly_stage(x, GROUPS: tl.constexpr, STEP: tl.constexpr, BLOCK_R: tl.constexpr, BLOCK_N: tl.constexpr): + x_grouped = tl.reshape(x, (BLOCK_R, GROUPS, 2, STEP)) + x_grouped = tl.permute(x_grouped, (0, 1, 3, 2)) + left, right = tl.split(x_grouped) + x_pair = tl.join(left + right, left - right) + x_pair = tl.permute(x_pair, (0, 1, 3, 2)) + return tl.reshape(x_pair, (BLOCK_R, BLOCK_N)) + + +@triton.jit +def _hadamard_transform_kernel( + X, + Y, + n_rows, + scale: tl.constexpr, + BLOCK_R: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows = pid * BLOCK_R + tl.arange(0, BLOCK_R) + mask = rows[:, None] < n_rows + offsets = rows[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) + + x = _butterfly_stage(x, 64, 1, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 32, 2, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 16, 4, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 8, 8, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 4, 16, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 2, 32, BLOCK_R, BLOCK_N) + x = _butterfly_stage(x, 1, 64, BLOCK_R, BLOCK_N) + + tl.store(Y + offsets, x * scale, mask=mask) + + +@functools.lru_cache(maxsize=None) +def _target_programs(device_index: int) -> int: + return torch.cuda.get_device_properties(device_index).multi_processor_count * 2 + + +def _pick_block_r(rows: int, device_index: int) -> int: + block_r = triton.next_power_of_2(max(1, rows // _target_programs(device_index))) + return max(1, min(128, block_r)) + + +def _hadamard_transform_triton(x: torch.Tensor, scale: float) -> torch.Tensor: + original_shape = x.shape + hidden_size = x.size(-1) + if not x.is_contiguous(): + x = x.contiguous() + rows = x.numel() // hidden_size + out = torch.empty_like(x) + BLOCK_R = _pick_block_r(rows, x.device.index) + grid = (triton.cdiv(rows, BLOCK_R),) + _hadamard_transform_kernel[grid]( + x, + out, + rows, + scale, + BLOCK_R=BLOCK_R, + BLOCK_N=hidden_size, + num_warps=4, + ) + return out.view(original_shape) + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + assert x.is_cuda, "hadamard_transform only supports CUDA tensors" + assert x.dtype == torch.bfloat16, "hadamard_transform expects bfloat16 input" + assert x.size(-1) == 128, "DeepSeek-V3.2 Hadamard transform expects hidden size 128" + + return _hadamard_transform_triton(x, scale) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index e1df1ec7fd..10b1958b0e 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -130,7 +130,12 @@ def _init_att_backend1(self): def _init_custom(self): self._init_to_get_rotary_gemma4() if self.config.get("enable_moe_block", False): - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", self.config.get("top_k_experts", 1)), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) self._init_ple_static_buffer() def _init_ple_static_buffer(self): diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..1e31306aea 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -25,7 +25,12 @@ def _init_config(self): def _init_custom(self): self._init_to_get_yarn_rotary() - dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["n_routed_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) def _init_to_get_yarn_rotary(self): rope_scaling = self.config.get("rope_scaling") diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..8879aa2d27 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -6,6 +6,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.envs_utils import get_env_start_args @@ -133,7 +134,7 @@ def overlap_tpsp_token_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_token_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -245,7 +246,7 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight, ): - if not self.is_moe: + if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method): return super().overlap_tpsp_context_forward( input_embdings, input_embdings1, infer_state, infer_state1, layer_weight ) @@ -270,9 +271,9 @@ def overlap_tpsp_context_forward( _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _0_input1, _0_router_logits ) - from deep_ep import Buffer + from deep_ep import ElasticBuffer - _0_overlap_event = Buffer.capture() + _0_overlap_event = ElasticBuffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -308,8 +309,7 @@ def overlap_tpsp_context_forward( _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( _1_input1, _1_router_logits ) - - _1_overlap_event = Buffer.capture() + _1_overlap_event = ElasticBuffer.capture() # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( @@ -332,7 +332,7 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - _0_combine_event = Buffer.capture() + _0_combine_event = ElasticBuffer.capture() # 0 combine execute _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook @@ -347,7 +347,7 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None - _1_combine_event = Buffer.capture() + _1_combine_event = ElasticBuffer.capture() input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..0d4b45bfe6 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -27,4 +27,9 @@ def _init_custom(self): super()._init_custom() # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + dist_group_manager.new_deepep_group( + self.config["num_experts"], + self.config["hidden_size"], + self.config.get("num_experts_per_tok", 1), + self.config.get("moe_intermediate_size", self.config.get("intermediate_size")), + ) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4a8ee80a46..e3c51f3617 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -12,7 +12,6 @@ ) from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger -from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.kv_cache_mem_manager.qwen3next_mem_manager import Qwen3NextMemManager from lightllm.server.core.objs.start_args_type import StartArgs @@ -56,12 +55,6 @@ def _init_config(self): super()._init_config() self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) - def _init_custom(self): - super()._init_custom() - # Only initialize DeepEP group for MoE models with num_experts - if "num_experts" in self.config and self.config["num_experts"] > 0: - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) - def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 start_args: StartArgs = get_env_start_args() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6b30ab6874..2db6c67e77 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -620,6 +620,14 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Path of quantization config. It can be used for mixed quantization. Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.""", ) + parser.add_argument( + "--expert_dtype", + type=str, + default=None, + choices=["fp8", "fp4"], + help="""Expert quantization dtype for EP MoE. Supported values are + fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""", + ) parser.add_argument( "--vit_quant_type", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 249839b0a7..654ba0f3e5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -4,6 +4,7 @@ import uuid import subprocess import signal +import math from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager @@ -291,7 +292,10 @@ def normal_or_p_d_start(args): # linear att cache 参数自动设置 if args.linear_att_cache_size is None: # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 - args.linear_att_cache_size = args.running_max_req_size * 2 + default_cache_size = args.running_max_req_size * 2 + dp_size_in_node = max(1, args.dp // args.nnodes) + per_dp_cache_size = max(1, math.ceil(args.running_max_req_size / dp_size_in_node) * 2) + args.linear_att_cache_size = min(default_cache_size, per_dp_cache_size) if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 05ff2658e1..6d0ee07465 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -133,6 +133,7 @@ class StartArgs: graph_max_len_in_batch: int = field(default=0) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) + expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]}) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 045723d073..a41c2f265a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -171,6 +171,7 @@ async def wait_to_model_ready(self): "batch_max_tokens": self.args.batch_max_tokens, "quant_type": self.args.quant_type, "quant_cfg": self.args.quant_cfg, + "expert_dtype": self.args.expert_dtype, "pd_rpyc_ports": self.args.pd_node_infer_rpyc_ports, # 非 pd 模式可以不设置 } diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e47717747e..4fb6a0db9d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -164,6 +164,7 @@ def init_model(self, kvargs): "batch_max_tokens": kvargs.get("batch_max_tokens", None), "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), + "expert_dtype": kvargs.get("expert_dtype", None), "run_mode": self.run_mode, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) @@ -338,6 +339,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), "quant_cfg": main_kvargs.get("quant_cfg", None), + "expert_dtype": main_kvargs.get("expert_dtype", None), "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 6c5fe90309..ab5c0a88a1 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -196,12 +196,15 @@ def _validate_flashmla_sparse(): except Exception as e: return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}" - batch, heads, seq, dim = 1, 64, 128, 512 + 64 + batch, heads, seq = 1, 64, 128 + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_dim = kv_lora_rank + qk_rope_head_dim dtype = torch.bfloat16 device = "cuda" - q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device) - kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device) + q = torch.randn(batch * seq, heads, qk_dim, dtype=dtype, device=device) + kv = torch.zeros(batch * seq, 1, qk_dim, dtype=dtype, device=device) index_topk = 128 topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device) @@ -210,8 +213,7 @@ def _validate_flashmla_sparse(): topk_indices = topk_indices.view(batch * seq, 1, index_topk) - softmax_scale = 1.0 / (dim ** 0.5) - kv_lora_rank = dim + softmax_scale = 1.0 / (qk_dim ** 0.5) try: mla_out, _, _ = flash_mla_sparse_fwd( diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..58bff90560 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -40,6 +40,11 @@ def get_device_sm_count(): return properties["multiprocessor_count"] +@lru_cache(maxsize=None) +def is_sm100_gpu(): + return torch.cuda.get_device_capability()[0] == 10 + + @lru_cache(maxsize=None) def get_device_sm_regs_num(): import triton diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index e11da07c8c..12b0b81993 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -17,7 +17,7 @@ logger = init_logger(__name__) _CUSTOM_ALLREDUCE_WORLD_SIZES = (2, 4, 6, 8) -_TWO_GPU_CHECK_TIMEOUT_SECONDS = 60.0 +_TWO_GPU_CHECK_TIMEOUT_SECONDS = 600.0 def _start_two_gpu_check_timeout_watchdog(backend_name: str) -> threading.Event: @@ -84,6 +84,8 @@ def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> N input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) else: input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + if not flashinfer_all_reduce.should_use(input_tensor): + raise RuntimeError("FlashInferAllReduce unsupported for probe tensor") output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) dist.barrier() expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..2bdd4005fa 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -69,9 +69,22 @@ def enable_env_vars(args): @lru_cache(maxsize=None) -def get_deepep_num_max_dispatch_tokens_per_rank(): +def get_deepep_num_max_dispatch_tokens_per_rank_prefill(): + # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大。 + # 如果未显式配置,则默认至少覆盖当前进程的 `batch_max_tokens`,避免 DeepEP V2 在 autotune + # warmup 或大 prefill batch 时因为 buffer 上界过小而报错。 + configured = os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_PREFILL", None) + if configured is not None: + return int(configured) + + batch_max_tokens = get_env_start_args().batch_max_tokens or 256 + return ((int(batch_max_tokens) + 7) // 8) * 8 + + +@lru_cache(maxsize=None) +def get_deepep_num_max_dispatch_tokens_per_rank_decode(): # 该参数需要大于单卡最大batch size,且是8的倍数。该参数与显存占用直接相关,值越大,显存占用越大,如果出现显存不足,可以尝试调小该值 - return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK", 256)) + return int(os.getenv("NUM_MAX_DISPATCH_TOKENS_PER_RANK_DECODE", 256)) def get_lightllm_gunicorn_keep_alive(): diff --git a/requirements.txt b/requirements.txt index d37ae05690..f124ce76f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ mpmath==1.3.0 multiprocessing-logging==0.3.4 networkx==3.1 ninja==1.11.1 -numpy==1.25.1 +numpy==2.1.3 packaging==24.2 pip==23.0.1 pluggy==1.2.0 @@ -59,9 +59,9 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.9.1 +torch==2.11.0 tqdm==4.65.0 -transformers==4.57.1 +transformers==5.8.0 tokenizers==0.22.1 urllib3==1.26.16 uvicorn==0.19.0 @@ -71,7 +71,7 @@ zstandard==0.23.0 safetensors==0.4.5 Pillow==10.4.0 tiktoken==0.7.0 -matplotlib==3.8.2 +matplotlib==3.10.0 psutil==5.9.4 prometheus_client==0.20.0 cchardet==2.1.7 @@ -81,19 +81,20 @@ atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 flashinfer-python==0.6.8.post1 -sgl-kernel==0.3.21 +flashinfer-cubin==0.6.8.post1 +sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 -cuda_bindings==12.9.0 +cuda_bindings==13.2.0 orjson==3.11.2 setproctitle==1.3.6 xxhash==3.6.0 -torchvision==0.24.1 +torchvision==0.26.0 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 -cupy-cuda12x==13.6.0 -nixl==0.8.0 -xformers==0.0.33.post2 +cupy-cuda13x==14.0.1 +nixl==1.1.0 +xformers==0.0.35 redis==7.3.0 litellm>=1.52.0,<1.85 diff --git a/test/benchmark/service/benchmark_client.py b/test/benchmark/service/benchmark_client.py index 09009fc9e1..3f55bcab1e 100644 --- a/test/benchmark/service/benchmark_client.py +++ b/test/benchmark/service/benchmark_client.py @@ -27,6 +27,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_output_length(input_num: int, output_len: int) -> List[int]: min_len, max_len = 2, output_len * 2 mean = (min_len + max_len) * 0.5 @@ -162,7 +169,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index 7387237f4d..7019654c38 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -39,6 +39,8 @@ import random import threading import time +import urllib.parse +import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple, Union @@ -46,6 +48,15 @@ import requests from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +_DEFAULT_TRANSIENT_RETRIES = 2 +_PROMPT_LEN_OVERLAP_CHARS = 512 +_TRANSIENT_STREAM_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.Timeout, +) + def seed_all(seed: int) -> None: if not seed: @@ -59,6 +70,85 @@ def get_tokenizer(tokenizer_name: str) -> Union[PreTrainedTokenizer, PreTrainedT return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + +def get_models_url(completions_url: str) -> str: + parsed = urllib.parse.urlsplit(completions_url) + path = parsed.path.rstrip("/") + for suffix in ("/chat/completions", "/completions"): + if path.endswith(suffix): + path = path[: -len(suffix)] + "/models" + return urllib.parse.urlunsplit(parsed._replace(path=path, query="", fragment="")) + return urllib.parse.urlunsplit(parsed._replace(path="/v1/models", query="", fragment="")) + + +def fetch_served_model_names(completions_url: str, timeout_s: int = 10) -> List[str]: + models_url = get_models_url(completions_url) + request = urllib.request.Request(models_url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(request, timeout=timeout_s) as response: + payload = json.loads(response.read().decode("utf-8")) + return [item["id"] for item in payload.get("data", []) if item.get("id")] + + +def resolve_model_name( + completions_url: str, + requested_model_name: str, + explicit_model_name: bool, +) -> Tuple[str, Optional[str]]: + normalized_name = normalize_model_name(requested_model_name) + if normalized_name != requested_model_name: + note = f"Normalized model name from `{requested_model_name}` to `{normalized_name}`." + else: + note = None + + try: + served_model_names = fetch_served_model_names(completions_url) + except Exception as exc: + if note is not None: + note = f"{note} Failed to query served models: {exc}." + return normalized_name, note + + if requested_model_name in served_model_names: + return requested_model_name, note + if normalized_name in served_model_names: + if normalized_name != requested_model_name: + return normalized_name, ( + f"Normalized model name from `{requested_model_name}` to `{normalized_name}` " "to match `/v1/models`." + ) + return normalized_name, note + + requested_basename = os.path.basename(normalized_name) + basename_matches = [ + served_name + for served_name in served_model_names + if os.path.basename(normalize_model_name(served_name)) == requested_basename + ] + if len(basename_matches) == 1: + matched_name = basename_matches[0] + return matched_name, ( + f"Resolved model name `{requested_model_name}` to served model `{matched_name}` " "via `/v1/models`." + ) + + if not explicit_model_name and len(served_model_names) == 1: + matched_name = served_model_names[0] + return matched_name, ( + f"Using the only served model `{matched_name}` returned by `/v1/models` " + f"instead of `{requested_model_name}`." + ) + + if note is not None: + note = ( + f"{note} Available served models: {', '.join(served_model_names) or '(none)'}. " + f"Using `{normalized_name}`." + ) + return normalized_name, note + + def gen_random_token_ids(tokenizer, n: int, rng: random.Random) -> List[int]: vocab = tokenizer.vocab_size return [rng.randint(0, vocab - 1) for _ in range(n)] @@ -87,6 +177,7 @@ def gen_session_initial_prompt( def append_turn_input( tokenizer, prompt: str, + prompt_token_len: int, generated_text: str, turn_input_increment: int, rng: random.Random, @@ -98,17 +189,34 @@ def append_turn_input( new_text = decode_ids(tokenizer, new_ids) else: new_text = "" - new_prompt = prompt + generated_text + new_text - new_len = len(tokenizer.encode(new_prompt, add_special_tokens=False)) + + appended_text = generated_text + new_text + new_prompt = prompt + appended_text + if not appended_text: + return new_prompt, prompt_token_len + + # Token merges only depend on a small boundary window, so avoid + # re-encoding the entire prompt on every turn. + overlap_text = prompt[-_PROMPT_LEN_OVERLAP_CHARS:] + if overlap_text: + overlap_token_len = len(tokenizer.encode(overlap_text, add_special_tokens=False)) + merged_token_len = len(tokenizer.encode(overlap_text + appended_text, add_special_tokens=False)) + appended_token_len = max(merged_token_len - overlap_token_len, 0) + else: + appended_token_len = len(tokenizer.encode(appended_text, add_special_tokens=False)) + new_len = prompt_token_len + appended_token_len return new_prompt, new_len def stream_one_turn( + tokenizer, url: str, model_name: str, prompt: str, + prompt_token_len: int, max_new_tokens: int, request_timeout_s: int, + max_retries: int = _DEFAULT_TRANSIENT_RETRIES, ) -> Optional[Dict]: """Send one streaming completion request, return per-turn stats: { @@ -117,6 +225,8 @@ def stream_one_turn( "prompt_tokens": int, "completion_tokens": int, "cached_tokens": int, + "cached_tokens_reported": bool, + "usage_estimated": bool, "generated_text": str, } Returns None on failure.""" @@ -131,79 +241,119 @@ def stream_one_turn( } headers = {"Content-Type": "application/json"} - start_time = time.time() - first_token_time: Optional[float] = None - last_token_time: Optional[float] = None - decode_times: List[float] = [] - generated_text_parts: List[str] = [] - prompt_tokens = 0 - completion_tokens = 0 - cached_tokens = 0 - - with requests.Session() as req_session: - req_session.trust_env = False - with req_session.post( - url, - headers=headers, - json=payload, - stream=True, - timeout=(10, request_timeout_s), - ) as response: - if response.status_code != 200: - err = response.text - raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}") - - for raw in response.iter_lines(): - if not raw: - continue - line = raw.strip() - if not line.startswith(b"data:"): - continue - data_str = line[len(b"data:") :].strip() - if data_str == b"[DONE]": - break - try: - chunk = json.loads(data_str) - except Exception: - continue - - # Final usage-only chunk: choices == [] and usage present - usage = chunk.get("usage") - choices = chunk.get("choices") or [] - if usage is not None and not choices: - prompt_tokens = usage.get("prompt_tokens", prompt_tokens) - completion_tokens = usage.get("completion_tokens", completion_tokens) - details = usage.get("prompt_tokens_details") or {} - cached_tokens = details.get("cached_tokens", cached_tokens) - continue - - # Token-bearing chunk - if not choices: - continue - text_piece = choices[0].get("text", "") - if text_piece == "" and choices[0].get("finish_reason") is None: - continue - - now = time.time() - if first_token_time is None: - first_token_time = now - else: - decode_times.append(now - last_token_time) - last_token_time = now - if text_piece: - generated_text_parts.append(text_piece) - - if first_token_time is None: - raise RuntimeError("stream_one_turn failed: no token received from stream") - - return { - "ttft": first_token_time - start_time, - "decode_times": decode_times, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "cached_tokens": cached_tokens, - "generated_text": "".join(generated_text_parts), - } + for attempt in range(max_retries + 1): + start_time = time.time() + first_token_time: Optional[float] = None + last_token_time: Optional[float] = None + decode_times: List[float] = [] + generated_text_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cached_tokens_reported = False + + try: + with requests.Session() as req_session: + req_session.trust_env = False + with req_session.post( + url, + headers=headers, + json=payload, + stream=True, + timeout=(10, request_timeout_s), + ) as response: + if response.status_code != 200: + err = response.text + if response.status_code >= 500 and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + print(f"\n[turn failed] status={response.status_code} body={err[:200]}") + return None + + for raw in response.iter_lines(): + if not raw: + continue + line = raw.strip() + if not line.startswith(b"data:"): + continue + data_str = line[len(b"data:") :].strip() + if data_str == b"[DONE]": + break + try: + chunk = json.loads(data_str) + except Exception: + continue + + # Final usage-only chunk: choices == [] and usage present + usage = chunk.get("usage") + choices = chunk.get("choices") or [] + if usage is not None and not choices: + prompt_tokens = usage.get("prompt_tokens", prompt_tokens) + completion_tokens = usage.get("completion_tokens", completion_tokens) + details = usage.get("prompt_tokens_details") + if isinstance(details, dict) and details.get("cached_tokens") is not None: + cached_tokens = details["cached_tokens"] + cached_tokens_reported = True + continue + + # Token-bearing chunk + if not choices: + continue + text_piece = choices[0].get("text", "") + if text_piece == "" and choices[0].get("finish_reason") is None: + continue + + now = time.time() + if first_token_time is None: + first_token_time = now + else: + decode_times.append(now - last_token_time) + last_token_time = now + if text_piece: + generated_text_parts.append(text_piece) + except _TRANSIENT_STREAM_ERRORS as e: + if first_token_time is None and attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + + if first_token_time is not None: + print(f"\n[turn warning] {e}; discarding partial turn (attempt={attempt + 1})") + return None + + print(f"\n[turn exception] {e}") + return None + except Exception as e: + print(f"\n[turn exception] {e}") + return None + + if first_token_time is None: + if attempt < max_retries: + time.sleep(0.2 * (attempt + 1)) + continue + return None + + generated_text = "".join(generated_text_parts) + usage_estimated = False + if prompt_tokens == 0: + prompt_tokens = prompt_token_len + usage_estimated = True + if completion_tokens == 0: + estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False)) + completion_tokens = max(estimated_completion_tokens, len(generated_text_parts)) + usage_estimated = True + + return { + "ttft": first_token_time - start_time, + "decode_times": decode_times, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cached_tokens": cached_tokens, + "cached_tokens_reported": cached_tokens_reported, + "usage_estimated": usage_estimated, + "generated_text": generated_text, + } + + return None def run_session( @@ -234,9 +384,11 @@ def run_session( while turn_idx < max_turns and prompt_len < max_input_len: turn_output_len = rng.randint(min_output_len, output_len) result = stream_one_turn( + tokenizer=tokenizer, url=url, model_name=model_name, prompt=prompt, + prompt_token_len=prompt_len, max_new_tokens=turn_output_len, request_timeout_s=request_timeout_s, ) @@ -248,13 +400,15 @@ def run_session( print( f"\rconc={progress_state['concurrency']} " f"finished_turns={progress_state['finished_turns']} " - f"active_sessions={progress_state['active_sessions']}", + f"active_sessions={progress_state['active_sessions']}\033[K", end="", + flush=True, ) turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) prompt, prompt_len = append_turn_input( tokenizer, prompt, + result["prompt_tokens"] or prompt_len, result["generated_text"], turn_input_len, rng, @@ -358,13 +512,14 @@ def summarize( prompt_tokens = sum(t["prompt_tokens"] for t in turns) completion_tokens = sum(t["completion_tokens"] for t in turns) cached_tokens = sum(t["cached_tokens"] for t in turns) + cached_tokens_reported_turns = sum(1 for t in turns if t.get("cached_tokens_reported")) + usage_estimated_turns = sum(1 for t in turns if t.get("usage_estimated")) total_tokens = prompt_tokens + completion_tokens qps = len(turns) / wall_time tpm_total = total_tokens / wall_time * 60.0 tpm_prompt = prompt_tokens / wall_time * 60.0 tpm_completion = completion_tokens / wall_time * 60.0 - cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 out["QPS"] = round(qps, 4) out["TPM_total"] = round(tpm_total, 2) @@ -373,7 +528,18 @@ def summarize( out["total_prompt_tokens"] = prompt_tokens out["total_completion_tokens"] = completion_tokens out["total_cached_prompt_tokens"] = cached_tokens - out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + out["cached_tokens_reported_turns"] = cached_tokens_reported_turns + out["usage_estimated_turns"] = usage_estimated_turns + if cached_tokens_reported_turns > 0: + cache_hit_ratio = cached_tokens / prompt_tokens if prompt_tokens else 0.0 + out["cache_hit_ratio"] = round(cache_hit_ratio, 6) + else: + out["cache_hit_ratio"] = None + out["cache_hit_ratio_note"] = ( + "Server did not return usage.prompt_tokens_details.cached_tokens. " + "For vLLM OpenAI-compatible APIs, start the server with " + "--enable-prompt-tokens-details to expose cache-hit stats." + ) out["avg_prompt_tokens_per_turn"] = round(prompt_tokens / len(turns), 2) out["avg_completion_tokens_per_turn"] = round(completion_tokens / len(turns), 2) @@ -406,10 +572,16 @@ def print_summary(summary: Dict) -> None: print(f" TPM (total) : {summary['TPM_total']}") print(f" TPM (prompt) : {summary['TPM_prompt']}") print(f" TPM (completion) : {summary['TPM_completion']}") - print( - f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " - f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" - ) + if summary["cache_hit_ratio"] is None: + print(" Cache hit ratio : n/a") + print(f" Cache hit note : {summary['cache_hit_ratio_note']}") + else: + print( + f" Cache hit ratio : {summary['cache_hit_ratio'] * 100:.2f}% " + f"({summary['total_cached_prompt_tokens']} / {summary['total_prompt_tokens']})" + ) + if summary.get("usage_estimated_turns"): + print(f" Usage estimated : {summary['usage_estimated_turns']} turns") print(f" Avg prompt tokens : {summary['avg_prompt_tokens_per_turn']}") print(f" Avg output tokens : {summary['avg_completion_tokens_per_turn']}") ttft = summary["TTFT_ms"] @@ -432,7 +604,7 @@ def main() -> None: parser.add_argument( "--url", type=str, - default="http://127.0.0.1:8088/v1/completions", + default="http://127.0.0.1:8000/v1/completions", help="Streaming OpenAI completion endpoint. The benchmark relies on " "the final SSE `usage` chunk to obtain cached_tokens.", ) @@ -499,12 +671,19 @@ def main() -> None: return seed_all(args.seed) - model_name = args.model_name or args.tokenizer_path + requested_model_name = args.model_name or args.tokenizer_path + model_name, model_name_note = resolve_model_name( + args.url, + requested_model_name, + explicit_model_name=args.model_name is not None, + ) tokenizer = get_tokenizer(args.tokenizer_path) concurrency_levels = [int(x) for x in args.concurrency_levels.split(",") if x.strip()] print(f"URL : {args.url}") print(f"Model : {model_name}") + if model_name_note: + print(f"Model note : {model_name_note}") print(f"Concurrency levels : {concurrency_levels}") print(f"start_input_len : {args.start_input_len}") print(f"max_input_len : {args.max_input_len}") @@ -538,6 +717,7 @@ def main() -> None: "config": { "url": args.url, "model_name": model_name, + "requested_model_name": requested_model_name, "tokenizer_path": args.tokenizer_path, "concurrency_levels": concurrency_levels, "start_input_len": args.start_input_len, diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 8249ae2c49..a9083091ea 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -31,6 +31,13 @@ def get_tokenizer( return tokenizer +def normalize_model_name(model_name: str) -> str: + if not model_name: + return model_name + normalized = model_name.rstrip("/\\") + return normalized or model_name + + def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]: lens = [] lens = np.random.randint( @@ -394,6 +401,12 @@ def main(): ) parser.add_argument("--num_clients", type=int, default=100) parser.add_argument("--tokenizer_path", type=str, default=None) + parser.add_argument( + "--model_name", + type=str, + default=None, + help="Model name passed to the server. Defaults to --tokenizer_path.", + ) parser.add_argument("--data_path", type=str, default=None) parser.add_argument("--input_num", type=int, default=2000) parser.add_argument("--input_qps", type=float, default=30.0) @@ -429,7 +442,7 @@ def main(): return assert args.tokenizer_path is not None - model_name.append(args.tokenizer_path) + model_name.append(args.model_name if args.model_name is not None else normalize_model_name(args.tokenizer_path)) seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py new file mode 100644 index 0000000000..8a54d6d9fd --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_hadamard_transform.py @@ -0,0 +1,83 @@ +import pytest +import torch +import triton + +from lightllm.models.deepseek3_2.triton_kernel.hadamard_transform import hadamard_transform + + +TP = 8 +INDEX_N_HEADS = 64 +INDEX_HEAD_DIM = 128 +TP_INDEX_N_HEADS = INDEX_N_HEADS // TP +SCALE = INDEX_HEAD_DIM ** -0.5 + + +def _get_sgl_kernel_hadamard_transform(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for hadamard_transform comparison") + try: + from sgl_kernel import hadamard_transform as sgl_hadamard_transform + except ImportError: + pytest.skip("sgl_kernel.hadamard_transform is not available") + return sgl_hadamard_transform + + +def _bench(fn, x): + ms = triton.testing.do_bench_cudagraph(lambda: fn(x, scale=SCALE), return_mode="median") + return ms, fn(x, scale=SCALE) + + +@pytest.mark.parametrize("tokens", [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]) +def test_hadamard_transform_matches_sgl_kernel_deepseek_v32_shapes(tokens): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + assert torch.equal(q_actual, q_expected) + assert torch.equal(k_actual, k_expected) + + +def test_hadamard_transform_perf_report_deepseek_v32_shapes(): + sgl_hadamard_transform = _get_sgl_kernel_hadamard_transform() + + print( + "\nDeepSeek-V3.2 per-rank shapes with tp=8:" + "\n q: [tokens, 8, 128]" + "\n k: [tokens, 128]" + "\n\ntokens | q_diff | k_diff | sgl_q ms | tri_q ms | sgl_k ms | tri_k ms | tri(q+k) ms | slowdown q+k" + ) + + for tokens in [1, 16, 128, 512, 1024, 2048, 4096, 8192, 16384]: + q = torch.randn(tokens, TP_INDEX_N_HEADS, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + k = torch.randn(tokens, INDEX_HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + q_expected = sgl_hadamard_transform(q, scale=SCALE) + q_actual = hadamard_transform(q, scale=SCALE) + k_expected = sgl_hadamard_transform(k, scale=SCALE) + k_actual = hadamard_transform(k, scale=SCALE) + torch.cuda.synchronize() + + q_diff = (q_expected.float() - q_actual.float()).abs().max().item() + k_diff = (k_expected.float() - k_actual.float()).abs().max().item() + sgl_q_ms, _ = _bench(sgl_hadamard_transform, q) + tri_q_ms, _ = _bench(hadamard_transform, q) + sgl_k_ms, _ = _bench(sgl_hadamard_transform, k) + tri_k_ms, _ = _bench(hadamard_transform, k) + sgl_sum_ms = sgl_q_ms + sgl_k_ms + tri_sum_ms = tri_q_ms + tri_k_ms + + print( + f"{tokens:6d} | {q_diff:6.1g} | {k_diff:6.1g} | " + f"{sgl_q_ms:8.4f} | {tri_q_ms:8.4f} | {sgl_k_ms:8.4f} | {tri_k_ms:8.4f} | " + f"{tri_sum_ms:11.4f} | {tri_sum_ms / sgl_sum_ms:10.2f}x" + ) + + assert q_diff == 0 + assert k_diff == 0