Skip to content

gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502

Open
bigximik wants to merge 7 commits intogrpo-metricsfrom
gspo
Open

gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502
bigximik wants to merge 7 commits intogrpo-metricsfrom
gspo

Conversation

@bigximik
Copy link
Copy Markdown
Collaborator

@bigximik bigximik commented Apr 29, 2026

Summary

This PR adds GSPO loss to fast-LLM along with a suite of supporting fixes that together achieve full metric and training-trajectory parity with DeepSpeed's GRPO/GSPO implementation. Targets the grpo-metrics branch. Six logical units:

1. GSPO loss (sequence-level IS-ratio clipping)

Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo".

  • New fused_gspo_loss_forward_backward kernel: computes per-segment geometric-mean log-ratio R_s, clips at [1−ε_low, 1+ε_high], and applies R_s × A_s as a uniform per-token gradient within each segment. An all_reduce(SUM) over sequence-data-parallel ranks aggregates (lrn_sum, adv_sum, tok_count) before clipping so the ratio is correct under sequence parallelism.
  • New document_index data field and LanguageModelKwargs.document_index kwarg constant to route per-token segment membership through the data pipeline.
  • 8 unit tests in tests/layers/test_gspo_loss.py (single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).

2. Dynamic docs_per_step accumulation

Replaces static depth_first_micro_batches with a runtime document-count target — matching DeepSpeed's gradient_accumulation_passes semantics for RL (where each microbatch holds one rollout).

  • ScheduleConfig.docs_per_step: when >0, Trainer._prefetch_to_doc_target fetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥ docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.
  • Trainer._get_or_build_schedule builds and caches a per-N Schedule with _depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.
  • Schedule._eff_{depth_first,sequential,num_inputs} properties expose the effective values for a given override.
  • 13 unit tests in tests/layers/test_docs_per_step.py.

3. normalize_by_documents

Adds a normalize_by_documents flag to LanguageModelGRPOLossConfig. When True, both the GRPO and GSPO paths divide the loss by num_documents_in_batch (the step-level rollout count) rather than the token count. Matches DeepSpeed's normalization where tokens_weights = 1 / batch_size.

4. Temperature scaling for IS ratio parity

Adds a temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in _forward_backward. Default temperature=1.0 preserves existing behaviour exactly.

5. fp32_lm_head precision fix (matches vLLM's bf16_last_layer_fp32)

Adds a fp32_lm_head flag (default False) on LanguageModelHeadConfig. When True, the LM head's logits computation upcasts both input and weight to FP32 before the linear projection, matching vLLM's bf16_last_layer_fp32 quantization. This ensures the trainer computes log-probabilities at the same numerical precision as the actor's sampling, so new_logprobs ≈ old_logprobs at step 0 (IS ratio at training start ≈ 1.0, not artificially inflated by precision mismatch).

  • Commit d8cb9ef5: introduces the flag, upcasts input/weight, casts back to BF16 before downstream consumption.
  • Commit 0f90f20b: fixes the gradient flow when fp32_lm_head=True. The detached FP32 weight copy has requires_grad=False, which makes output_parallel_linear_backward skip writing to the original weight's grad_buffer. We restore the FSDP gradient contract by computing grad_weight = grad.t() @ saved_input explicitly and accumulating into the BF16 param's grad_buffer via accumulate_gradient.

6. Decoupled loss/gradient divisors and SDP loss double-counting fix

Even with normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024× larger than DeepSpeed's, causing the default gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. Two issues, fixed in commit 557a3c4c:

Asymmetric loss/gradient scaling in DS:

  • DS loss reported uses /batch_size once (via tokens_weights = 1/batch_size, pipelinerl/finetune/rl/__init__.py:246).
  • DS gradient buffer has an ADDITIONAL /(gas × world_size) factor from scale_wrt_gas=True in engine.backward() (deepspeed/runtime/engine.py:1995-1996) and tensor.div_(world_sz) in reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).
  • For samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, so the gradient buffer effectively has 1/batch_size² while the loss metric has 1/batch_size.

Fast-LLM cancels DS's /(gas × world_size) factor structurally via grad_output = data_parallel × grad_scale (runner.py:318) interacting with FSDP's RS-AVG over data_parallel ranks (fsdp.py:396). So we need to apply the second 1/batch_size factor explicitly only to the gradient — keeping the loss metric matched to DS.

Fix: add a grad_divisor parameter to fused_gspo_loss_forward_backward, fused_grpo_loss_forward_backward, and triton_grpo_loss_forward_backward. When normalize_by_documents=true:

  • loss divisor = num_documents_in_batch (matches DS rl/loss)
  • gradient divisor = num_documents_in_batch² (matches DS grad_norm)

Independent of TP/PP/SDP/DP parallelism and microbatching schedule, because batch_size is invariant under all of them.

SDP loss double-counting:
After the SDP allreduce of lrn_sum/adv_sum/tok_sum in fused_gspo_loss_forward_backward, both SDP ranks compute IDENTICAL per-segment loss values. When LossDef.reduce SUMs across data_group (which includes SDP ranks), the loss metric is double-counted by sdp_size. The gradient is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment.

Fix: divide loss by sdp_size when sdp_group is active. Gradient unaffected.

Verification

End-to-end 7B math run on 4 nodes, GSPO, gradient_norm_clipping=0.3 (default), normalize_by_documents=true, fp32_lm_head=true, temperature=0.7:

