云计算百科
云计算领域专业知识百科平台

GRPO(Group Relative Policy Optimization)公式速览

GRPO(Group Relative Policy Optimization)公式速览

把 1600 行源码浓缩成一页可抄进论文的公式表。


1 组内归一化优势(Group-Relative Advantage)

符号含义代码变量
q prompt prompt
G 组大小 num_generations
o_i 第 i 条 completion completions[i]
r(q,o_i) 奖励 rewards[i]

μr(q)=1G∑i=1Gr(q,oi)σr(q)=stdi=1..G r(q,oi)Aq,oi=r(q,oi)−μr(q)σr(q)+ε(ε=1×10−4)
\\boxed{
\\begin{aligned}
\\mu_r(q) &= \\frac{1}{G}\\sum_{i=1}^{G} r(q, o_i) \\\\
\\sigma_r(q) &= \\text{std}_{i=1..G}\\, r(q, o_i) \\\\
A_{q,o_i} &= \\frac{r(q, o_i) – \\mu_r(q)}{\\sigma_r(q) + \\varepsilon} \\quad (\\varepsilon = 1\\times10^{-4})
\\end{aligned}
}
μr(q)σr(q)Aq,oi=G1i=1Gr(q,oi)=stdi=1..Gr(q,oi)=σr(q)+εr(q,oi)μr(q)(ε=1×104)


2 策略裁剪目标(Token-Level)

ri,t(θ)=πθ(oi,t∣q,oi,<t)πθold(oi,t∣q,oi,<t)Lclip(θ)=∑i=1G∑t=1∣oi∣min⁡ ⁣(ri,t(θ) Aq,oi,  clip(ri,t(θ), 1−εlow, 1+εhigh) Aq,oi)
\\boxed{
\\begin{aligned}
r_{i,t}(\\theta) &= \\frac{\\pi_\\theta(o_{i,t}\\mid q, o_{i,<t})}{\\pi_{\\theta_{\\text{old}}}(o_{i,t}\\mid q, o_{i,<t})} \\\\[6pt]
L_{\\text{clip}}(\\theta) &= \\sum_{i=1}^{G}\\sum_{t=1}^{|o_i|} \\min\\!\\Bigl(
r_{i,t}(\\theta)\\,A_{q,o_i},\\;
\\text{clip}(r_{i,t}(\\theta),\\,1-\\varepsilon_{\\text{low}},\\,1+\\varepsilon_{\\text{high}})\\,A_{q,o_i}
\\Bigr)
\\end{aligned}
}
ri,t(θ)Lclip(θ)=πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)=i=1Gt=1oimin(ri,t(θ)Aq,oi,clip(ri,t(θ),1εlow,1+εhigh)Aq,oi)

代码对应:per_token_logps, old_per_token_logps, coef_1, coef_2


3 KL 正则项(可选,β>0 时启用

KLreg=β  DKL ⁣[πθ ∥ πref]\\boxed{
\\mathrm{KL}_{\\mathrm{reg}} = \\beta \\; D_{\\mathrm{KL}}\\!\\bigl[\\pi_{\\theta}\\,\\|\\,\\pi_{\\mathrm{ref}}\\bigr]}
KLreg=βDKL[πθπref]

代码对应:per_token_kl, beta


4 最终损失(Token-Level)

| GRPO | Ltotal=−Lclip∑t1+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{\\sum_t 1} + \\text{KL}_{\\text{reg}}Ltotal=t1Lclip+KLreg
| BNPO | Ltotal=−Lclip∑t1+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{\\sum_t 1} + \\text{KL}_{\\text{reg}}Ltotal=t1Lclip+KLreg
| DR-GRPO | Ltotal=−LclipB⋅T+KLreg\\displaystyle L_{\\text{total}} = -\\frac{L_{\\text{clip}}}{B\\cdot T} + \\text{KL}_{\\text{reg}}Ltotal=BTLclip+KLreg

  • 代码由 loss_type 参数切换。

**BNPO vs GRPO:一句话速记 **

BNPO = GRPO 的“奖励归一化外挂”
二者共用同一套“组内相对优势 + KL + clip”框架,只是 BNPO 把静态均值-方差换成了动态 Beta 归一化。


✅ 核心差别表

维度GRPOBNPO
归一化方式 组内均值-方差(静态) Beta 分布自适应
奖励假设 任何数值 二值奖励 Bernoulli
基线更新 每次 batch 重算 μ, σ 实时更新 α, β 参数
梯度方差 固定 随策略动态减小
是否 GRPO 的超集 是:GRPO 是 β 固定时的特例

✅ 公式对照(一句话看懂)

算法优势函数
GRPO AGRPO=r−μrσr+εA_{\\text{GRPO}} = \\frac{r – \\mu_r}{\\sigma_r+\\varepsilon}AGRPO=σr+εrμr
BNPO ABNPO=r−μ^μ^(1−μ^)+δA_{\\text{BNPO}} = \\frac{r – \\hat{\\mu}}{\\sqrt{\\hat{\\mu}(1-\\hat{\\mu})+\\delta}}ABNPO=μ^(1μ^)+δrμ^ 其中 μ^∼Beta(α,β)\\hat{\\mu}\\sim\\text{Beta}(\\alpha,\\beta)μ^Beta(α,β),参数 α,β 用最近 N 步的奖励在线估计。

✅ 使用场景

  • GRPO:通用、简单、快速实现。
  • BNPO:
    • 奖励只有 0/1(正确/错误)
    • 训练初期奖励分布漂移大(如数学推理任务)
    • 需要更低梯度方差、更高稳定性

✅ 结论

  • BNPO 不推翻 GRPO,只是把“静态均值基线”升级为“动态 Beta 基线”。
  • 在 TRL 中只需把 loss_type='bnpo' 即可启用,其余流程(采样、clip、KL)完全一致。

Huggingface TRL中是怎么实现的

计算reward

def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
device = self.accelerator.device
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)

# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}

# This allows for dynamic reward shaping based on training progress.
reward_kwargs["trainer_state"] = self.state

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
):
with profiling_context(self, reward_func_name):
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
output_reward_func = reward_func(
prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
)
# Convert None values to NaN
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]

rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)

# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
return rewards_per_func

_calculate_rewards 把“多条 completion”喂给“多个奖励函数”,返回一张 (B×G, F) 的奖励矩阵,并跨进程同步,为后续 GRPO 组内归一化做准备。


✅ 输入输出

参数形状 / 含义
prompts List[str] 或 List[Messages],长度 = B×G
completions List[str] 或 List[Messages],长度 = B×G
completion_ids_list List[List[int]],token id,长度 = B×G
返回值 Tensor 形状 (B×G, F),F = 奖励函数个数

✅ 核心步骤

  • 初始化容器
    rewards_per_func = zeros(B×G, F) 先占好位置。

  • 把额外列打包成 kwargs
    任何 inputs[0] 里除 "prompt"/"completion"/"completion_ids" 以外的字段全部按行重复,供自定义奖励函数使用。

  • 遍历 F 个奖励函数

    • 如果是模型(nn.Module):
      • 构造 prompt+completion 的文本 → 走 tokenizer → 前向 → 取 logits[:, 0] 作为标量奖励。
    • 如果是函数(Callable):
      • 直接调用,允许返回 None → 转成 NaN 占位。
  • 跨进程同步
    gather(rewards_per_func) 让 所有 GPU 拿到 全局 (N×G, F) 奖励矩阵,保证后续组内归一化一致。

  • 异常检测
    如果某一行全是 NaN,打印详细 warning,方便排查奖励函数漏返回值。


  • ✅ 总结

    “把 B×G 条 completion 喂给 F 个奖励函数,跨进程收集结果,生成 (B×G, F) 的奖励张量,供 GRPO 做组内归一化。”

    _generate_and_score_completions

    def _generate_and_score_completions(
    self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) > dict[str, Union[torch.Tensor, Any]]:
    device = self.accelerator.device
    mode = "train" if self.model.training else "eval"

    prompts = [x["prompt"] for x in inputs]

    # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
    # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
    # VLM chat template.
    original_prompts = copy.deepcopy(prompts)

    # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
    # [{"role": "user", "content": "What color is the sky?"}] to
    # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
    kwargs = {}
    has_images = "image" in inputs[0]
    if has_images:
    images = [example.get("image") for example in inputs]
    kwargs = {"images": [[img] for img in images]}
    for prompt in prompts:
    if isinstance(prompt, list):
    for message in prompt:
    if not isinstance(message, dict):
    continue
    content = message.get("content")
    role = message.get("role")
    if isinstance(content, str):
    if role == "user":
    message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
    elif role == "system":
    message["content"] = [{"type": "text", "text": content}]

    prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

    prompt_inputs = self.processing_class(
    text=prompts_text,
    return_tensors="pt",
    padding=True,
    padding_side="left",
    add_special_tokens=False,
    **kwargs,
    )
    prompt_inputs = super()._prepare_inputs(prompt_inputs)
    prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

    if self.max_prompt_length is not None:
    # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
    # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
    # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
    protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
    protected = [token for token in protected if token is not None]
    prompt_ids, prompt_mask = truncate_with_protected_tokens(
    prompt_ids, prompt_mask, self.max_prompt_length, protected
    )

    prompts_text = self.processing_class.batch_decode(
    prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]

    # The chat template sometimes inserts a single image token into the prompt text. However, when this text is
    # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
    # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
    # collapse them back into a single token string to match the original chat template in case it originally
    # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
    # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
    # the vision_start_token_id (e.g. <start_of_image>).
    if self.image_token is not None:
    escaped_img_token = re.escape(self.image_token)
    # Search for the image token in the chat template
    if re.search(escaped_img_token, self.processing_class.chat_template):
    prompts_text = [
    re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
    ]
    else:
    # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
    if self.vision_end_token_id is not None:
    escaped_eoi_token = re.escape(
    self.processing_class.tokenizer.decode([self.vision_end_token_id])
    )
    prompts_text = [
    re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
    ]
    else:
    # If vision_end_token_id is None, just remove the image tokens
    prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]

    # Generate completions using either vLLM or regular generation
    if self.use_vllm:
    # First, update the vLLM weights if needed
    if self.state.global_step != self._last_loaded_step:
    self._move_model_to_vllm()
    self._last_loaded_step = self.state.global_step

    # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
    if self.vllm_mode == "server":
    all_prompts_text = gather_object(prompts_text)
    if has_images:
    all_images = gather_object(images)

    if self.accelerator.is_main_process:
    # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
    # num_generations outputs for each one. This is faster than generating outputs for each duplicate
    # prompt individually.
    ordered_set_of_prompts = all_prompts_text[:: self.num_generations]

    if has_images:
    ordered_set_of_images = all_images[:: self.num_generations]
    else:
    ordered_set_of_images = None

    with profiling_context(self, "vLLM.generate"):
    completion_ids = self.vllm_client.generate(
    prompts=ordered_set_of_prompts,
    images=ordered_set_of_images,
    n=self.num_generations,
    repetition_penalty=self.repetition_penalty,
    temperature=self.temperature,
    top_p=self.top_p,
    top_k=1 if self.top_k is None else self.top_k,
    min_p=0.0 if self.min_p is None else self.min_p,
    max_tokens=self.max_completion_length,
    guided_decoding_regex=self.guided_decoding_regex,
    generation_kwargs=self.args.generation_kwargs,
    )
    else:
    completion_ids = [None] * len(all_prompts_text)
    # Broadcast the completions from the main process to all processes, ensuring each process receives its
    # corresponding slice.
    completion_ids = broadcast_object_list(completion_ids, from_process=0)
    process_slice = slice(
    self.accelerator.process_index * len(prompts),
    (self.accelerator.process_index + 1) * len(prompts),
    )
    completion_ids = completion_ids[process_slice]

    # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
    elif self.vllm_mode == "colocate":
    if self.guided_decoding_regex:
    guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
    else:
    guided_decoding = None

    generation_kwargs = {
    "n": 1, # vLLM on each GPU generates only 1 in colocate mode
    "repetition_penalty": self.repetition_penalty,
    "temperature": self.temperature,
    "top_p": self.top_p,
    "top_k": 1 if self.top_k is None else self.top_k,
    "min_p": 0.0 if self.min_p is None else self.min_p,
    "max_tokens": self.max_completion_length,
    "guided_decoding": guided_decoding,
    }
    if self.args.generation_kwargs is not None:
    generation_kwargs.update(self.args.generation_kwargs)
    sampling_params = SamplingParams(**generation_kwargs)

    if self.vllm_tensor_parallel_size > 1:
    # Gather prompts from all ranks in the TP group and flatten.
    # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
    orig_size = len(prompts_text)
    gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
    torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
    all_prompts_text = [p for sublist in gathered_prompts for p in sublist]

    if has_images:
    gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
    torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
    all_images = [img for sublist in gathered_images for img in sublist]
    else:
    all_images = None
    else:
    all_prompts_text = prompts_text
    all_images = images if has_images else None

    if has_images and all_images:
    vllm_inputs = []
    for prompt, image in zip(all_prompts_text, all_images):
    if image is not None:
    vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
    else:
    vllm_inputs.append(prompt)
    else:
    vllm_inputs = all_prompts_text

    with profiling_context(self, "vLLM.generate"):
    all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)

    completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

    if self.vllm_tensor_parallel_size > 1:
    # Slice completions for this rank within its TP group.
    # Each rank generates all outputs — we keep only our share.
    local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
    tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
    completion_ids = completion_ids[tp_slice]

    # Pad the completions, and concatenate them with the prompts
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
    completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
    prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

    elif self.use_transformers_paged:
    # Re-process inputs for paged generation if needed
    # Note: images are already validated and preprocessed above
    paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
    previous_attn = self.model_wrapped.config._attn_implementation

    if is_flash_attn_2_available():
    self.model_wrapped.config._attn_implementation = "paged_attention"
    else:
    self.model_wrapped.config._attn_implementation = "sdpa_paged"
    with (
    profiling_context(self, "transformers.generate_batch"),
    unwrap_model_for_generation(
    self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
    ) as unwrapped_model,
    torch.no_grad(),
    FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
    ):
    # Cast to the appropriate dtype based on training configuration
    if self.args.bf16:
    unwrapped_model.to(torch.bfloat16)
    elif self.args.fp16:
    unwrapped_model.to(torch.float16)
    with torch.inference_mode():
    all_outputs = unwrapped_model.generate_batch(
    paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
    )
    completion_ids = [output.generated_tokens for output in all_outputs.values()]
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
    completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
    prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
    prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
    prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    # Restore the original attention implementation, training mode
    self.model_wrapped.config._attn_implementation = previous_attn
    else:
    # Regular generation path
    with (
    profiling_context(self, "transformers.generate"),
    unwrap_model_for_generation(
    self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
    ) as unwrapped_model,
    torch.no_grad(),
    FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
    ):
    prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask
    prompt_completion_ids = unwrapped_model.generate(
    **prompt_inputs, generation_config=self.generation_config, disable_compile=True
    )
    # Compute prompt length and extract completion ids
    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_completion_ids[:, :prompt_length]
    completion_ids = prompt_completion_ids[:, prompt_length:]

    # Mask everything after the first EOS token
    is_eos = completion_ids == self.eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
    eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
    sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), 1)
    completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

    # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
    # to re-tokenize completions if the reward is computed from tokens.
    completion_ids_list = [
    [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
    ]

    # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
    completion_lengths = completion_mask.sum(1)

    # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
    if self.mask_truncated_completions:
    truncated_completions = ~is_eos.any(dim=1)
    completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()

    # Concatenate prompt_mask with completion_mask for logit computation
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)

    logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
    batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

    with torch.no_grad():
    # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
    # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
    # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
    # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
    # old_per_token_logps to None.
    generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
    if self.args.gradient_accumulation_steps % generate_every != 0:
    old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
    self.model,
    prompt_completion_ids,
    attention_mask,
    logits_to_keep,
    batch_size,
    pixel_values=prompt_inputs.get("pixel_values"),
    image_grid_thw=prompt_inputs.get("image_grid_thw"),
    pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
    image_sizes=prompt_inputs.get("image_sizes"),
    )
    else:
    old_per_token_logps = None

    # Compute the per-token log probabilities for the reference model
    if self.beta != 0.0:
    if self.ref_model is not None:
    ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
    self.ref_model,
    prompt_completion_ids,
    attention_mask,
    logits_to_keep,
    batch_size=batch_size,
    pixel_values=prompt_inputs.get("pixel_values"),
    image_grid_thw=prompt_inputs.get("image_grid_thw"),
    pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
    image_sizes=prompt_inputs.get("image_sizes"),
    )
    else:
    with self.accelerator.unwrap_model(self.model).disable_adapter():
    ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
    self.model,
    prompt_completion_ids,
    attention_mask,
    logits_to_keep,
    batch_size=batch_size,
    pixel_values=prompt_inputs.get("pixel_values"),
    image_grid_thw=prompt_inputs.get("image_grid_thw"),
    pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
    image_sizes=prompt_inputs.get("image_sizes"),
    )
    else:
    ref_per_token_logps = None

    # Decode the generated completions
    completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
    if is_conversational(inputs[0]):
    completions = []
    for prompt, completion in zip(prompts, completions_text):
    bootstrap = prompt.pop()["content"] if prompt[1]["role"] == "assistant" else ""
    completions.append([{"role": "assistant", "content": bootstrap + completion}])
    else:
    completions = completions_text

    # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
    # important because rewards will be normalized per group, and completions are distributed. We will later slice
    # rewards_per_func to extract each process's subset.
    rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)

    # Apply weights to each reward function's output and sum
    rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

    # Compute grouped-wise rewards
    mean_grouped_rewards = rewards.view(1, self.num_generations).mean(dim=1)
    std_grouped_rewards = rewards.view(1, self.num_generations).std(dim=1)
    is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards))

    # Normalize the rewards to compute the advantages
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
    advantages = rewards mean_grouped_rewards
    if self.scale_rewards:
    advantages = advantages / (std_grouped_rewards + 1e-4)

    # Slice to keep only the local part of the data
    process_slice = slice(
    self.accelerator.process_index * len(prompts),
    (self.accelerator.process_index + 1) * len(prompts),
    )
    all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
    advantages = advantages[process_slice]

    # Log the metrics
    if mode == "train":
    self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
    self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

    # Log completion lengths, mean, min, max
    agg_completion_lengths = self.accelerator.gather(completion_lengths)
    self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
    self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
    self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

    # Identify sequences that terminated with EOS and log their lengths
    agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
    term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
    clipped_completions_ratio = 1 len(term_completion_lengths) / len(agg_completion_lengths)
    self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
    if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
    term_completion_lengths = torch.zeros(1, device=device)
    self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
    self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
    self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

    # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
    for i, reward_func_name in enumerate(self.reward_func_names):
    mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
    self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
    std_rewards = nanstd(rewards_per_func[:, i]).item()
    self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
    self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
    self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
    self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())

    # Log prompt and completion texts
    self._logs["prompt"].extend(gather_object(prompts_text))
    self._logs["completion"].extend(gather_object(completions_text))
    for i, name in enumerate(self.reward_func_names):
    self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
    self._logs["advantages"].extend(all_process_advantages.tolist())

    if has_images:
    self._logs["image"].extend(gather_object(images))

    output = {
    "prompt_ids": prompt_ids,
    "prompt_mask": prompt_mask,
    "completion_ids": completion_ids,
    "completion_mask": completion_mask,
    "advantages": advantages,
    }
    if old_per_token_logps is not None:
    output["old_per_token_logps"] = old_per_token_logps
    if ref_per_token_logps is not None:
    output["ref_per_token_logps"] = ref_per_token_logps
    if "pixel_values" in prompt_inputs:
    output["pixel_values"] = prompt_inputs["pixel_values"]
    if "image_grid_thw" in prompt_inputs:
    output["image_grid_thw"] = prompt_inputs["image_grid_thw"]
    if "pixel_attention_mask" in prompt_inputs:
    output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
    if "image_sizes" in prompt_inputs:
    output["image_sizes"] = prompt_inputs["image_sizes"]
    return output

    _generate_and_score_completions 是 GRPOTrainer 的“心脏”——
    一次性完成 prompt 处理 → 多后端生成 → 奖励打分 → 组内归一化 → 输出训练所需全部张量。


    ✅ 一句话总结

    “把一批 prompt 变成 B×G 条 completion,奖励打分后算组内优势,打包成可直接喂给损失函数的训练字典。”


    ✅ 核心流程(8 步速记)

    步骤关键动作代码/变量
    1️⃣ 输入准备 提取 prompt、处理图文 prompts, has_images
    2️⃣ token 化 左填充、截断、保护特殊 token truncate_with_protected_tokens
    3️⃣ 生成 vLLM / transformers / paged-attention 三选一 completion_ids
    4️⃣ 后处理 截断 EOS、生成 mask completion_mask, completion_lengths
    5️⃣ 奖励打分 调用 _calculate_rewards rewards_per_func
    6️⃣ 加权求和 多奖励函数加权 → 单条奖励 rewards
    7️⃣ 组内归一化 均值-方差 → 优势 advantages
    8️⃣ 跨进程同步 gather & slice 保证分布式一致 gather, process_slice

    ✅ 输出字典(可直接喂损失)

    {
    "prompt_ids" : Tensor, # (B×G, P)
    "prompt_mask" : Tensor, # (B×G, P)
    "completion_ids" : Tensor, # (B×G, C)
    "completion_mask" : Tensor, # (B×G, C)
    "advantages" : Tensor, # (B×G,) 组内归一化优势
    "old_per_token_logps" : Tensor, # 可选,重要性采样
    "ref_per_token_logps" : Tensor, # 可选,KL 计算
    ... # 图像相关字段(若多模态)
    }


    ✅ 再总结

    只要调用一次 _generate_and_score_completions,就能把“prompt”变成“带优势的训练样本”。

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » GRPO(Group Relative Policy Optimization)公式速览
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!