https://zhuanlan.zhihu.com/p/24816372882https://zhuanlan.zhihu.com/p/24816372882

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.special import softmax
from scipy.special import softmax, kl_div

# 词汇表及其分组
vocab = ["ant", "bear", "cat", "dog"]
is_vowel = lambda x: x[0].lower() in "aeiou"

class GRPO:
    def __init__(self, vocab, beta=0.1, epsilon=1e-8):
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.beta = beta
        self.epsilon = epsilon
        
        np.random.seed(42)
        self.theta = np.random.randn(self.vocab_size)
        
        # 添加用于存储训练历史的列表
        self.prob_history = []
        self.reward_history = []
        self.kl_div_history = []
        
    def get_policy_probs(self, logits):
        return softmax(logits)
    
    def sample_word(self, probs):
        word_idx = np.random.choice(len(self.vocab), p=probs)
        return word_idx
    
    def compute_reward(self, word_idx):
        word = self.vocab[word_idx]
        return 1.0 if is_vowel(word) else 0.0
    
    def compute_kl_divergence(self, old_probs, new_probs):
        return np.sum(kl_div(old_probs, new_probs))
    
    def plot_training_progress(self):
        # 创建三个子图
        fig = make_subplots(
            rows=3, cols=1,
            subplot_titles=('Word Probabilities Over Time', 'Mean Reward', 'KL Divergence'),
            vertical_spacing=0.1,
            row_heights=[0.4, 0.3, 0.3]
        )
        
        # 1. 词汇概率随时间变化
        prob_data = np.array(self.prob_history)
        for i, word in enumerate(self.vocab):
            fig.add_trace(
                go.Scatter(x=list(range(len(prob_data))), 
                          y=prob_data[:, i],
                          name=word,
                          mode='lines'),
                row=1, col=1
            )
            
        # 2. 平均奖励变化
        fig.add_trace(
            go.Scatter(x=list(range(len(self.reward_history))),
                      y=self.reward_history,
                      name='Mean Reward',
                      line=dict(color='green')),
            row=2, col=1
        )
        
        # 3. KL散度变化
        fig.add_trace(
            go.Scatter(x=list(range(len(self.kl_div_history))),
                      y=self.kl_div_history,
                      name='KL Divergence',
                      line=dict(color='red')),
            row=3, col=1
        )
        
        # 更新布局
        fig.update_layout(
            height=900,
            showlegend=True,
            title_text="GRPO Training Progress"
        )
        
        # 更新坐标轴标签
        fig.update_xaxes(title_text="Iteration", row=3, col=1)
        fig.update_yaxes(title_text="Probability", row=1, col=1)
        fig.update_yaxes(title_text="Reward", row=2, col=1)
        fig.update_yaxes(title_text="KL Divergence", row=3, col=1)
        
        fig.show()
    
    def train(self, num_iterations=50, num_samples=20):
        for iteration in range(1, num_iterations + 1):
            current_probs = self.get_policy_probs(self.theta)
            
            outputs = []
            rewards = []
            
            for _ in range(num_samples):
                word_idx = self.sample_word(current_probs)
                reward = self.compute_reward(word_idx)
                outputs.append((word_idx, reward))
                rewards.append(reward)
                
            rewards = np.array(rewards)
            mu_G = np.mean(rewards)
            sigma_G = np.std(rewards) + self.epsilon
            
            advantages = [(r - mu_G) / (sigma_G) for _, r in outputs]
            
            grad = np.zeros_like(self.theta)
            
            for (word_idx, _), advantage in zip(outputs, advantages):
                grad_sample = -current_probs.copy()
                grad_sample[word_idx] += 1.0
                grad += grad_sample * advantage
            
            learning_rate = 0.01
            self.theta += learning_rate * grad
            
            # 记录训练历史
            new_probs = self.get_policy_probs(self.theta)
            kl_div_value = self.compute_kl_divergence(current_probs, new_probs)
            
            self.prob_history.append(new_probs)
            self.reward_history.append(mu_G)
            self.kl_div_history.append(kl_div_value)
            
            if iteration % 10 == 0:
                print(f"\nIteration {iteration}")
                print("Probabilities:")
                for word, prob in zip(vocab, new_probs):
                    print(f"  {word:12s}: {prob:.3f}")
                print(f"Mean Reward: {mu_G:.3f}")
                print(f"KL Divergence: {kl_div_value:.3f}")
        
        # 训练结束后显示可视化结果
        self.plot_training_progress()

