v0.53.11 — GRPO variants live, PRM live, LongLoRA live, weighted preference

The capstone of the v0.50 / v0.49 / v0.40 deferred-stub debt. Five deep TRL trainer-subclassing items lifted.

Six GRPO objective variants live

apply_variant_loss ships real math kernels for every non-standard variant:

  • gspo — group-stabilised importance ratio
  • dapo — decoupled asymmetric clip
  • dr_grpo — no length normalization
  • bnpo — length-normalised PPO
  • two_sided — symmetric grpo_delta clip
  • rft — rejection-sampling fine-tuning

Subclassing is done via make_grpo_trainer_variant (an lru_cache factory over _GRPOTrainerVariant) that overrides compute_loss to route through the kernel, with a defensive fallback to the original loss if TRL renames input attributes.

`task: prm` Process Reward Model — live

PRMTrainerWrapper loads AutoModelForCausalLM + nn.Linear(hidden, 1) as the reward head. make_prm_trainer_class(HF Trainer) factory subclasses Trainer and overrides compute_loss to:

1. Gather hidden states at step-boundary tokens

2. Project to scalars through the linear head

3. MSE-loss against per-step labels

_prepare_prm_dataset tokenizes {prompt, completions, labels} with reserved-token truncation; _build_collator pads with -1 sentinel for missing step positions.

yaml
task: prm
base: deepseek-ai/deepseek-math-7b-base
data:
  train: prm_steps.jsonl
  format: prm   # {prompt, completions, labels}

GRPOStabilityCallback live

  • EMA reference-model update fires post-step: (1-α) * ref + α * policy per parameter (load_state_dict(strict=False) so LoRA state stays valid)
  • Bounded-deque replay buffer
  • TIS (truncated importance sampling) alert counter
  • ema_alpha surfaced via state.log_history so soup why can flag instability

LongLoRA S² forward override live

The v0.49.0 schema is now backed by a runtime. shift_heads_for_s2 rolls the second half of attention heads by group_size // 2 along the sequence axis (LongLoRA paper §3.2). LongLoRAForwardOverride is a context manager that monkey-patches every Llama / Mistral / Qwen / Phi attention module's forward, restoring on exit. Idempotent __del__ cleanup and best-effort exception swallow ensure training never crashes from a shape mismatch.

True weighted-sum preference combine

attach_weighted_preference_combine now reads the four logprob tensors from TRL inputs (policy_chosen_logps / policy_rejected_logps / reference_chosen_logps / reference_rejected_logps) and computes each requested loss via the matching kernel:

  • compute_dpo_term
  • compute_ipo_term
  • compute_simpo_term
  • compute_orpo_term

Then combines via combine_losses(terms, weights). BCO mixed with paired losses is still rejected at validation time. Defence-in-depth fallback to the v0.40.1 primary-loss scaling when TRL renames the logps attributes.

yaml
task: preference
training:
  preference_loss_weights:
    dpo: 0.7
    simpo: 0.3

Tests

  • 8,330 → 8,400 (+75) in test_v05311.py (54 initial + 21 review-fix coverage gaps from python / code / security / tdd review agents).
  • Math kernels are real and unit-tested. Full GPU smoke runs documented in the v0.53.11.1 follow-up plan.

See also

  • [GRPO Plus reference](/docs/grpo-plus)
  • [Long Context](/docs/long-context)
  • [Preference Variety](/docs/preference-variety)