GRPO(Group Relative Policy Optimization)公式速览
把 1600 行源码浓缩成一页可抄进论文的公式表。
1 组内归一化优势(Group-Relative Advantage)
q | prompt | prompt |
G | 组大小 | num_generations |
o_i | 第 i 条 completion | completions[i] |
r(q,o_i) | 奖励 | rewards[i] |
μr(q)=1G∑i=1Gr(q,oi)σr(q)=stdi=1..G r(q,oi)Aq,oi=r(q,oi)−μr(q)σr(q)+ε(ε=1×10−4)
\\boxed{
\\begin{aligned}
\\mu_r(q) &= \\frac{1}{G}\\sum_{i=1}^{G} r(q, o_i) \\\\
\\sigma_r(q) &= \\text{std}_{i=1..G}\\, r(q, o_i) \\\\
A_{q,o_i} &= \\frac{r(q, o_i) – \\mu_r(q)}{\\sigma_r(q) + \\varepsilon} \\quad (\\varepsilon = 1\\times10^{-4})
\\end{aligned}
}
μr(q)σr(q)Aq,oi=G1i=1∑Gr(q,oi)=stdi=1..Gr(q,oi)=σr(q)+εr(q,oi)−μr(q)(ε=1×10−4)
2 策略裁剪目标(Token-Level)
ri,t(θ)=πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)Lclip(θ)=∑i=1G∑t=1∣oi∣min (ri,t(θ) Aq,oi, clip(ri,t(θ), 1−εlow, 1+εhigh) Aq,oi)
\\boxed{
\\begin{aligned}
r_{i,t}(\\theta) &= \\frac{\\pi_\\theta(o_{i,t}\\mid q, o_{i,<t})}{\\pi_{\\theta_{\\text{old}}}(o_{i,t}\\mid q, o_{i,<t})} \\\\[6pt]
L_{\\text{clip}}(\\theta) &= \\sum_{i=1}^{G}\\sum_{t=1}^{|o_i|} \\min\\!\\Bigl(
r_{i,t}(\\theta)\\,A_{q,o_i},\\;
\\text{clip}(r_{i,t}(\\theta),\\,1-\\varepsilon_{\\text{low}},\\,1+\\varepsilon_{\\text{high}})\\,A_{q,o_i}
\\Bigr)
\\end{aligned}
}
ri,t(θ)Lclip(θ)=πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)=i=1∑Gt=1∑∣oi∣min(ri,t(θ)Aq,oi,clip(ri,t(θ),1−εlow,1+εhigh)Aq,oi)
代码对应:per_token_logps, old_per_token_logps, coef_1, coef_2
3 KL 正则项(可选,β>0 时启用
KLreg=β DKL [πθ ∥ πref]\\boxed{
\\mathrm{KL}_{\\mathrm{reg}} = \\beta \\; D_{\\mathrm{KL}}\\!\\bigl[\\pi_{\\theta}\\,\\|\\,\\pi_{\\mathrm{ref}}\\bigr]}KLreg=βDKL[πθ∥πref]
代码对应:per_token_kl, beta
4 最终损失(Token-Level)
| GRPO | Ltotal=−Lclip∑t1+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{\\sum_t 1} + \\text{KL}_{\\text{reg}}Ltotal=−∑t1Lclip+KLreg
| BNPO | Ltotal=−Lclip∑t1+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{\\sum_t 1} + \\text{KL}_{\\text{reg}}Ltotal=−∑t1Lclip+KLreg
| DR-GRPO | Ltotal=−LclipB⋅T+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{B\\cdot T} + \\text{KL}_{\\text{reg}}Ltotal=−B⋅TLclip+KLreg
- 代码由 loss_type 参数切换。
**BNPO vs GRPO:一句话速记 **
BNPO = GRPO 的“奖励归一化外挂”
二者共用同一套“组内相对优势 + KL + clip”框架,只是 BNPO 把静态均值-方差换成了动态 Beta 归一化。
✅ 核心差别表
归一化方式 | 组内均值-方差(静态) | Beta 分布自适应 |
奖励假设 | 任何数值 | 二值奖励 Bernoulli |
基线更新 | 每次 batch 重算 μ, σ | 实时更新 α, β 参数 |
梯度方差 | 固定 | 随策略动态减小 |
是否 GRPO 的超集 | ✗ | 是:GRPO 是 β 固定时的特例 |
✅ 公式对照(一句话看懂)
GRPO | AGRPO=r−μrσr+εA_{\\text{GRPO}} = \\frac{r – \\mu_r}{\\sigma_r+\\varepsilon}AGRPO=σr+εr−μr |
BNPO | ABNPO=r−μ^μ^(1−μ^)+δA_{\\text{BNPO}} = \\frac{r – \\hat{\\mu}}{\\sqrt{\\hat{\\mu}(1-\\hat{\\mu})+\\delta}}ABNPO=μ^(1−μ^)+δr−μ^ 其中 μ^∼Beta(α,β)\\hat{\\mu}\\sim\\text{Beta}(\\alpha,\\beta)μ^∼Beta(α,β),参数 α,β 用最近 N 步的奖励在线估计。 |
✅ 使用场景
- GRPO:通用、简单、快速实现。
- BNPO:
- 奖励只有 0/1(正确/错误)
- 训练初期奖励分布漂移大(如数学推理任务)
- 需要更低梯度方差、更高稳定性
✅ 结论
- BNPO 不推翻 GRPO,只是把“静态均值基线”升级为“动态 Beta 基线”。
- 在 TRL 中只需把 loss_type='bnpo' 即可启用,其余流程(采样、clip、KL)完全一致。
Huggingface TRL中是怎么实现的
计算reward
def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
device = self.accelerator.device
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
# This allows for dynamic reward shaping based on training progress.
reward_kwargs["trainer_state"] = self.state
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):
with profiling_context(self, reward_func_name):
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
output_reward_func = reward_func(
prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
)
# Convert None values to NaN
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
return rewards_per_func
_calculate_rewards 把“多条 completion”喂给“多个奖励函数”,返回一张 (B×G, F) 的奖励矩阵,并跨进程同步,为后续 GRPO 组内归一化做准备。
✅ 输入输出
prompts | List[str] 或 List[Messages],长度 = B×G |
completions | List[str] 或 List[Messages],长度 = B×G |
completion_ids_list | List[List[int]],token id,长度 = B×G |
返回值 | Tensor 形状 (B×G, F),F = 奖励函数个数 |
✅ 核心步骤
初始化容器
rewards_per_func = zeros(B×G, F) 先占好位置。
把额外列打包成 kwargs
任何 inputs[0] 里除 "prompt"/"completion"/"completion_ids" 以外的字段全部按行重复,供自定义奖励函数使用。
遍历 F 个奖励函数
- 如果是模型(nn.Module):
- 构造 prompt+completion 的文本 → 走 tokenizer → 前向 → 取 logits[:, 0] 作为标量奖励。
- 如果是函数(Callable):
- 直接调用,允许返回 None → 转成 NaN 占位。
跨进程同步
gather(rewards_per_func) 让 所有 GPU 拿到 全局 (N×G, F) 奖励矩阵,保证后续组内归一化一致。
异常检测
如果某一行全是 NaN,打印详细 warning,方便排查奖励函数漏返回值。
✅ 总结
“把 B×G 条 completion 喂给 F 个奖励函数,跨进程收集结果,生成 (B×G, F) 的奖励张量,供 GRPO 做组内归一化。”
_generate_and_score_completions
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) –> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompts = [x["prompt"] for x in inputs]
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
# VLM chat template.
original_prompts = copy.deepcopy(prompts)
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list):
for message in prompt:
if not isinstance(message, dict):
continue
content = message.get("content")
role = message.get("role")
if isinstance(content, str):
if role == "user":
message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
elif role == "system":
message["content"] = [{"type": "text", "text": content}]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids, prompt_mask = truncate_with_protected_tokens(
prompt_ids, prompt_mask, self.max_prompt_length, protected
)
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original chat template in case it originally
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
# the vision_start_token_id (e.g. <start_of_image>).
if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
# Generate completions using either vLLM or regular generation
if self.use_vllm:
# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if has_images:
all_images = gather_object(images)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
if has_images:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=–1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
else:
guided_decoding = None
generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": –1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
}
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
if has_images:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None
else:
all_prompts_text = prompts_text
all_images = images if has_images else None
if has_images and all_images:
vllm_inputs = []
for prompt, image in zip(all_prompts_text, all_images):
if image is not None:
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
else:
vllm_inputs.append(prompt)
else:
vllm_inputs = all_prompts_text
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
if self.vllm_tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
# Note: images are already validated and preprocessed above
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
previous_attn = self.model_wrapped.config._attn_implementation
if is_flash_attn_2_available():
self.model_wrapped.config._attn_implementation = "paged_attention"
else:
self.model_wrapped.config._attn_implementation = "sdpa_paged"
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
# Cast to the appropriate dtype based on training configuration
if self.args.bf16:
unwrapped_model.to(torch.bfloat16)
elif self.args.fp16:
unwrapped_model.to(torch.float16)
with torch.inference_mode():
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
else:
# Regular generation path
with (
profiling_context(self, "transformers.generate"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), –1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
completion_ids_list = [
[id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
]
# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
completion_lengths = completion_mask.sum(1)
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
# old_per_token_logps to None.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
if self.args.gradient_accumulation_steps % generate_every != 0:
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
else:
old_per_token_logps = None
# Compute the per-token log probabilities for the reference model
if self.beta != 0.0:
if self.ref_model is not None:
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.ref_model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
else:
ref_per_token_logps = None
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = prompt.pop()["content"] if prompt[–1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(–1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(–1, self.num_generations).std(dim=1)
is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards – mean_grouped_rewards
if self.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]
# Log the metrics
if mode == "train":
self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
# Log completion lengths, mean, min, max
agg_completion_lengths = self.accelerator.gather(completion_lengths)
self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
# Identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
clipped_completions_ratio = 1 – len(term_completion_lengths) / len(agg_completion_lengths)
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
term_completion_lengths = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
std_rewards = nanstd(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
# Log prompt and completion texts
self._logs["prompt"].extend(gather_object(prompts_text))
self._logs["completion"].extend(gather_object(completions_text))
for i, name in enumerate(self.reward_func_names):
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())
if has_images:
self._logs["image"].extend(gather_object(images))
output = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
}
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
if ref_per_token_logps is not None:
output["ref_per_token_logps"] = ref_per_token_logps
if "pixel_values" in prompt_inputs:
output["pixel_values"] = prompt_inputs["pixel_values"]
if "image_grid_thw" in prompt_inputs:
output["image_grid_thw"] = prompt_inputs["image_grid_thw"]
if "pixel_attention_mask" in prompt_inputs:
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
return output
_generate_and_score_completions 是 GRPOTrainer 的“心脏”——
一次性完成 prompt 处理 → 多后端生成 → 奖励打分 → 组内归一化 → 输出训练所需全部张量。
✅ 一句话总结
“把一批 prompt 变成 B×G 条 completion,奖励打分后算组内优势,打包成可直接喂给损失函数的训练字典。”
✅ 核心流程(8 步速记)
1️⃣ 输入准备 | 提取 prompt、处理图文 | prompts, has_images |
2️⃣ token 化 | 左填充、截断、保护特殊 token | truncate_with_protected_tokens |
3️⃣ 生成 | vLLM / transformers / paged-attention 三选一 | completion_ids |
4️⃣ 后处理 | 截断 EOS、生成 mask | completion_mask, completion_lengths |
5️⃣ 奖励打分 | 调用 _calculate_rewards | rewards_per_func |
6️⃣ 加权求和 | 多奖励函数加权 → 单条奖励 | rewards |
7️⃣ 组内归一化 | 均值-方差 → 优势 | advantages |
8️⃣ 跨进程同步 | gather & slice 保证分布式一致 | gather, process_slice |
✅ 输出字典(可直接喂损失)
{
"prompt_ids" : Tensor, # (B×G, P)
"prompt_mask" : Tensor, # (B×G, P)
"completion_ids" : Tensor, # (B×G, C)
"completion_mask" : Tensor, # (B×G, C)
"advantages" : Tensor, # (B×G,) 组内归一化优势
"old_per_token_logps" : Tensor, # 可选,重要性采样
"ref_per_token_logps" : Tensor, # 可选,KL 计算
... # 图像相关字段(若多模态)
}
✅ 再总结
只要调用一次 _generate_and_score_completions,就能把“prompt”变成“带优势的训练样本”。
评论前必须登录!
注册