# 运行训练
grpo = GRPO(vocab)
grpo.train()
import torch
import torch.nn.functional as F
import random
import copy

def selective_log_softmax(logits, input_ids):
    """
    作用:
        1. 计算模型对实际生成的token的预测概率
        2. 这些概率后续用于计算策略梯度和KL散度
        3. 通过对数概率进行计算可以提高数值稳定性

    参数:
        logits (torch.Tensor): 张量,形状为 (batch_size, seq_len, vocab_size),表示模型的原始logits输出。
        input_ids (torch.Tensor): 张量,形状为 (batch_size, seq_len),表示需要计算log概率的tokens索引。

    返回:
        torch.Tensor: 张量,形状为 (batch_size, seq_len),表示input_ids中每个token对应的log概率。

    解释:
        1. 使用F.log_softmax在词汇表维度(dim=-1)上将logits转换为log概率。
        2. 将input_ids通过unsqueeze增加一个额外维度,以便用作log_probs中的索引。
        3. 使用torch.gather提取log_probs中每个位置对应input_ids的log概率。
        4. 最后使用squeeze(-1)移除多余的维度,返回与input_ids形状相同的张量。
    """
    # 将原始logits转换为log概率,计算维度为词汇表维度。
    log_probs = F.log_softmax(logits, dim=-1)  # 形状: (batch_size, seq_len, vocab_size)
    
    # 将input_ids从 (batch_size, seq_len) 转换为 (batch_size, seq_len, 1),以便使用gather获取指定位置的log概率。
    selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
    
    # 去掉最后多余的维度,返回形状为 (batch_size, seq_len) 的张量。
    return selected_log_probs.squeeze(-1)

def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
    """
    作用:
        1. 计算模型对部分tokens(通常是补全部分)的逐token log概率
        2. 只关注completion部分的token,因为这是我们要优化的部分
        3. 处理了自回归生成中的对齐问题
        4. 为后续计算策略梯度和KL散度提供必要的概率值

    参数:
        model: 使用的语言模型。
        input_ids (torch.Tensor): 张量,形状为 (batch_size, total_seq_len),包含提示和补全的token ids。
        attention_mask (torch.Tensor): 张量,形状为 (batch_size, total_seq_len),用于指示哪些token是有效的(1)或填充的(0)。
        logits_to_keep (int): 需要计算log概率的token数量(通常是completion部分的长度)

    返回:
        torch.Tensor: 张量,表示每个序列最后 `logits_to_keep` 个tokens的log概率。

    解释:
        1. 调用模型时请求logits_to_keep + 1个logits,以便支持下一token预测。
        2. 删除序列维度上的最后一个logit,因为它与任何输入token都不对应。
        3. 将input_ids和logits限制为最后的logits_to_keep个tokens,即补全部分。
        4. 使用selective_log_softmax计算这些tokens的log概率。
    """
    # 前向传播模型并获取logits。
    """关于logits_to_keep + 1的解释:
    input_ids = [X, Y1, Y2, Y3]  # 长度 4
    logits = [
        [预测 Y1],  # 位置 0 对应 X 的下一个 token
        [预测 Y2],  # 位置 1 对应 Y1 的下一个 token
        [预测 Y3],  # 位置 2 对应 Y2 的下一个 token
        [预测 ???]  # 位置 3 对应 Y3 的下一个 token(无意义)
    ]
    删除最后一个 logit 后,保留 logits[0], logits[1], logits[2]。
    切片 input_ids 取 [Y1, Y2, Y3]。
    此时 logits[0] 对应 Y1,logits[1] 对应 Y2,logits[2] 对应 Y3,完美对齐。
    """
    logits = model(
        input_ids=input_ids, 
        attention_mask=attention_mask, 
        logits_to_keep=logits_to_keep + 1  # 请求比需要的多一个logit,这是因为在自回归生成中,每个token预测下一个token
    ).logits  # 形状: (batch_size, total_seq_len, vocab_size)

    # 删除最后一个logit,因为它没有对应的目标token。
    logits = logits[:, :-1, :]  # 新形状: (batch_size, total_seq_len - 1, vocab_size)
    
    # 从input_ids中切出最后logits_to_keep个tokens,表示补全部分。
    input_ids = input_ids[:, -logits_to_keep:]  # 形状: (batch_size, logits_to_keep)
    
    # 同样切出对应的logits,仅保留补全部分。
    logits = logits[:, -logits_to_keep:, :]  # 形状: (batch_size, logits_to_keep, vocab_size)
    
    # 计算并返回这些tokens的log概率。
    return selective_log_softmax(logits, input_ids)

