GRPO 算法核心公式解析(附代码详解)
【代码】GRPO 算法核心公式解析(附代码详解)
·
https://zhuanlan.zhihu.com/p/24816372882
https://zhuanlan.zhihu.com/p/24816372882
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.special import softmax
from scipy.special import softmax, kl_div
# 词汇表及其分组
vocab = ["ant", "bear", "cat", "dog"]
is_vowel = lambda x: x[0].lower() in "aeiou"
class GRPO:
def __init__(self, vocab, beta=0.1, epsilon=1e-8):
self.vocab = vocab
self.vocab_size = len(vocab)
self.beta = beta
self.epsilon = epsilon
np.random.seed(42)
self.theta = np.random.randn(self.vocab_size)
# 添加用于存储训练历史的列表
self.prob_history = []
self.reward_history = []
self.kl_div_history = []
def get_policy_probs(self, logits):
return softmax(logits)
def sample_word(self, probs):
word_idx = np.random.choice(len(self.vocab), p=probs)
return word_idx
def compute_reward(self, word_idx):
word = self.vocab[word_idx]
return 1.0 if is_vowel(word) else 0.0
def compute_kl_divergence(self, old_probs, new_probs):
return np.sum(kl_div(old_probs, new_probs))
def plot_training_progress(self):
# 创建三个子图
fig = make_subplots(
rows=3, cols=1,
subplot_titles=('Word Probabilities Over Time', 'Mean Reward', 'KL Divergence'),
vertical_spacing=0.1,
row_heights=[0.4, 0.3, 0.3]
)
# 1. 词汇概率随时间变化
prob_data = np.array(self.prob_history)
for i, word in enumerate(self.vocab):
fig.add_trace(
go.Scatter(x=list(range(len(prob_data))),
y=prob_data[:, i],
name=word,
mode='lines'),
row=1, col=1
)
# 2. 平均奖励变化
fig.add_trace(
go.Scatter(x=list(range(len(self.reward_history))),
y=self.reward_history,
name='Mean Reward',
line=dict(color='green')),
row=2, col=1
)
# 3. KL散度变化
fig.add_trace(
go.Scatter(x=list(range(len(self.kl_div_history))),
y=self.kl_div_history,
name='KL Divergence',
line=dict(color='red')),
row=3, col=1
)
# 更新布局
fig.update_layout(
height=900,
showlegend=True,
title_text="GRPO Training Progress"
)
# 更新坐标轴标签
fig.update_xaxes(title_text="Iteration", row=3, col=1)
fig.update_yaxes(title_text="Probability", row=1, col=1)
fig.update_yaxes(title_text="Reward", row=2, col=1)
fig.update_yaxes(title_text="KL Divergence", row=3, col=1)
fig.show()
def train(self, num_iterations=50, num_samples=20):
for iteration in range(1, num_iterations + 1):
current_probs = self.get_policy_probs(self.theta)
outputs = []
rewards = []
for _ in range(num_samples):
word_idx = self.sample_word(current_probs)
reward = self.compute_reward(word_idx)
outputs.append((word_idx, reward))
rewards.append(reward)
rewards = np.array(rewards)
mu_G = np.mean(rewards)
sigma_G = np.std(rewards) + self.epsilon
advantages = [(r - mu_G) / (sigma_G) for _, r in outputs]
grad = np.zeros_like(self.theta)
for (word_idx, _), advantage in zip(outputs, advantages):
grad_sample = -current_probs.copy()
grad_sample[word_idx] += 1.0
grad += grad_sample * advantage
learning_rate = 0.01
self.theta += learning_rate * grad
# 记录训练历史
new_probs = self.get_policy_probs(self.theta)
kl_div_value = self.compute_kl_divergence(current_probs, new_probs)
self.prob_history.append(new_probs)
self.reward_history.append(mu_G)
self.kl_div_history.append(kl_div_value)
if iteration % 10 == 0:
print(f"\nIteration {iteration}")
print("Probabilities:")
for word, prob in zip(vocab, new_probs):
print(f" {word:12s}: {prob:.3f}")
print(f"Mean Reward: {mu_G:.3f}")
print(f"KL Divergence: {kl_div_value:.3f}")
# 训练结束后显示可视化结果
self.plot_training_progress()
# 运行训练
grpo = GRPO(vocab)
grpo.train()
import torch
import torch.nn.functional as F
import random
import copy
def selective_log_softmax(logits, input_ids):
"""
作用:
1. 计算模型对实际生成的token的预测概率
2. 这些概率后续用于计算策略梯度和KL散度
3. 通过对数概率进行计算可以提高数值稳定性
参数:
logits (torch.Tensor): 张量,形状为 (batch_size, seq_len, vocab_size),表示模型的原始logits输出。
input_ids (torch.Tensor): 张量,形状为 (batch_size, seq_len),表示需要计算log概率的tokens索引。
返回:
torch.Tensor: 张量,形状为 (batch_size, seq_len),表示input_ids中每个token对应的log概率。
解释:
1. 使用F.log_softmax在词汇表维度(dim=-1)上将logits转换为log概率。
2. 将input_ids通过unsqueeze增加一个额外维度,以便用作log_probs中的索引。
3. 使用torch.gather提取log_probs中每个位置对应input_ids的log概率。
4. 最后使用squeeze(-1)移除多余的维度,返回与input_ids形状相同的张量。
"""
# 将原始logits转换为log概率,计算维度为词汇表维度。
log_probs = F.log_softmax(logits, dim=-1) # 形状: (batch_size, seq_len, vocab_size)
# 将input_ids从 (batch_size, seq_len) 转换为 (batch_size, seq_len, 1),以便使用gather获取指定位置的log概率。
selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
# 去掉最后多余的维度,返回形状为 (batch_size, seq_len) 的张量。
return selected_log_probs.squeeze(-1)
def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
"""
作用:
1. 计算模型对部分tokens(通常是补全部分)的逐token log概率
2. 只关注completion部分的token,因为这是我们要优化的部分
3. 处理了自回归生成中的对齐问题
4. 为后续计算策略梯度和KL散度提供必要的概率值
参数:
model: 使用的语言模型。
input_ids (torch.Tensor): 张量,形状为 (batch_size, total_seq_len),包含提示和补全的token ids。
attention_mask (torch.Tensor): 张量,形状为 (batch_size, total_seq_len),用于指示哪些token是有效的(1)或填充的(0)。
logits_to_keep (int): 需要计算log概率的token数量(通常是completion部分的长度)
返回:
torch.Tensor: 张量,表示每个序列最后 `logits_to_keep` 个tokens的log概率。
解释:
1. 调用模型时请求logits_to_keep + 1个logits,以便支持下一token预测。
2. 删除序列维度上的最后一个logit,因为它与任何输入token都不对应。
3. 将input_ids和logits限制为最后的logits_to_keep个tokens,即补全部分。
4. 使用selective_log_softmax计算这些tokens的log概率。
"""
# 前向传播模型并获取logits。
"""关于logits_to_keep + 1的解释:
input_ids = [X, Y1, Y2, Y3] # 长度 4
logits = [
[预测 Y1], # 位置 0 对应 X 的下一个 token
[预测 Y2], # 位置 1 对应 Y1 的下一个 token
[预测 Y3], # 位置 2 对应 Y2 的下一个 token
[预测 ???] # 位置 3 对应 Y3 的下一个 token(无意义)
]
删除最后一个 logit 后,保留 logits[0], logits[1], logits[2]。
切片 input_ids 取 [Y1, Y2, Y3]。
此时 logits[0] 对应 Y1,logits[1] 对应 Y2,logits[2] 对应 Y3,完美对齐。
"""
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1 # 请求比需要的多一个logit,这是因为在自回归生成中,每个token预测下一个token
).logits # 形状: (batch_size, total_seq_len, vocab_size)
# 删除最后一个logit,因为它没有对应的目标token。
logits = logits[:, :-1, :] # 新形状: (batch_size, total_seq_len - 1, vocab_size)
# 从input_ids中切出最后logits_to_keep个tokens,表示补全部分。
input_ids = input_ids[:, -logits_to_keep:] # 形状: (batch_size, logits_to_keep)
# 同样切出对应的logits,仅保留补全部分。
logits = logits[:, -logits_to_keep:, :] # 形状: (batch_size, logits_to_keep, vocab_size)
# 计算并返回这些tokens的log概率。
return selective_log_softmax(logits, input_ids)
def create_completion_mask(completion_ids, eos_token_id):
"""
作用:
1. 处理变长序列:
不同序列可能在不同位置结束
通过mask确保只考虑有效的token
2. EOS处理:
包含EOS token本身
忽略EOS之后的所有token
3. 批处理支持:
同时处理一个batch中的多个序列
每个序列可以有不同的有效长度
参数:
completion_ids (torch.Tensor): 张量,形状为 (batch_size, seq_len),包含生成的token ids。
eos_token_id (int): 表示序列结束的EOS token的id。
返回:
torch.Tensor: 掩码张量,形状为 (batch_size, seq_len),在EOS之前(包括EOS)为1,之后为0。
解释:
# 假设我们有一个batch_size=2的样本,seq_len=6,eos_token_id=2
completion_ids = torch.tensor([
[5, 7, 2, 8, 9, 1], # 序列1:在索引2处有EOS
[3, 4, 6, 8, 1, 5] # 序列2:没有EOS
])
# 1. 找出EOS位置
is_eos = completion_ids == 2
# is_eos =
# [[False, False, True, False, False, False],
# [False, False, False, False, False, False]]
# 2. 初始化eos_idx(默认为序列长度6)
eos_idx = torch.tensor([6, 6])
# 3. 检查哪些序列包含EOS
mask_exists = is_eos.any(dim=1) # [True, False]
# 4. 更新存在EOS的序列的eos_idx
# 序列1的eos_idx更新为2,序列2保持为6
eos_idx = tensor([2, 6])
# 5. 创建序列索引
sequence_indices = tensor([
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]
])
# 6. 最终的mask
completion_mask = tensor([
[1, 1, 1, 0, 0, 0], # 序列1:只保留到EOS(包括EOS)
[1, 1, 1, 1, 1, 1] # 序列2:保留所有token
])
"""
# 确定序列中哪些位置是EOS token。
is_eos = completion_ids == eos_token_id # 布尔张量,形状为 (batch_size, seq_len)
# 初始化张量,用于存储每个序列中第一个EOS的索引。如果没有找到EOS,默认使用序列长度。
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
# 找到包含至少一个EOS的序列。
mask_exists = is_eos.any(dim=1)
# 对于包含EOS的序列,更新eos_idx为第一个EOS的位置索引。
eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
# 创建张量,包含每个序列位置的索引 [0, 1, 2, ..., seq_len-1]。
# 并将其扩展到与序列数量一致。
sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
# 构建掩码:位置索引小于等于第一个EOS位置的标记为1。
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
return completion_mask
def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
"""
为每个 prompt 生成多个 completion,并创建相应的 attention 掩码。
参数:
model: 用于生成文本的语言模型。
tokenizer: 用于处理 prompt 和解码输出的分词器。
prompts (list of str): 输入的 prompt 字符串列表。
num_generations (int): 每个 prompt 要生成的 completion 数量。
max_completion_length (int): 每个 completion 生成的新 token 的最大数量。
返回:
tuple: 包含以下张量:
- prompt_ids: (batch_size * num_generations, prompt_seq_len)
- prompt_mask: (batch_size * num_generations, prompt_seq_len)
- completion_ids: (batch_size * num_generations, completion_seq_len)
- completion_mask: (batch_size * num_generations, completion_seq_len)
解释:
1. 对 prompt 进行分词,并左侧填充以保持右对齐。
2. 每个 prompt 重复 num_generations 次,以便生成多个 completion。
3. 调用 model.generate() 函数生成新 token。
4. 生成的输出包含了 prompt 和 completion,通过移除 prompt 部分即可获得 completion。
5. 使用 create_completion_mask 创建掩码,确保只处理第一个 EOS 之前的 token。
具体例子:
# 假设我们有:
prompts = ["请写一首诗", "讲个故事"]
num_generations = 2 # 每个prompt生成2个完成
# 1. Tokenize后可能的结果:
prompt_ids = tensor([
[ 1, 45, 678, 234], # "请写一首诗"的token
[ 1, 123, 456, 789] # "讲个故事"的token
])
# 2. 复制每个prompt(repeat_interleave操作):
prompt_ids_repeated = tensor([
[ 1, 45, 678, 234], # prompt 1 - 复制 1
[ 1, 45, 678, 234], # prompt 1 - 复制 2
[ 1, 123, 456, 789], # prompt 2 - 复制 1
[ 1, 123, 456, 789] # prompt 2 - 复制 2
])
# 3. 生成完成文本(model.generate的结果):
outputs = tensor([
[ 1, 45, 678, 234, 111, 222, 2], # 完成 1-1
[ 1, 45, 678, 234, 333, 444, 2], # 完成 1-2
[ 1, 123, 456, 789, 555, 666, 2], # 完成 2-1
[ 1, 123, 456, 789, 777, 888, 2] # 完成 2-2
])
# 4. 分离completion部分:
completion_ids = tensor([
[111, 222, 2], # 完成 1-1
[333, 444, 2], # 完成 1-2
[555, 666, 2], # 完成 2-1
[777, 888, 2] # 完成 2-2
])
"""
device = next(model.parameters()).device
# 对 prompt 列表进行分词,并进行 padding,padding_side="left" 确保右侧对齐。
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
prompt_ids = inputs["input_ids"].to(device) # 形状: (batch_size, prompt_seq_len)
prompt_mask = inputs["attention_mask"].to(device) # 形状: (batch_size, prompt_seq_len)
prompt_length = prompt_ids.size(1) # 保存 prompt 的长度,便于后续分离 prompt 和 completion。
# 每个 prompt 重复 num_generations 次。
prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0) # 新形状: (batch_size * num_generations, prompt_seq_len)
prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) # 新形状: (batch_size * num_generations, prompt_seq_len)
# 为每个 prompt 生成新 token,生成的结果包含了原始 prompt 及生成的部分。
outputs = model.generate(
prompt_ids,
attention_mask=prompt_mask,
max_new_tokens=max_completion_length,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 从输出中去掉 prompt 部分,提取出生成的 completion tokens。
completion_ids = outputs[:, prompt_length:] # 形状: (batch_size * num_generations, completion_seq_len)
# 创建二值掩码,忽略第一个 EOS 后面的 token。
completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
return prompt_ids, prompt_mask, completion_ids, completion_mask
def grpo_loss(model, ref_model, tokenizer, batch_samples, reward_function,
beta=0.1, num_generations=4, max_completion_length=32):
"""
计算 GRPO 损失,该损失结合了策略梯度损失和 KL 散度惩罚项。
参数:
model: 当前的语言模型(策略)。
ref_model: 用于计算 KL 散度的参考模型(基线)。
tokenizer: 用于解码 completion 的分词器。
batch_samples (list): 一批样本,每个样本至少包含 "prompt" 和 "answer"。
reward_function: 一个函数,接受 prompts、completions 和 answers 并返回奖励列表。
beta (float): KL 散度部分的权重。
num_generations (int): 每个 prompt 生成 completion 数量。
max_completion_length (int): 每个生成的 completion 的最大 token 数量。
返回:
torch.Tensor: 标量形式的损失张量。
解释:
1. 从 batch_samples 中提取 prompt。
2. 为每个 prompt 生成多个 completion。
3. 将 prompt 和 completion tokens 拼接成完整输入序列。
4. 分别使用当前模型和参考模型计算 completion 部分的对数概率。
5. 将生成的 completion 格式化为文本以进行奖励评估。
6. 为每个 completion 计算奖励,并在每个 prompt 组内归一化(计算优势)。
7. 计算参考模型与当前模型对数概率之间每个 token 的 KL 散度。
8. 结合策略梯度损失与 KL 惩罚项,计算最终的损失值。
具体例子:
# 假设我们有:
batch_size = 2
seq_len = 3
# 策略log概率
policy_log_probs = tensor([
[-1.2, -0.8, -1.5],
[-0.9, -1.1, -0.7]
])
# 参考策略log概率
ref_log_probs = tensor([
[-1.0, -0.9, -1.4],
[-1.0, -1.0, -0.8]
])
# 优势值
advantages = tensor([0.5, -0.3])
# completion mask
completion_mask = tensor([
[1, 1, 1],
[1, 1, 0] # 最后一个token无效
])
# 计算过程:
# 1. KL散度
kl_div = tensor([
[-0.2, 0.1, -0.1],
[0.1, -0.1, 0.1]
])
# 2. 策略梯度(带优势)
policy_gradient = tensor([
[-0.6, -0.4, -0.75],
[0.27, 0.33, 0.21]
])
# 3. 组合损失(假设beta=0.1)
combined_loss = -(policy_gradient - 0.1 * kl_div)* completion_mask
"""
device = next(model.parameters()).device
# 从每个样本中提取 prompt 文本。
prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
# 生成 completion 及其对应的掩码。
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
model, tokenizer, prompts, num_generations, max_completion_length
)
# 将 prompt 和 completion tokens 拼接成完整输入序列。
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# 仅对 completion 部分计算对数概率。
logits_to_keep = completion_ids.size(1)
# 使用参考模型计算 completion tokens 的对数概率,使用 torch.no_grad() 避免梯度传播。
with torch.no_grad():
ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep)
# 使用当前模型计算 completion tokens 的对数概率。
token_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep)
# 将 completion tokens 解码为文本以便进行奖励评估。
# 每个解码后的 completion 被包装在一个字典中(以兼容奖励函数)。
formatted_completions = [
[{'content': tokenizer.decode(ids, skip_special_tokens=True)}]
for ids in completion_ids
]
# 每个 prompt 重复生成 num_generations 个 completion。
repeated_prompts = [p for p in prompts for _ in range(num_generations)]
# 从每个样本中提取 answer,并重复以匹配每个生成的 completion 数。
answers = [sample["answer"] if isinstance(sample, dict) else sample[1]
for sample in batch_samples for _ in range(num_generations)]
# 使用 reward_function 计算奖励。
rewards = torch.tensor(
reward_function(prompts=repeated_prompts, completions=formatted_completions, answer=answers),
dtype=torch.float32,
device=device
)
# 为了监控,打印平均奖励。
avg_reward = rewards.mean().item()
print("Average Reward:", avg_reward)
# 对奖励进行 reshape,按照每个 prompt 的生成组组织,
# 并计算每组的均值和标准差。
mean_rewards = rewards.view(-1, num_generations).mean(dim=1)
std_rewards = rewards.view(-1, num_generations).std(dim=1)
# 将均值和标准差扩展为与原扁平奖励张量相同形状。
mean_rewards = mean_rewards.repeat_interleave(num_generations, dim=0)
std_rewards = std_rewards.repeat_interleave(num_generations, dim=0)
# 对奖励做归一化,计算优势。
advantages = (rewards - mean_rewards) / (std_rewards + 1e-4)
# 计算参考模型和当前模型对数概率之间每个 token 的 KL 散度。
per_token_kl = torch.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
# 计算策略梯度损失部分,
# 使用 token_log_probs.detach() 阻止梯度传入基线值。
per_token_loss = torch.exp(token_log_probs - token_log_probs.detach()) * advantages.unsqueeze(1)
# 将基于 token 的损失与 KL 惩罚项(乘上 beta)相结合,并取负值。
per_token_loss = -(per_token_loss - beta * per_token_kl)
# 结合 completion 掩码计算每个序列的平均损失:
# - 用掩码乘上损失,仅有效 token 参与计算;
# - 对每个序列的损失求和后除以有效 token 数目;
# - 最后对所有序列求平均。
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
return loss
def train_with_grpo(model, tokenizer, train_data, num_steps=500, batch_size=4,
num_generations=4, max_completion_length=128, beta=0.1,
learning_rate=5e-6):
"""
利用 GRPO 算法对模型进行微调。
该函数实现的训练流程包括:
1. 创建一个参考模型(当前模型的深拷贝),其参数被冻结,不进行梯度更新。
2. 在每个训练步骤中:
- 从 train_data 中随机采样一批样本。
- 为每个 prompt 生成多个 completion。
- 计算 GRPO 损失(结合了基于优势的策略梯度项和当前模型与参考模型之间 KL 散度惩罚)。
- 进行反向传播并更新模型参数。
- 更新参考模型,使其与当前模型参数同步。
参数:
model: 待微调的语言模型。
tokenizer: 用于编码 prompt 和解码 completion 的分词器。
train_data (list): 训练样本列表,每个样本至少包含 "prompt" 和 "answer"。
num_steps (int): 总训练步数。
batch_size (int): 每步训练的样本数量。
num_generations (int): 每个 prompt 生成的 completion 数量。
max_completion_length (int): 每个生成的 completion 的最大 token 数。
beta (float): 损失中 KL 散度惩罚项的权重。
learning_rate (float): 优化器学习率。
返回:
微调后的模型。
"""
# 获取模型参数所在设备(如 CPU 或 GPU)。
device = next(model.parameters()).device
# 深拷贝当前模型创建参考模型,参考模型用于计算 KL 散度,其参数不参与梯度更新。
ref_model = copy.deepcopy(model)
for param in ref_model.parameters():
param.requires_grad = False
# 使用 Adam 优化器,并设置学习率。
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 将模型设置为训练模式(启用 dropout 等)。
model.train()
# 用于统计已处理样本数量(便于日志记录)。
examples_processed = 0
# 训练循环:迭代指定数量的训练步数。
for step in range(num_steps):
# 从 train_data 中随机采样一批样本,每个样本应为包含 "prompt" 和 "answer" 的字典或元组。
batch_samples = random.sample(train_data, batch_size)
# 计算当前批次的 GRPO 损失。
# grpo_loss 函数执行的操作:
# - 从样本中提取 prompt;
# - 为每个 prompt 生成多个 completion;
# - 将 prompt 和生成的 completion tokens 拼接;
# - 分别计算当前模型和参考模型对 completion tokens 的对数概率;
# - 解码 completion 并通过奖励函数计算奖励;
# - 在每个 prompt 组内归一化奖励(计算优势);
# - 计算当前模型与参考模型之间每个 token 的 KL 散度;
# - 结合策略梯度损失和 KL 惩罚项得到最终标量损失。
loss = grpo_loss(
model, # 当前正在微调的模型(策略)。
ref_model, # 用于计算 KL 散度的参考模型。
tokenizer, # 文本编码与解码的分词器。
batch_samples, # 当前批次的训练样本。
combined_reward, # 奖励函数(需在其他地方定义),返回奖励列表。
beta=beta, # KL 散度的权重。
num_generations=num_generations,
max_completion_length=max_completion_length
)
# 反向传播及参数更新:
optimizer.zero_grad() # 清除上一步梯度。
loss.backward() # 计算梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # 可选:裁剪梯度以防梯度爆炸。
optimizer.step() # 更新模型参数。
# 更新参考模型,使其与当前模型保持一致。
ref_model.load_state_dict(model.state_dict())
# 每 5 步打印一次损失,用于监控训练进度。
if step % 5 == 0:
print(f"Step {step}/{num_steps}, loss: {loss.item():.4f}")
# 累计处理的样本数量。
examples_processed += batch_size
# 清理 GPU 缓存,帮助内存管理。
torch.cuda.empty_cache()
# 完成所有训练步骤后返回微调后的模型。
return model

更多推荐

所有评论(0)