SQL-R1:通过强化学习训练自然语言到SQL推理模型
本文提出SQL-R1,一种基于强化学习(RL)训练的自然语言到SQL(NL2SQL)推理模型,旨在提升复杂场景下的性能。针对现有监督微调(SFT)方法在跨领域适应性和可解释性上的局限,研究设计了专用RL奖励函数,探讨冷启动训练影响,并利用少量合成数据(SynSQL-2.5M)增强训练。实验表明,SQL-R1在Spider和BIRD基准上分别达88.6%和67.1%的执行准确率。核心贡献包括:1)结
马沛贤 12{ }^{12}12,庄夏烈 13{ }^{13}13,许成锦 14∗{ }^{14 *}14∗,蒋旭辉 14{ }^{14}14,陈然 1{ }^{1}1,郭健 1{ }^{1}1
1{ }^{1}1 IDEA研究院,国际数字经济研究院
2{ }^{2}2 香港科技大学(广州)
3{ }^{3}3 中国科学院大学
4{ }^{4}4 数据弧科技有限公司
{mapeixian, zhuangxialie, xuchengjin, jiangxuhui, chenran, guojian} @idea.edu.cn
(1) https://github.com/IDEA-FinAI/SQL-R1
(2) https://huggingface.co/MPX0222forHF/SQL-R1-7B
摘要
自然语言到SQL(NL2SQL)能够通过将自然语言查询转换为结构化SQL语句来实现与数据库的直观交互。尽管最近在增强数据库应用中人机交互方面取得了进展,但在复杂场景下(如多表连接和嵌套查询)推理性能仍然存在显著挑战。当前的方法主要使用监督微调(SFT)来训练NL2SQL模型,这可能会限制其在新环境(例如金融和医疗)中的适应性和可解释性。为了提高NL2SQL模型在上述复杂情况下的推理性能,我们引入了SQL-R1,这是一种通过强化学习(RL)算法训练的新型NL2SQL推理模型。我们设计了一种专门针对NL2SQL任务的基于RL的奖励函数,并讨论了冷启动对密集训练效果的影响。此外,我们仅使用少量合成的NL2SQL数据进行增强训练,实现了具有竞争力的准确性,并进一步探索了RL的数据工程。在现有实验中,SQL-R1在基准Spider和BIRD上分别达到了88.6%和67.1%的执行准确率。
1 引言
自然语言到SQL(NL2SQL,或Text2SQL)将自然语言问题(NL)转换为结构化SQL语句,从而简化了无需数据库专业知识即可进行的数据库交互 [1, 2]。近年来,NL2SQL的进步显著提升了数据库查询应用中人机交互的水平,并为广泛的数据科学分析任务做出了贡献 [3, 4]。目前的NL2SQL模型主要集中在优化工作流程及其组件上,例如模式链接 [5, 6]、内容检索 [7]、生成校正 [8–12]。
尽管有这些进步,但提高NL2SQL系统在复杂数据库场景中的推理性能仍然是一个相当大的挑战。如图1所示,模式复杂性可能导致处理多表连接和嵌套查询时生成错误,而单独训练的模型难以思考和处理复杂的语义。目前,大量NL2SQL研究致力于通过监督微调(SFT)[13–15]来训练开源大型语言模型(LLMs),以在较小的模型规模上达到准确性,相对于使用闭源LLMs(例如GPT-4,GPT-4o)的方法 [8, 10, 16]。然而,SFT依赖于数据库模式结构和训练数据规模。这可能导致现有模型在域适应和新数据库环境中的泛化能力不稳定。此外,NL2SQL推理逻辑缺乏可解释性,限制了该模型在高风险领域(如金融和医疗)的应用。
最近,强化学习(RL)在训练LLMs的推理能力方面显示出了巨大潜力。与监督微调相比,强化学习可以通过与环境的交互动态调整LLMs的决策策略,从而在复杂推理任务中表现出色 [17]。基于RL的方法已被证明能有效提升模型在金融推理 [18]、搜索引擎 [19] 和数学推理 [20, 21] 中的推理和泛化能力。
基于上述启发,我们提出了SQL-R1,一种通过强化学习算法训练的NL2SQL推理模型。图1展示了我们的工作概述。在接下来的部分中,我们将重点回答以下关键问题:
Q1: 我们能否为NL2SQL任务设计特定的强化学习算法并成功训练出NL2SQL推理模型? 与SFT相比,RL算法优先直接优化NL2SQL推理,具体来说是生成准确反映用户查询意图的SQL查询。为强化学习设计有效的反馈机制是开发NL2SQL推理模型的重要挑战。在强化学习框架内设计适当的奖励可以显著提升其性能。
Q2: 对于基于RL的NL2SQL推理模型,是否需要对其进行某种形式的冷启动? 对于现有的基础模型,有效的冷启动可以加强模型遵循指令的能力并激活其NL2SQL生成能力,从而促使其在强化学习探索中生成高质量的SQL查询。设计冷启动的形式也将是一个重大挑战。
Q3: 我们能否部署可持续的数据工程来训练强大且高效的NL2SQL推理模型? RL训练依赖于高质量的训练数据,而当前的NL2SQL任务缺乏大量真实数据用于训练。如何基于现有的数据工程技术开发NL2SQL推理模型的数据支持,解决模型训练问题,提高模型的鲁棒性和泛化能力,是一项重要挑战。
综上所述,本工作的贡献如下:
- 明确的NL2SQL推理模型:我们提出了SQL-R1,一种基于当前少量NL2SQL数据(例如5K)训练的NL2SQL推理模型,该模型可以在领先的基准测试Spider-Test和BIRD上分别达到88.6%和66.6%的准确率,并能输出详细的明确推理过程。
-
- NL2SQL推理模型的训练策略:我们广泛探讨了冷启动训练对SQL-R1的影响,开发了一种结合SFT和RL的训练策略。我们的研究结果突出了使用合成数据提升模型性能和鲁棒性的策略,为优化NL2SQL推理模型训练提供了关键见解。
2 SQL-R1
2.1 概述
本节主要介绍两种通过RL算法训练NL2SQL模型的形式:直接强化训练和通过冷启动后的强化训练。其中,冷启动指的是使用特定数据首先通过SFT训练基础模型,使其具备一定的思维和指令跟随能力。此外,由于真实数据有限,我们使用最新的合成数据来支持上述训练过程。第2.2节将介绍我们当前的数据工程解决方案,第2.3节将介绍专为NL2SQL设计的SFT算法和RL算法。
2.2 数据准备
2.2.1 来源
目前,我们利用SynSQL-2.5M [22] 数据集作为主要数据来源,这是首个百万级合成NL2SQL数据集,包含超过250万个多样且高质量的数据样本。每个样本由四元组组成,包括数据库、自然语言问题、SQL查询和链式推理(CoT)解决方案。该数据集包含来自多个领域的超过16,000个合成数据库,从而确保覆盖广泛的现实场景。SynSQL-2.5M包括从简单单表查询到复杂多表连接、函数和通用表达式的各种SQL复杂度级别。
2.2.2 预处理
SFT数据集。在这项研究中,我们调查了SFT中冷启动条件对RL训练的影响。目前,我们使用了从SynSQL-2.5M中抽取的200,000个样本来进行SFT训练,不同难度级别的样本数量均匀分布,每个级别包含50,000个样本。为了清晰起见,我们将在后续部分中将此子集称为SynSQL-200K。必须强调的是,所有SQL真值查询结果都是非空值。对于SFT数据集V中的每个样本v=(x, t, y^),x表示自然语言,t表示封装在 标签内的推理过程,y^ 表示封装在 . . 标签内的SQL。
RL数据集。当前的NL2SQL基础模型在生成简单到中等复杂度的SQL查询方面表现出了强大的能力。然而,在生成更复杂的SQL查询时,它表现出局限性。因此,在训练过程中使用包含更具挑战性样本的数据集可能有助于解决这些缺陷并提高模型在生成复杂SQL方面的整体性能。我们从SynSQL-2.5M中随机抽样了5K个NL-SQL对,其复杂度较高。对于RL数据集V中的每个NL-SQL对v=(x, y*),x表示自然语言,y* 表示模型生成的SQL候选。强化学习的目标是提高答案的准确性并确保它们符合预期格式。RL数据集在下一节中被引入为SynSQL-Complex-5K。值得注意的是,RL数据集的输入不包括SynSQL-2.5M的原始CoT数据。
2.3 训练
2.3.1 监督微调
在本研究中,我们在Qwen2.5-Coder-7B-Instruct模型上进行SFT,以增强模型在NL2SQL领域中的指令遵守能力和生成能力。我们研究了两种不同的SFT冷启动训练策略。第一种策略使用专注于SQL生成的原始指令。我们参考了现有的OmniSQL-7B [22] 检查点。第二种策略利用完整的微调和推理生成指令,促进合规思维过程和最终答案的发展。
2.3.2 强化训练
在强化学习阶段,我们采用Group Relative Policy Optimization (GRPO) 算法来增强我们的训练协议,这种方法不需要价值模型,内存需求较低,并且可以明确定义奖励目标,使其成为有效优化NL2SQL策略模型的最佳选择 [23]。
对于每个与其对应的数据库模式对齐的自然语言问题,策略模型会从旧策略π_old生成一组G个SQL候选{ o_1, o_2 …, o_G },这些候选会通过一个复合奖励函数进行仔细评估,该函数分配特定的奖励分数。通过关注组内SQL候选的相对表现,GRPO能够巧妙地计算每个输出的奖励,从而根据我们设定的目标引导策略更新。
JGRPO(θ)=Ev∼P(V),{oi}i=1G∼πθold(O∣v)[1G∑i=1G(min(riratio Ai,clip(riratio ,1−ϵ,1+ϵ)Ai)−βDKL(πθ∥πref))] \begin{aligned} \mathcal{J}_{\mathrm{GRPO}}(\theta)= & \mathbb{E}_{\mathbf{v} \sim P(\mathbf{V}),\left\{o_{i}\right\}_{i=1}^{G} \sim \pi_{\theta_{\mathrm{old}}}(O \mid \mathbf{v})} \\ & {\left[\frac{1}{G} \sum_{i=1}^{G}\left(\min \left(r_{i}^{\text {ratio }} A_{i}, \operatorname{clip}\left(r_{i}^{\text {ratio }}, 1-\epsilon, 1+\epsilon\right) A_{i}\right)-\beta D_{\mathrm{KL}}\left(\pi_{\theta} \| \pi_{\mathrm{ref}}\right)\right)\right] } \end{aligned} JGRPO(θ)=Ev∼P(V),{oi}i=1G∼πθold(O∣v)[G1i=1∑G(min(riratio Ai,clip(riratio ,1−ϵ,1+ϵ)Ai)−βDKL(πθ∥πref))]
其中 riratio =πθ(oi∣V)πold(oi∣V)r_{i}^{\text {ratio }}=\frac{\pi_{\theta}\left(o_{i} \mid V\right)}{\pi_{o l d}\left(o_{i} \mid V\right)}riratio =πold(oi∣V)πθ(oi∣V) 表示重要性采样比率,量化在新策略 πθ\pi_{\theta}πθ 下生成输出 oi 的相对可能性与 πold\pi_{o l d}πold 相比; AiA_{i}Ai 表示每个输出的组相对优势;剪切算子、超参数 ϵ\epsilonϵ 和 β\betaβ 控制更新步长和发散正则化; πref \pi_{\text {ref }}πref 表示参考策略。
2.3.3 奖励函数设计
在使用强化学习训练NL2SQL奖励模型时,我们利用了一个渐进的反馈机制,包含四种类型的奖励:格式奖励、执行奖励、结果奖励和长度奖励。这种分层方法通过在各个阶段提供详细反馈来增强模型的学习。
格式奖励。我们鼓励模型将NL2SQL推理过程包含在
参考论文:https://arxiv.org/pdf/2504.08600
更多推荐


所有评论(0)