def create_completion_mask(completion_ids, eos_token_id):
    """
    作用:
        1. 处理变长序列:
            不同序列可能在不同位置结束
            通过mask确保只考虑有效的token
        2. EOS处理:
            包含EOS token本身
            忽略EOS之后的所有token
        3. 批处理支持:
            同时处理一个batch中的多个序列
            每个序列可以有不同的有效长度

    参数:
        completion_ids (torch.Tensor): 张量,形状为 (batch_size, seq_len),包含生成的token ids。
        eos_token_id (int): 表示序列结束的EOS token的id。

    返回:
        torch.Tensor: 掩码张量,形状为 (batch_size, seq_len),在EOS之前(包括EOS)为1,之后为0。

    解释:
        # 假设我们有一个batch_size=2的样本,seq_len=6,eos_token_id=2
        completion_ids = torch.tensor([
            [5, 7, 2, 8, 9, 1],  # 序列1:在索引2处有EOS
            [3, 4, 6, 8, 1, 5]   # 序列2:没有EOS
        ])

        # 1. 找出EOS位置
        is_eos = completion_ids == 2
        # is_eos = 
        # [[False, False, True,  False, False, False],
        #  [False, False, False, False, False, False]]

        # 2. 初始化eos_idx(默认为序列长度6)
        eos_idx = torch.tensor([6, 6])

        # 3. 检查哪些序列包含EOS
        mask_exists = is_eos.any(dim=1)  # [True, False]

        # 4. 更新存在EOS的序列的eos_idx
        # 序列1的eos_idx更新为2,序列2保持为6
        eos_idx = tensor([2, 6])

        # 5. 创建序列索引
        sequence_indices = tensor([
            [0, 1, 2, 3, 4, 5],
            [0, 1, 2, 3, 4, 5]
        ])

        # 6. 最终的mask
        completion_mask = tensor([
            [1, 1, 1, 0, 0, 0],  # 序列1:只保留到EOS(包括EOS)
            [1, 1, 1, 1, 1, 1]   # 序列2:保留所有token
        ])
    """
    # 确定序列中哪些位置是EOS token。
    is_eos = completion_ids == eos_token_id  # 布尔张量,形状为 (batch_size, seq_len)

    # 初始化张量,用于存储每个序列中第一个EOS的索引。如果没有找到EOS,默认使用序列长度。
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
    
    # 找到包含至少一个EOS的序列。
    mask_exists = is_eos.any(dim=1)
    # 对于包含EOS的序列,更新eos_idx为第一个EOS的位置索引。
    eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
    
    # 创建张量,包含每个序列位置的索引 [0, 1, 2, ..., seq_len-1]。
    # 并将其扩展到与序列数量一致。
    sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
    
    # 构建掩码:位置索引小于等于第一个EOS位置的标记为1。
    completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
    
    return completion_mask


