Skip to content

[WIP]: Support mtp + dp#4611

Open
RunningLeon wants to merge 8 commits into
InternLM:mainfrom
RunningLeon:mtp-dp-tp
Open

[WIP]: Support mtp + dp#4611
RunningLeon wants to merge 8 commits into
InternLM:mainfrom
RunningLeon:mtp-dp-tp

Conversation

@RunningLeon
Copy link
Copy Markdown
Collaborator

@RunningLeon RunningLeon commented May 21, 2026

Motivation

Support dp> 1 and ep >1 and tp > 1 for qwen3_5_mtp

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copilot AI review requested due to automatic review settings May 21, 2026 13:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the PyTorch engine’s speculative decoding (Qwen3.5 MTP) to better support distributed setups (notably dp > 1, and draft-side TP), and wires the necessary metadata/communication hooks through the model-agent and strategy layers.

Changes:

  • Add draft-model DistConfig into SpecDecodeConfig and propagate it through config building and spec-agent initialization.
  • Adjust KV-cache sizing/accounting to consider draft TP and update executor tests accordingly.
  • Update DP + speculative decode flow in the model agent (sampling, broadcasting, and post-broadcast synchronization) and add DeepEP-related runtime configuration.

Reviewed changes

Copilot reviewed 19 out of 20 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/pytorch/engine/test_executor_base.py Adds unit tests for draft-TP-aware cache block sizing.
lmdeploy/pytorch/strategies/base/model_inputs.py Introduces make_dummy_extra_inputs hook for dummy DP forwards.
lmdeploy/pytorch/strategies/base/model_agent.py Adds a post_broadcast hook to the strategy API.
lmdeploy/pytorch/strategies/ar_spec/model_inputs.py Implements dummy extra-input creation for AR-spec.
lmdeploy/pytorch/strategies/ar_spec/model_agent.py Updates broadcast payloads and adds a post-broadcast draft-token sync.
lmdeploy/pytorch/spec_decode/spec_agent.py Refactors SpecModelAgent init via BaseSpecModelAgent, updates cache-engine build for draft TP, and adjusts DP-aware speculative flow.
lmdeploy/pytorch/spec_decode/base.py Moves common spec-agent wiring into the base class and adds a post_broadcast wrapper.
lmdeploy/pytorch/spec_decode/init.py Changes per-rank enable logic for speculative decoding and constructs spec/base agents accordingly.
lmdeploy/pytorch/models/qwen3_5_mtp.py Removes TP/DP flags from a module construction path (to support broader parallel modes).
lmdeploy/pytorch/model_inputs.py Extends DPMeta with spec-decode-related fields (decoding flags, batch sizes).
lmdeploy/pytorch/envs.py Reads DEEPEP_MAX_TOKENS_PER_RANK as an int config value.
lmdeploy/pytorch/engine/model_agent/inputs_maker.py Includes dummy extra_inputs and return_logits when fabricating DP dummy forwards.
lmdeploy/pytorch/engine/model_agent/agent.py Reworks sampling/spec-decode integration and DP metadata handling to support dp+spec flows.
lmdeploy/pytorch/engine/executor/base.py Makes cache sizing draft-TP-aware and adjusts per-rank cache accounting.
lmdeploy/pytorch/engine/engine.py Passes dist_config into specdecode config building.
lmdeploy/pytorch/engine/config_builder.py Builds a draft-side DistConfig (copy target for Qwen3.5 MTP) and injects it into SpecDecodeConfig.
lmdeploy/pytorch/config.py Adds dist_config field to SpecDecodeConfig and passes it into ModelConfig.from_pretrained.
lmdeploy/pytorch/backends/cuda/token_dispatcher.py Uses DEEPEP_MAX_TOKENS_PER_RANK for DeepEP dispatch sizing.
lmdeploy/pytorch/backends/cuda/graph_runner.py Gates DeepEP setup/cleanup by EP mode; adjusts dp-meta TP-size sync for query length.
.gitignore Ignores additional workspace directories.
Comments suppressed due to low confidence (1)

lmdeploy/pytorch/strategies/base/model_agent.py:175

  • ModelAgentStrategy.post_broadcast is marked with @contextmanager but the default implementation has no yield (it’s just pass). Entering this context will raise "generator didn't yield" if any strategy relies on the default. Make it a real no-op context manager (single yield) or mark it @abstractmethod to require overrides.
    @contextmanager
    def post_broadcast(self, extra_inputs: ExtraInputs, dist_ctx: 'DistContext'):
        """Post broadcast extra inputs."""
        pass


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 46 to 50
def broadcast(self, src: int, group, async_op=False):
dist.broadcast(self.output_draft_token_ids, src=src, group=group, async_op=async_op)
dist.broadcast(self.last_token_indices, src=src, group=group, async_op=async_op)
dist.broadcast(self.output_token_ids, src=src, group=group, async_op=async_op)
handle = dist.broadcast(self.num_rejected_tokens, src=src, group=group, async_op=async_op)
return handle
else:
from .base import BaseSpecModelAgent
return BaseSpecModelAgent(specdecode_config)
return BaseSpecModelAgent(specdecode_config,
@RunningLeon RunningLeon requested a review from grimoire May 26, 2026 12:08
Comment thread lmdeploy/pytorch/backends/cuda/graph_runner.py Outdated
"""Build spec decode config."""
def _build_draft_dist_ctx(dist_config):
# TODO support tp > 1, ep > 1 for other methods
if speculative_config.method == 'qwen3_5_mtp':
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too dirty

Comment thread lmdeploy/pytorch/spec_decode/base.py Outdated
Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants