DRO for LLM Reliability

Using distributionally robust optimization to make LLM responses consistently good, not just good on average

Updated 2026-02-14

Standard RL for LLMs (GRPO, best-of-N, rejection sampling) optimizes average single-sample performance. A prompt solved 3 out of \(k\) times gets the same training signal as one solved \(k-1\) out of \(k\). We’re applying CVaR-style distributionally robust optimization to optimize the tail of the per-prompt success distribution instead.

The gap

Current methods exploit variance rather than reduce it. Best-of-N sampling benefits from high variance; GRPO is indifferent to per-prompt consistency. We want models that reliably succeed, not models that sometimes succeed brilliantly.

Method

For each training sample, compute the loss and apply CVaR: focus gradient on the worst-performing fraction \(\alpha\) of the batch. Parameter \(\alpha \in (0,1]\) interpolates between expected loss (\(\alpha = 1\), standard ERM) and worst-case (\(\alpha \to 0\)).

We tested five CVaR variants on GSM8K SFT with Qwen2.5-Math-1.5B-Instruct (1,319 test problems, 2,000 training steps):

  • Hard CVaR. Select worst-\(\alpha\) fraction of batch via topk, average only those losses. Simple but noisy with small batches.
  • SoftCVaR (Rockafellar-Uryasev). Differentiable CVaR: \(\text{CVaR}_{\alpha} = \tau + \frac{1}{\alpha} \mathbb{E}[\max(L - \tau, 0)]\). Learnable threshold \(\tau\) optimized with separate SGD (not AdamW—Adam normalizes the gradient, making \(\tau\) updates too small).
  • StreamCVaR. EMA estimate of the VaR threshold. Single-pass, no extra forward pass needed.
  • DORO. Skip top-\(\epsilon\) outliers (presumed mislabeled), then apply CVaR on the rest.

Results

CVaR vs ERM results on GSM8K

SoftCVaR at \(\alpha=0.1\) beats ERM on every metric: mean, CVaR@5%, and max.

Training dynamics

The chart below shows test loss over training. Select a metric to compare ERM vs SoftCVaR.

ERM reaches its best mean loss around step 400, then the tail metrics (CVaR@5%, max) start climbing while mean stays flat. The model overfits to easy examples, worsening its worst-case performance. SoftCVaR avoids this by zeroing gradients on samples below the learned threshold \(\tau\), preventing the model from distorting representations to fit already-easy examples.

Full method comparison (all metrics)
Full method comparison

Few-shot ablation

Few-shot ablation on verification F1
Two retrieved examples give perfect F1 on GSM8K and the best average. More examples help hard splits (MATH: 85.7→90.4 at n=16) but hurt easy ones—likely because the 1,500-char truncation limit crowds out the actual solution being verified.

Next steps

  • Generation-time eval (pass@k, not just loss)
  • Scale to 7B model and MATH dataset
  • Apply CVaR to the RL objective (GRPO/REINFORCE), not just SFT
  • Paper draft (in progress)
  • Code