def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
    """
    为每个 prompt 生成多个 completion,并创建相应的 attention 掩码。

    参数:
        model: 用于生成文本的语言模型。
        tokenizer: 用于处理 prompt 和解码输出的分词器。
        prompts (list of str): 输入的 prompt 字符串列表。
        num_generations (int): 每个 prompt 要生成的 completion 数量。
        max_completion_length (int): 每个 completion 生成的新 token 的最大数量。

    返回:
        tuple: 包含以下张量:
            - prompt_ids: (batch_size * num_generations, prompt_seq_len)
            - prompt_mask: (batch_size * num_generations, prompt_seq_len)
            - completion_ids: (batch_size * num_generations, completion_seq_len)
            - completion_mask: (batch_size * num_generations, completion_seq_len)
    
    解释:
        1. 对 prompt 进行分词,并左侧填充以保持右对齐。
        2. 每个 prompt 重复 num_generations 次,以便生成多个 completion。
        3. 调用 model.generate() 函数生成新 token。
        4. 生成的输出包含了 prompt 和 completion,通过移除 prompt 部分即可获得 completion。
        5. 使用 create_completion_mask 创建掩码,确保只处理第一个 EOS 之前的 token。

    具体例子:
        # 假设我们有:
        prompts = ["请写一首诗", "讲个故事"]
        num_generations = 2  # 每个prompt生成2个完成

        # 1. Tokenize后可能的结果:
        prompt_ids = tensor([
            [  1,  45, 678, 234],  # "请写一首诗"的token
            [  1, 123, 456, 789]   # "讲个故事"的token
        ])

        # 2. 复制每个prompt(repeat_interleave操作):
        prompt_ids_repeated = tensor([
            [  1,  45, 678, 234],  # prompt 1 - 复制 1
            [  1,  45, 678, 234],  # prompt 1 - 复制 2
            [  1, 123, 456, 789],  # prompt 2 - 复制 1
            [  1, 123, 456, 789]   # prompt 2 - 复制 2
        ])

        # 3. 生成完成文本(model.generate的结果):
        outputs = tensor([
            [  1,  45, 678, 234, 111, 222, 2],    # 完成 1-1
            [  1,  45, 678, 234, 333, 444, 2],    # 完成 1-2
            [  1, 123, 456, 789, 555, 666, 2],    # 完成 2-1
            [  1, 123, 456, 789, 777, 888, 2]     # 完成 2-2
        ])

        # 4. 分离completion部分:
        completion_ids = tensor([
            [111, 222, 2],    # 完成 1-1
            [333, 444, 2],    # 完成 1-2
            [555, 666, 2],    # 完成 2-1
            [777, 888, 2]     # 完成 2-2
        ])
    """
    device = next(model.parameters()).device

    # 对 prompt 列表进行分词,并进行 padding,padding_side="left" 确保右侧对齐。
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
    prompt_ids = inputs["input_ids"].to(device)      # 形状: (batch_size, prompt_seq_len)
    prompt_mask = inputs["attention_mask"].to(device)  # 形状: (batch_size, prompt_seq_len)
    prompt_length = prompt_ids.size(1)  # 保存 prompt 的长度,便于后续分离 prompt 和 completion。

    # 每个 prompt 重复 num_generations 次。
    prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)   # 新形状: (batch_size * num_generations, prompt_seq_len)
    prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) # 新形状: (batch_size * num_generations, prompt_seq_len)

    # 为每个 prompt 生成新 token,生成的结果包含了原始 prompt 及生成的部分。
    outputs = model.generate(
        prompt_ids,
        attention_mask=prompt_mask,
        max_new_tokens=max_completion_length,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    
    # 从输出中去掉 prompt 部分,提取出生成的 completion tokens。
    completion_ids = outputs[:, prompt_length:]  # 形状: (batch_size * num_generations, completion_seq_len)

    # 创建二值掩码,忽略第一个 EOS 后面的 token。
    completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)

    return prompt_ids, prompt_mask, completion_ids, completion_mask

def grpo_loss(model, ref_model, tokenizer, batch_samples, reward_function,
              beta=0.1, num_generations=4, max_completion_length=32):
    """
    计算 GRPO 损失,该损失结合了策略梯度损失和 KL 散度惩罚项。

    参数:
        model: 当前的语言模型(策略)。
        ref_model: 用于计算 KL 散度的参考模型(基线)。
        tokenizer: 用于解码 completion 的分词器。
        batch_samples (list): 一批样本,每个样本至少包含 "prompt" 和 "answer"。
        reward_function: 一个函数,接受 prompts、completions 和 answers 并返回奖励列表。
        beta (float): KL 散度部分的权重。
        num_generations (int): 每个 prompt 生成 completion 数量。
        max_completion_length (int): 每个生成的 completion 的最大 token 数量。

    返回:
        torch.Tensor: 标量形式的损失张量。

    解释:
        1. 从 batch_samples 中提取 prompt。
        2. 为每个 prompt 生成多个 completion。
        3. 将 prompt 和 completion tokens 拼接成完整输入序列。
        4. 分别使用当前模型和参考模型计算 completion 部分的对数概率。
        5. 将生成的 completion 格式化为文本以进行奖励评估。
        6. 为每个 completion 计算奖励,并在每个 prompt 组内归一化(计算优势)。
        7. 计算参考模型与当前模型对数概率之间每个 token 的 KL 散度。
        8. 结合策略梯度损失与 KL 惩罚项,计算最终的损失值。

    具体例子:
        # 假设我们有:
        batch_size = 2
        seq_len = 3

        # 策略log概率
        policy_log_probs = tensor([
            [-1.2, -0.8, -1.5],
            [-0.9, -1.1, -0.7]
        ])

        # 参考策略log概率
        ref_log_probs = tensor([
            [-1.0, -0.9, -1.4],
            [-1.0, -1.0, -0.8]
        ])

        # 优势值
        advantages = tensor([0.5, -0.3])

        # completion mask
        completion_mask = tensor([
            [1, 1, 1],
            [1, 1, 0]  # 最后一个token无效
        ])

        # 计算过程:
        # 1. KL散度
        kl_div = tensor([
            [-0.2, 0.1, -0.1],
            [0.1, -0.1, 0.1]
        ])

        # 2. 策略梯度(带优势)
        policy_gradient = tensor([
            [-0.6, -0.4, -0.75],
            [0.27, 0.33, 0.21]
        ])

        # 3. 组合损失(假设beta=0.1)
        combined_loss = -(policy_gradient - 0.1 * kl_div)* completion_mask
        
    """
    device = next(model.parameters()).device

    # 从每个样本中提取 prompt 文本。
    prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]

    # 生成 completion 及其对应的掩码。
    prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
        model, tokenizer, prompts, num_generations, max_completion_length
    )

    # 将 prompt 和 completion tokens 拼接成完整输入序列。
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)

    # 仅对 completion 部分计算对数概率。
    logits_to_keep = completion_ids.size(1)

    # 使用参考模型计算 completion tokens 的对数概率,使用 torch.no_grad() 避免梯度传播。
    with torch.no_grad():
        ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep)
    
    # 使用当前模型计算 completion tokens 的对数概率。
    token_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep)

    # 将 completion tokens 解码为文本以便进行奖励评估。
    # 每个解码后的 completion 被包装在一个字典中(以兼容奖励函数)。
    formatted_completions = [
        [{'content': tokenizer.decode(ids, skip_special_tokens=True)}]
        for ids in completion_ids
    ]
    # 每个 prompt 重复生成 num_generations 个 completion。
    repeated_prompts = [p for p in prompts for _ in range(num_generations)]
    # 从每个样本中提取 answer,并重复以匹配每个生成的 completion 数。
    answers = [sample["answer"] if isinstance(sample, dict) else sample[1]
               for sample in batch_samples for _ in range(num_generations)]

    # 使用 reward_function 计算奖励。
    rewards = torch.tensor(
        reward_function(prompts=repeated_prompts, completions=formatted_completions, answer=answers),
        dtype=torch.float32,
        device=device
    )

    # 为了监控,打印平均奖励。
    avg_reward = rewards.mean().item()
    print("Average Reward:", avg_reward)

    # 对奖励进行 reshape,按照每个 prompt 的生成组组织,
    # 并计算每组的均值和标准差。
    mean_rewards = rewards.view(-1, num_generations).mean(dim=1)
    std_rewards = rewards.view(-1, num_generations).std(dim=1)
    # 将均值和标准差扩展为与原扁平奖励张量相同形状。
    mean_rewards = mean_rewards.repeat_interleave(num_generations, dim=0)
    std_rewards = std_rewards.repeat_interleave(num_generations, dim=0)
    # 对奖励做归一化,计算优势。
    advantages = (rewards - mean_rewards) / (std_rewards + 1e-4)

    # 计算参考模型和当前模型对数概率之间每个 token 的 KL 散度。
    per_token_kl = torch.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
    
    # 计算策略梯度损失部分,
    # 使用 token_log_probs.detach() 阻止梯度传入基线值。
    per_token_loss = torch.exp(token_log_probs - token_log_probs.detach()) * advantages.unsqueeze(1)
    # 将基于 token 的损失与 KL 惩罚项(乘上 beta)相结合,并取负值。
    per_token_loss = -(per_token_loss - beta * per_token_kl)
    
    # 结合 completion 掩码计算每个序列的平均损失:
    # - 用掩码乘上损失,仅有效 token 参与计算;
    # - 对每个序列的损失求和后除以有效 token 数目;
    # - 最后对所有序列求平均。
    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

    return loss

