Skip to content

support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure#190

Open
meichangsu1 wants to merge 8 commits into
modelscope:mainfrom
meichangsu1:dsv4_fsdp2_ljl
Open

support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure#190
meichangsu1 wants to merge 8 commits into
modelscope:mainfrom
meichangsu1:dsv4_fsdp2_ljl

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

@meichangsu1 meichangsu1 commented May 12, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR adds DeepSeek V4 support across template processing, Transformers model initialization, and training cookbook examples.

  1. DeepSeek V4 template support
  1. Reduce Transformers v5 model loading memory pressure
  • Refine pretrained model initialization for distributed/FSDP training.
  • Add rank-aware empty-model initialization for non-rank0 workers when memory-efficient init is enabled.
  • Build empty models from config under init_empty_weights, then rely on later FSDP wrapping / broadcast-based loading flow.
  • Preserve normal from_pretrained behavior for non-memory-efficient paths.
  • This is aimed at avoiding excessive peak memory use and improving compatibility with newer Transformers loading behavior.
  1. Add DeepSeek V4 FSDP2 + LoRA cookbook
  • Add cookbook/transformers/deepseek_v4_flash.py.
  • Provide an end-to-end local training example with:
    • DeepseekV4Template
    • FSDP2/native FSDP strategy
    • memory-efficient initialization
  • Current cookbook scope is FSDP2 + LoRA training only. EP is not supported in this example yet.
  1. Training correctness and current limitations
  • Training accuracy has been aligned with the expected baseline.
  • Add dtype alignment handling around LoRA/FSDP integration.
  • For NPU, add a temporary workaround to align parameter dtype more aggressively with the base model dtype to keep training behavior correct.
  • NPU correctness is aligned, but runtime performance still needs further optimization.

Experiment results

  • DeepSeek V4 FSDP2 + LoRA training flow has been validated.
  • Accuracy behavior has been aligned with the expected baseline.
  • EP training is not covered by the current cookbook.
  • NPU training correctness is available, while performance optimization remains follow-up work.

Below are the two-layer loss alignment results on NPU and GPU. The loss curves overlap almost perfectly, and the loss difference gradually converges to zero in the final stage.

a0412c9124febad049676c9d457954f6

Introduce BaseAgentTemplate and DeepSeekV4AgentTemplate for agent-based interactions. Add ReactCompatMixin for parsing and formatting ReAct-style tool calls, including Action/Action Input/Observation keywords. Implement ToolDesc and AgentKeyword dataclasses to support structured tool descriptions and agent keywords.
- Update `_get_decoder_layers` to first search for modules matching `_no_split_modules` names
- Add `_get_no_split_module_names` helper to collect no-split module names from model hierarchy
- Add `_normalize_no_split_modules` utility for consistent set conversion
- Change return type hint from `nn.ModuleList` to `List[nn.Module]` for flexibility
…in FSDP strategy

- Broadcast non-persistent buffers from rank 0 to all ranks instead of only restoring on rank 0
- Add source metadata validation to ensure shape/dtype consistency before distributing tensors
- Fix tie_weights() call to execute on all ranks instead of only rank 0
- Improve error handling with explicit KeyError and RuntimeError for state dict mismatches
…nd update references

- Rename class and module exports for consistency with naming conventions
- Update default TEMPLATE_ID in deepseek_v4_flash.py to use new template name
- Refactor encoding/decoding logic for chat message processing with improved tool call handling
…4 example

- Change LORA_TARGET_MODULES from 'all-linear' to specific module list for DeepSeek V4
- Remove gradient_accumulation_steps parameter from forward_backward call
- Fix NPU device dtype alignment to convert all parameters to base dtype when on NPU
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the DeepSeek-V4 model, including a training cookbook, a specialized chat template with thinking mode and tool-calling capabilities, and significant enhancements to the Native FSDP strategy. Key changes include the implementation of rank-0 broadcasting for pretrained weights and non-persistent buffers to support memory-efficient initialization, as well as NPU-specific dtype alignment. Feedback highlights a critical bug in the FSDP state dict broadcasting where DTensor parameters skip synchronization, potentially leading to uninitialized memory shards. Additionally, it is recommended to rename the eval function in the cookbook to avoid shadowing the Python built-in and to replace the regex library dependency with the standard re module for consistency.

Comment on lines +554 to +562
if isinstance(sharded_param, DTensor):
sharded_tensor = distribute_tensor(
full_tensor,
sharded_param.device_mesh,
sharded_param.placements,
)
else:
dist.broadcast(full_tensor, src=0)
sharded_tensor = full_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _broadcast_sharded_state_dict, the dist.broadcast call is currently skipped when sharded_param is a DTensor. However, distribute_tensor requires the input tensor to be logically global (synchronized across all ranks) to correctly produce shards. Without this broadcast, non-rank0 ranks will shard uninitialized memory from torch.empty. The broadcast should be performed for all parameters before sharding.

Suggested change
if isinstance(sharded_param, DTensor):
sharded_tensor = distribute_tensor(
full_tensor,
sharded_param.device_mesh,
sharded_param.placements,
)
else:
dist.broadcast(full_tensor, src=0)
sharded_tensor = full_tensor
dist.broadcast(full_tensor, src=0)
if isinstance(sharded_param, DTensor):
sharded_tensor = distribute_tensor(
full_tensor,
sharded_param.device_mesh,
sharded_param.placements,
)
else:
sharded_tensor = full_tensor

return dataset


def eval(model):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function name eval shadows the Python built-in eval() function. It is recommended to rename it to something more descriptive, such as evaluate or run_eval, to avoid confusion and potential name resolution issues. Note that the call site at line 135 should also be updated.

Suggested change
def eval(model):
def evaluate(model):

logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')

if step > 0 and step % SAVE_STEPS == 0:
metrics = eval(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the call site to match the renamed evaluation function.

Suggested change
metrics = eval(model)
metrics = evaluate(model)


import copy
import json
import regex as re
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file introduces a dependency on the 3rd-party regex library, while other files in the project (like src/twinkle/template/utils.py) use the standard library re. Since the patterns used here are simple and don't require advanced regex features, it's better to use the standard re module to maintain consistency and avoid unnecessary dependencies.

Suggested change
import regex as re
import re

@meichangsu1 meichangsu1 changed the title Dsv4 fsdp2 ljl support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure May 12, 2026
- Update default model ID from DeepSeek-V4-flash-bfa16 to DeepSeek-V4-Flash
- Add comment explaining FP4/FP8 weight conversion requirement
- Fix LoRA target module names to match actual model architecture
…th DCP

Replace the custom `_broadcast_sharded_state_dict` function with `_load_rank0_full_state_dict` that leverages `torch.distributed.checkpoint`'s `set_model_state_dict` with `broadcast_from_rank0=True`. This simplifies the code by using the official DCP API for distributing rank0 full state dict to FSDP2 shards, removing manual tensor distribution and metadata broadcasting logic.
Use Platform.device_prefix() instead of iterating over model parameters to check for NPU device, improving efficiency and correctness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant