File size: 5,156 Bytes
84c630e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# 引言
[Rain's SQLCoder](https://huggingface.co/SuanChang/rain-SQLCoder) 是自然语言生成 SparkSQL 的 SOTA 大型语言模型(LLM),拥有 32B 参数,基于 [Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) 微调。 Rain's SQLCoder 针对自然语言到 SparkSQL 转换任务进行了优化,能够有效处理最长达 32k 个 token 的上下文,尤其适用于复杂且大规模的 SQL 查询生成任务。

<p align="center">
          🤗 <a href="https://huggingface.co/SuanChang/rain-SQLCoder">Hugging Face</a> | 🖥️ <a href="https://www.suan-chang.com/">演示</a> | 💬 <a href="./figures/wechat.png">微信</a> | <a href="https://github.com/suan-chang/rain-SQLCoder">GitHub</a>
</p>

[English](./README.md) | [中文](./README-zh.md)

# 提示词
Rain's SQLCoder 采用了 [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 模板,使用的提示词如下。
````
Below is an instruction that describes a task. 
Write a response that appropriately completes the request.

### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing Spark SQL queries. You are given a user query and a set of table schemas.
Based on the user query, you need to generate one Spark SQL query to achieve the purpose.
{task description for date hint and related question and sqls}
[END OF TASK INSTRUCTION]

[BEGIN OF TABLE SCHEMAS]
{schemas}
[END OF TABLE SCHEMAS]

[BEGIN OF GENERATION HINT]
{date hint}
[END OF GENERATION HINT]

[BEGIN OF RELATED QUERIES]
{related question and sqls}
[END OF RELATED QUERIES]

[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following format, and NO other text MUST be included.
```sql
your output Spark SQL query
``` 
[END OF FORMAT INSTRUCTION]

[BEGIN OF QUERY]
User Query: {user question}
[END OF QUERY]

### Response:
````

# 评估
我们沿用了 [SQL-Eval](https://github.com/defog-ai/sql-eval) 中评估预测结果与标准结果的逻辑:
1. 如果预测的数据块和标准数据块完全一致,则预测结果正确;
2. 标准SQL中不包含排序逻辑,且预测数据块和标准数据块在排序之后完全一致,则预测结果正确;
3. 如果标准数据块的列是预测数据块的子集,则预测结果正确;
4. 其余情况均认为预测结果错误。

# 实验结果
我们在两个测试集上对比了Rain's SQLCoder与国内外先进自然语言大模型的生成准确率。其中,基准测试集(Benchmark Dataset)包含基础样本,而增强测试集(Enhanced Dataset)则是在基准测试集的基础上,通过分层抽样方法选取20%的样本,并补充了相关的用户查询及对应的SparkSQL语句,以评估模型在增强上下文信息下的性能表现。实验结果表明,Rain's SQLCoder在查询意图理解、SQL语法准确性和复杂查询处理等方面均展现出显著优势。

## 基准测试集
<img src="./figures/benchmark_dataset_result.png" alt="benchmark" width=800>

## 增强测试集
<img src="./figures/enhanced_dataset_result.png" alt="enhanced" width=800>

# 快速开始
我们在此处提供示例,帮助您快速掌握如何加载并使用我们的模型。
>注意: Rain's SQLCoder 只被训练用于生成 `SELECT` 语句,当表结构无法支持回答用户问题时,模型会拒绝回答。

````python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompt import SQLGeneratePrompt

model_name = "SuanChang/rain-SQLCoder"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

question = "What is the name of the department that offers a course that has a description including the word 'Statistics'?"
schemas = [
'''CREATE TABLE `course` (
    `crs_code` STRING,
    `dept_code` STRING,
    `crs_description` STRING,
    `crs_credit` DOUBLE
);''',
'''CREATE TABLE `department` (
    `dept_code` STRING,
    `dept_name` STRING,
    `school_code` STRING,
    `emp_num` INT,
    `dept_address` STRING,
    `dept_extension` INT
);''',
'''CREATE TABLE `student` (
    `stu_num` INT,
    `stu_lname` STRING,
    `stu_fname` STRING,
    `stu_init` STRING,
    `stu_dob` STRING,
    `stu_hrs` INT,
    `stu_class` STRING,
    `stu_gpa` DOUBLE,
    `stu_transfer` INT,
    `dept_code` STRING,
    `stu_phone` INT,
    `prof_num` INT
);'''
]
hint = "- Today is 2025-02-01."
data = dict(
    question=question,
    schema="\n\n".join(schemas),
    hint=hint,
    related_question_sqls=None,
)
text, _, _ = SQLGeneratePrompt.prompt(data)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

'''
```sql
SELECT d.dept_name FROM department d JOIN course c ON d.dept_code = c.dept_code WHERE c.crs_description LIKE '%Statistics%';
```
'''
````