def train_with_grpo(model, tokenizer, train_data, num_steps=500, batch_size=4,
                    num_generations=4, max_completion_length=128, beta=0.1,
                    learning_rate=5e-6):
    """
    利用 GRPO 算法对模型进行微调。

    该函数实现的训练流程包括:
      1. 创建一个参考模型(当前模型的深拷贝),其参数被冻结,不进行梯度更新。
      2. 在每个训练步骤中:
           - 从 train_data 中随机采样一批样本。
           - 为每个 prompt 生成多个 completion。
           - 计算 GRPO 损失(结合了基于优势的策略梯度项和当前模型与参考模型之间 KL 散度惩罚)。
           - 进行反向传播并更新模型参数。
           - 更新参考模型,使其与当前模型参数同步。
      
    参数:
        model: 待微调的语言模型。
        tokenizer: 用于编码 prompt 和解码 completion 的分词器。
        train_data (list): 训练样本列表,每个样本至少包含 "prompt" 和 "answer"。
        num_steps (int): 总训练步数。
        batch_size (int): 每步训练的样本数量。
        num_generations (int): 每个 prompt 生成的 completion 数量。
        max_completion_length (int): 每个生成的 completion 的最大 token 数。
        beta (float): 损失中 KL 散度惩罚项的权重。
        learning_rate (float): 优化器学习率。

    返回:
        微调后的模型。
    """
    # 获取模型参数所在设备(如 CPU 或 GPU)。
    device = next(model.parameters()).device

    # 深拷贝当前模型创建参考模型,参考模型用于计算 KL 散度,其参数不参与梯度更新。
    ref_model = copy.deepcopy(model)
    for param in ref_model.parameters():
        param.requires_grad = False

    # 使用 Adam 优化器,并设置学习率。
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # 将模型设置为训练模式(启用 dropout 等)。
    model.train()

    # 用于统计已处理样本数量(便于日志记录)。
    examples_processed = 0

    # 训练循环:迭代指定数量的训练步数。
    for step in range(num_steps):
        # 从 train_data 中随机采样一批样本,每个样本应为包含 "prompt" 和 "answer" 的字典或元组。
        batch_samples = random.sample(train_data, batch_size)

        # 计算当前批次的 GRPO 损失。
        # grpo_loss 函数执行的操作:
        #   - 从样本中提取 prompt;
        #   - 为每个 prompt 生成多个 completion;
        #   - 将 prompt 和生成的 completion tokens 拼接;
        #   - 分别计算当前模型和参考模型对 completion tokens 的对数概率;
        #   - 解码 completion 并通过奖励函数计算奖励;
        #   - 在每个 prompt 组内归一化奖励(计算优势);
        #   - 计算当前模型与参考模型之间每个 token 的 KL 散度;
        #   - 结合策略梯度损失和 KL 惩罚项得到最终标量损失。
        loss = grpo_loss(
            model,            # 当前正在微调的模型(策略)。
            ref_model,        # 用于计算 KL 散度的参考模型。
            tokenizer,        # 文本编码与解码的分词器。
            batch_samples,    # 当前批次的训练样本。
            combined_reward,  # 奖励函数(需在其他地方定义),返回奖励列表。
            beta=beta,        # KL 散度的权重。
            num_generations=num_generations,
            max_completion_length=max_completion_length
        )

        # 反向传播及参数更新:
        optimizer.zero_grad()           # 清除上一步梯度。
        loss.backward()                 # 计算梯度。
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)  # 可选:裁剪梯度以防梯度爆炸。
        optimizer.step()                # 更新模型参数。

        # 更新参考模型,使其与当前模型保持一致。
        ref_model.load_state_dict(model.state_dict())

        # 每 5 步打印一次损失,用于监控训练进度。
        if step % 5 == 0:
            print(f"Step {step}/{num_steps}, loss: {loss.item():.4f}")

        # 累计处理的样本数量。
        examples_processed += batch_size

        # 清理 GPU 缓存,帮助内存管理。
        torch.cuda.empty_cache()

    # 完成所有训练步骤后返回微调后的模型。
    return model


Logo

立足具身智能前沿赛道,致力于搭建全球化、开源化、全栈式技术交流与实践共创平台。

更多推荐