support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure#190
support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure#190meichangsu1 wants to merge 8 commits into
Conversation
Introduce BaseAgentTemplate and DeepSeekV4AgentTemplate for agent-based interactions. Add ReactCompatMixin for parsing and formatting ReAct-style tool calls, including Action/Action Input/Observation keywords. Implement ToolDesc and AgentKeyword dataclasses to support structured tool descriptions and agent keywords.
- Update `_get_decoder_layers` to first search for modules matching `_no_split_modules` names - Add `_get_no_split_module_names` helper to collect no-split module names from model hierarchy - Add `_normalize_no_split_modules` utility for consistent set conversion - Change return type hint from `nn.ModuleList` to `List[nn.Module]` for flexibility
…in FSDP strategy - Broadcast non-persistent buffers from rank 0 to all ranks instead of only restoring on rank 0 - Add source metadata validation to ensure shape/dtype consistency before distributing tensors - Fix tie_weights() call to execute on all ranks instead of only rank 0 - Improve error handling with explicit KeyError and RuntimeError for state dict mismatches
…nd update references - Rename class and module exports for consistency with naming conventions - Update default TEMPLATE_ID in deepseek_v4_flash.py to use new template name - Refactor encoding/decoding logic for chat message processing with improved tool call handling
…4 example - Change LORA_TARGET_MODULES from 'all-linear' to specific module list for DeepSeek V4 - Remove gradient_accumulation_steps parameter from forward_backward call - Fix NPU device dtype alignment to convert all parameters to base dtype when on NPU
There was a problem hiding this comment.
Code Review
This pull request introduces support for the DeepSeek-V4 model, including a training cookbook, a specialized chat template with thinking mode and tool-calling capabilities, and significant enhancements to the Native FSDP strategy. Key changes include the implementation of rank-0 broadcasting for pretrained weights and non-persistent buffers to support memory-efficient initialization, as well as NPU-specific dtype alignment. Feedback highlights a critical bug in the FSDP state dict broadcasting where DTensor parameters skip synchronization, potentially leading to uninitialized memory shards. Additionally, it is recommended to rename the eval function in the cookbook to avoid shadowing the Python built-in and to replace the regex library dependency with the standard re module for consistency.
| if isinstance(sharded_param, DTensor): | ||
| sharded_tensor = distribute_tensor( | ||
| full_tensor, | ||
| sharded_param.device_mesh, | ||
| sharded_param.placements, | ||
| ) | ||
| else: | ||
| dist.broadcast(full_tensor, src=0) | ||
| sharded_tensor = full_tensor |
There was a problem hiding this comment.
In _broadcast_sharded_state_dict, the dist.broadcast call is currently skipped when sharded_param is a DTensor. However, distribute_tensor requires the input tensor to be logically global (synchronized across all ranks) to correctly produce shards. Without this broadcast, non-rank0 ranks will shard uninitialized memory from torch.empty. The broadcast should be performed for all parameters before sharding.
| if isinstance(sharded_param, DTensor): | |
| sharded_tensor = distribute_tensor( | |
| full_tensor, | |
| sharded_param.device_mesh, | |
| sharded_param.placements, | |
| ) | |
| else: | |
| dist.broadcast(full_tensor, src=0) | |
| sharded_tensor = full_tensor | |
| dist.broadcast(full_tensor, src=0) | |
| if isinstance(sharded_param, DTensor): | |
| sharded_tensor = distribute_tensor( | |
| full_tensor, | |
| sharded_param.device_mesh, | |
| sharded_param.placements, | |
| ) | |
| else: | |
| sharded_tensor = full_tensor |
| return dataset | ||
|
|
||
|
|
||
| def eval(model): |
There was a problem hiding this comment.
The function name eval shadows the Python built-in eval() function. It is recommended to rename it to something more descriptive, such as evaluate or run_eval, to avoid confusion and potential name resolution issues. Note that the call site at line 135 should also be updated.
| def eval(model): | |
| def evaluate(model): |
| logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') | ||
|
|
||
| if step > 0 and step % SAVE_STEPS == 0: | ||
| metrics = eval(model) |
|
|
||
| import copy | ||
| import json | ||
| import regex as re |
There was a problem hiding this comment.
This file introduces a dependency on the 3rd-party regex library, while other files in the project (like src/twinkle/template/utils.py) use the standard library re. Since the patterns used here are simple and don't require advanced regex features, it's better to use the standard re module to maintain consistency and avoid unnecessary dependencies.
| import regex as re | |
| import re |
- Update default model ID from DeepSeek-V4-flash-bfa16 to DeepSeek-V4-Flash - Add comment explaining FP4/FP8 weight conversion requirement - Fix LoRA target module names to match actual model architecture
…th DCP Replace the custom `_broadcast_sharded_state_dict` function with `_load_rank0_full_state_dict` that leverages `torch.distributed.checkpoint`'s `set_model_state_dict` with `broadcast_from_rank0=True`. This simplifies the code by using the official DCP API for distributing rank0 full state dict to FSDP2 shards, removing manual tensor distribution and metadata broadcasting logic.
Use Platform.device_prefix() instead of iterating over model parameters to check for NPU device, improving efficiency and correctness.
PR type
PR information
This PR adds DeepSeek V4 support across template processing, Transformers model initialization, and training cookbook examples.
DeepseekV4Templateand register it intwinkle.template.apply_chat_template.init_empty_weights, then rely on later FSDP wrapping / broadcast-based loading flow.from_pretrainedbehavior for non-memory-efficient paths.cookbook/transformers/deepseek_v4_flash.py.DeepseekV4TemplateExperiment results
Below are the two-layer loss alignment results on NPU and GPU. The loss curves overlap almost perfectly, and the loss difference gradually converges to zero in the final stage.