File size: 6,140 Bytes
e6f4fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import json
import os
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm


def load_table_schemas(tables_file):
    """
    Load table schemas from the tables.jsonl file.
    
    Args:
        tables_file: Path to the tables.jsonl file.
    
    Returns:
        A dictionary mapping table IDs to their column names.
    """
    table_schemas = {}
    with open(tables_file, 'r') as f:
        for line in f:
            table_data = json.loads(line)
            table_id = table_data["id"]
            table_columns = table_data["header"]
            table_schemas[table_id] = table_columns
    return table_schemas


# Step 1: Load and Preprocess WikiSQL Data
def load_wikisql(data_dir):
    """
    Load WikiSQL data and prepare it for training.
    Args:
        data_dir: Path to the WikiSQL dataset directory.
    Returns:
        List of examples with input and target text.
    """
    def parse_file(file_path):
        with open(file_path, 'r') as f:
            return [json.loads(line) for line in f]

    tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl"))
    train_data = parse_file(os.path.join(data_dir, "train.jsonl"))
    dev_data = parse_file(os.path.join(data_dir, "dev.jsonl"))

    print("====>", train_data[0])
    tables_file = "./data/train.tables.jsonl"
    table_schemas = load_table_schemas(tables_file)

    dev_tables = './data/dev.tables.jsonl'
    dev_tables_schema = load_table_schemas(dev_tables)

    def format_data(data, type):
        formatted = []
        for item in data:
            table_id = item["table_id"]
            table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id]
            question = item["question"]
            sql = item["sql"]
            sql_query = sql_to_text(sql, table_columns)
            print("SQL Query", sql_query)
            formatted.append({"input": f"Question: {question}", "target": sql_query})
        return formatted

    return format_data(train_data, "train"), format_data(dev_data, "dev")


def sql_to_text(sql, table_columns):
    """
    Convert SQL dictionary from WikiSQL to text representation.
    
    Args:
        sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}).
        table_columns: List of column names corresponding to the table.
        
    Returns:
        SQL query as a string.
    """
    # Aggregation functions mapping
    agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
    operators = ["=", ">", "<"]

    # Get selected column
    sel_column = table_columns[sql["sel"]]
    agg_func = agg_functions[sql["agg"]]
    select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}"

    # Get conditions
    if sql["conds"]:
        conditions = []
        for cond in sql["conds"]:
            col_idx, operator, value = cond
            col_name = table_columns[col_idx]
            conditions.append(f"{col_name} {operators[operator]} '{value}'")
        where_clause = " WHERE " + " AND ".join(conditions)
    else:
        where_clause = ""

    # Combine clauses into a full query
    return select_clause + where_clause

# Step 2: Tokenize the Data
def tokenize_data(data, tokenizer, max_length=128):
    """
    Tokenize the input and target text.
    Args:
        data: List of examples with "input" and "target".
        tokenizer: Pretrained tokenizer.
        max_length: Maximum sequence length for the model.
    Returns:
        Tokenized dataset.
    """
    inputs = [item["input"] for item in data]
    targets = [item["target"] for item in data]

    tokenized = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        targets,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    tokenized["labels"] = labels["input_ids"]
    return tokenized


# Step 3: Load Model and Tokenizer
model_name = "t5-small"  # Use "t5-small", "t5-base", or "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Step 4: Prepare Training and Validation Data
data_dir = "data"  # Path to the WikiSQL dataset
train_data, dev_data = load_wikisql(data_dir)

# Tokenize Data
train_dataset = tokenize_data(train_data, tokenizer)
dev_dataset = tokenize_data(dev_data, tokenizer)

# # Convert to Hugging Face Dataset format
train_dataset = Dataset.from_dict(train_dataset)
dev_dataset = Dataset.from_dict(dev_dataset)

# # # Step 5: Define Training Arguments
# training_args = Seq2SeqTrainingArguments(
#     output_dir="./t5_sql_finetuned",
#     evaluation_strategy="steps",
#     save_steps=1000,
#     eval_steps=100,
#     logging_steps=100,
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=16,
#     num_train_epochs=3,
#     save_total_limit=2,
#     learning_rate=5e-5,
#     predict_with_generate=True,
#     fp16=torch.cuda.is_available(),  # Enable mixed precision for faster training
#     logging_dir="./logs",
# )

# # # Step 6: Define Trainer
# trainer = Seq2SeqTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=dev_dataset,
#     tokenizer=tokenizer,
# )

# # # Step 7: Train the Model
# trainer.train()

# # # Step 8: Save the Model
# trainer.save_model("./t5_sql_finetuned")
# tokenizer.save_pretrained("./t5_sql_finetuned")

# # Step 9: Test the Model
test_question = "Find all orders with product_id greater than 5."
input_text = f"Question: {test_question}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

outputs = model.generate(**inputs, max_length=128)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated SQL:", generated_sql)