DRO for LLM Reliability
Using distributionally robust optimization to make LLM responses consistently good, not just good on average
Standard RL for LLMs (GRPO, best-of-N, rejection sampling) optimizes average single-sample performance. A prompt solved 3 out of \(k\) times gets the same training signal as one solved \(k-1\) out of \(k\). We’re applying CVaR-style distributionally robust optimization to optimize the tail of the per-prompt success distribution instead.
The gap
Current methods exploit variance rather than reduce it. Best-of-N sampling benefits from high variance; GRPO is indifferent to per-prompt consistency. We want models that reliably succeed, not models that sometimes succeed brilliantly.
Method
For each training sample, compute the loss and apply CVaR: focus gradient on the worst-performing fraction \(\alpha\) of the batch. Parameter \(\alpha \in (0,1]\) interpolates between expected loss (\(\alpha = 1\), standard ERM) and worst-case (\(\alpha \to 0\)).
We tested five CVaR variants on GSM8K SFT with Qwen2.5-Math-1.5B-Instruct (1,319 test problems, 2,000 training steps):
- Hard CVaR. Select worst-\(\alpha\) fraction of batch via
topk, average only those losses. Simple but noisy with small batches. - SoftCVaR (Rockafellar-Uryasev). Differentiable CVaR: \(\text{CVaR}_{\alpha} = \tau + \frac{1}{\alpha} \mathbb{E}[\max(L - \tau, 0)]\). Learnable threshold \(\tau\) optimized with separate SGD (not AdamW—Adam normalizes the gradient, making \(\tau\) updates too small).
- StreamCVaR. EMA estimate of the VaR threshold. Single-pass, no extra forward pass needed.
- DORO. Skip top-\(\epsilon\) outliers (presumed mislabeled), then apply CVaR on the rest.
Results
SoftCVaR at \(\alpha=0.1\) beats ERM on every metric: mean, CVaR@5%, and max.
Training dynamics
The chart below shows test loss over training. Select a metric to compare ERM vs SoftCVaR.
ERM reaches its best mean loss around step 400, then the tail metrics (CVaR@5%, max) start climbing while mean stays flat. The model overfits to easy examples, worsening its worst-case performance. SoftCVaR avoids this by zeroing gradients on samples below the learned threshold \(\tau\), preventing the model from distorting representations to fit already-easy examples.
Full method comparison (all metrics)
Few-shot ablation
Next steps
- Generation-time eval (pass@k, not just loss)
- Scale to 7B model and MATH dataset
- Apply CVaR to the RL objective (GRPO/REINFORCE), not just SFT
Related work
- Oren et al. (2019) – topic-based CVaR for language modeling
- Group DRO (Sagawa et al., 2019) – worst-case subpopulation performance
- Polychromic (Hamid et al., 2025) – set-level objectives for LLM robustness
- RiskPO (Chen et al., 2025) – CVaR-based RL
- G-Pass@k (Liu et al., 2024) – stable reasoning metrics
Links
- Paper draft (in progress)
- Code