Metric Before unit-6 fix After unit-6 fix DS GSPO reference
step 1 grad_norm 141 (1000× DS) 0.135 0.145
step 1 lm_head_loss -13.7 ~-1.7 magnitude -1.7
step 1 clip_coeff 0.002 (severe over-clip) 1.000 (no clip) no clip
step 50 newlp trapped at -0.17 -0.103 -0.105

newlp trajectory tracks DS step-by-step. Both systems show same gradient-spike pattern during warmup ramp-up at steps 14-20 (DS step 16 grad_norm=6.365, fast-LLM step 15=9.005). Match within data variance.

Test plan

  • pytest tests/layers/test_gspo_loss.py — GSPO unit tests pass
  • pytest tests/layers/test_docs_per_step.py — docs_per_step unit tests pass
  • pytest tests/layers/test_lm_losses.py — existing GRPO loss tests unaffected
  • pytest tests/layers/test_grpo_metrics.py — metrics tests unaffected
  • End-to-end: 4-node Qwen2.5-7B math run with full config (docs_per_step=1024, temperature=0.7, normalize_by_documents=true, fp32_lm_head=true, default gradient_norm_clipping=0.3) — grad_norm matches DS at step 1, training trajectory matches DS step-by-step through step 50+ (ongoing run validates through step ~410).

Implements GSPO (geometric-mean sequence-level policy-gradient loss) as
an alternative to the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".

Key changes:
- data pipeline: expose per-token document_index when return_document_index=True
- LanguageModelKwargs.document_index: new kwarg constant
- LanguageModelLoss: store SDP dim for cross-rank segment aggregation
- grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across
  SDP ranks before computing segment-level R_s and A_s; gradient derivation
  exploits tok_count cancellation so every token in a segment gets the
  same gradient factor R_s * clip_indicator_s
- tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed,
  ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff,
  per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).

Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.

YAML usage:
  schedule:
    rollouts_per_step: 1024   # replaces manual depth_first_micro_batches
  model:
    distributed:
      data_parallel: 8        # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
  is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
  properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
  all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
  then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
  _depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
  both GRPO and GSPO paths divide by num_documents_in_batch instead of
  num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
  scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
  and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match
the actor's sampling temperature (e.g. 0.7), new log-probs are computed
at the same temperature as the stored old log-probs, so the IS ratio
starts near 1.0 instead of ~1.08.

Implementation: _effective_logits_scale = logits_scale_factor / temperature,
substituted for logits_scale_factor at all three callsites in
_forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default
temperature=1.0 preserves existing behaviour exactly.
@bigximik bigximik requested a review from jlamypoirier April 29, 2026 08:04
bigximik added 3 commits May 4, 2026 07:14
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden
states and output_weights are cast to float32 before the lm_head linear,
producing FP32 logits. This matches vLLM's bf16_last_layer_fp32
quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's
apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed
at the same numerical precision and the IS ratio starts near 1.0 at init.

The gradient flowing back through the linear is cast to the original
input dtype (bf16) before returning, keeping the transformer backward pass
in its native dtype.
…accumulation

Detaching the FP32 weight copy (requires_grad=False) prevents
output_parallel_linear_backward from trying to write to a non-existent
grad_buffer on the copy. Weight grad is then computed explicitly from
the FP32 matmul and accumulated into the original BF16 param's grad_buffer
via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.

Root cause analysis
-------------------

DeepSpeed has TWO 1/batch_size factors with different sources:

  1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
     (pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
     value is the raw policy_loss_total, divided once by batch_size.

  2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
     from `scale_wrt_gas=True` in engine.backward()
     (deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
     reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).

For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.

Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.

Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.

Fixes
-----

1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
   `fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
   defaulting to `divisor` (existing behavior). Allows the gradient to use a
   different divisor than the loss.

2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
   is True, set:
     loss divisor      = num_documents_in_batch     (matches DS rl/loss)
     gradient divisor  = num_documents_in_batch²    (matches DS grad_norm)
   This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
   because batch_size is invariant under all of these.

3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
   (`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
   that LossDef.reduce's SUM over data_group introduces. The gradient is
   unaffected — different SDP ranks naturally contribute gradient from
   different LOCAL token positions, no double-counting at any layer.

Verification
------------

Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:

  Before fix          | After fix          | DS GSPO reference
  ------------------- | ------------------ | ------------------
  step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
  step 1 lm_head_loss | step 1 lm_head_loss   | step 1 rl/loss
   = -13.7            |  ~ -1.7 (sign varies  |   = -1.7
                      |   per data sample)    |
  clip_coeff=0.002    | clip_coeff=1.000      | no clipping at step 1
  newlp at step 50    | newlp at step 50      | newlp at step 50
   trapped at -0.17   |  = -0.103             |  = -0.105

newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.

Files changed
-------------

- fast_llm/layers/language_model/loss/grpo.py:
  - LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
    based on normalize_by_documents flag, with detailed comments referencing
    the corresponding lines in DeepSpeed and PipelineRL.
  - fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
    by sdp_size when sdp_group is active.
  - fused_grpo_loss_forward_backward: add grad_divisor parameter.

- fast_llm/functional/triton/grpo_loss.py:
  - triton_grpo_loss_forward_backward: add grad_divisor parameter.
@bigximik bigximik changed the title gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature) May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant