Rajan commited on
Commit
7d64919
·
verified ·
1 Parent(s): be38eaf

Upload 3 files

Browse files

Add basic requirements

Files changed (3) hide show
  1. requirements.txt +4 -0
  2. schema_extractor.py +132 -0
  3. ui.py +236 -0
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==5.23.1
2
+ peft==0.15.1
3
+ torch==2.6.0+cu118
4
+ transformers==4.50.3
schema_extractor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import Optional
3
+
4
+
5
+ class SQLiteSchemaExtractor:
6
+ """
7
+ A class to extract the schema from an SQLite database file and return it as formatted text.
8
+ """
9
+
10
+ def __init__(self, db_path: str):
11
+ """
12
+ Initialize the extractor with the path to the SQLite database file.
13
+
14
+ Args:
15
+ db_path: Path to the SQLite database file
16
+ """
17
+ self.db_path = db_path
18
+ self.connection = None
19
+
20
+ def connect(self):
21
+ """Establish a connection to the SQLite database."""
22
+ try:
23
+ self.connection = sqlite3.connect(self.db_path)
24
+ return True
25
+ except sqlite3.Error as e:
26
+ print(f"Error connecting to database: {e}")
27
+ return False
28
+
29
+ def close(self):
30
+ """Close the database connection if it exists."""
31
+ if self.connection:
32
+ self.connection.close()
33
+
34
+ def get_schema(self) -> Optional[str]:
35
+ """
36
+ Extract the schema from the SQLite database and return it as formatted text.
37
+
38
+ Returns:
39
+ A string containing the formatted schema, or None if an error occurred
40
+ """
41
+ if not self.connect():
42
+ return None
43
+
44
+ try:
45
+ cursor = self.connection.cursor()
46
+
47
+ # Get the list of all tables
48
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
49
+ tables = cursor.fetchall()
50
+
51
+ schema_text = ""
52
+
53
+ for table in tables:
54
+ table_name = table[0]
55
+
56
+ # Skip SQLite internal tables
57
+ if table_name.startswith("sqlite_"):
58
+ continue
59
+
60
+ # Get the CREATE TABLE statement for the current table
61
+ cursor.execute(f"PRAGMA table_info({table_name});")
62
+ columns = cursor.fetchall()
63
+
64
+ # Format the CREATE TABLE statement
65
+ schema_text += f"CREATE TABLE {table_name} (\n"
66
+
67
+ for i, col in enumerate(columns):
68
+ # Column format: column_name data_type
69
+ col_id, col_name, col_type, not_null, default_val, pk = col
70
+
71
+ # Map SQLite types to our simplified type system
72
+ simplified_type = self._simplify_type(col_type.lower())
73
+
74
+ # Add column definition
75
+ schema_text += f" {col_name} {simplified_type}"
76
+
77
+ # Add comma if not the last column
78
+ if i < len(columns) - 1:
79
+ schema_text += ","
80
+
81
+ schema_text += "\n"
82
+
83
+ schema_text += ")\n\n"
84
+
85
+ # Remove the last newline
86
+ if schema_text.endswith("\n\n"):
87
+ schema_text = schema_text[:-1]
88
+
89
+ return schema_text
90
+
91
+ except sqlite3.Error as e:
92
+ print(f"Error extracting schema: {e}")
93
+ return None
94
+ finally:
95
+ self.close()
96
+
97
+ def _simplify_type(self, sqlite_type: str) -> str:
98
+ """
99
+ Convert SQLite types to simplified types (number, text, others).
100
+
101
+ Args:
102
+ sqlite_type: The SQLite data type
103
+
104
+ Returns:
105
+ A simplified type name
106
+ """
107
+ if any(
108
+ num_type in sqlite_type
109
+ for num_type in ["int", "real", "floa", "doub", "num", "dec"]
110
+ ):
111
+ return "number"
112
+ elif any(
113
+ text_type in sqlite_type
114
+ for text_type in ["text", "char", "clob", "varchar"]
115
+ ):
116
+ return "text"
117
+ else:
118
+ return "others"
119
+
120
+
121
+ # # Example usage:
122
+ # if __name__ == "__main__":
123
+ # # Replace with your SQLite database file path
124
+ # db_file = "path/to/your/database.db"
125
+
126
+ # extractor = SQLiteSchemaExtractor(db_file)
127
+ # schema = extractor.get_schema()
128
+
129
+ # if schema:
130
+ # print(schema)
131
+ # else:
132
+ # print("Failed to extract schema.")
ui.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from peft import PeftConfig, PeftModel
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from schema_extractor import SQLiteSchemaExtractor
10
+
11
+
12
+ # Load model and tokenizer
13
+ def load_model():
14
+ config = PeftConfig.from_pretrained("Rajan/training_run")
15
+ tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-350M")
16
+ base_model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-350M")
17
+ model = PeftModel.from_pretrained(base_model, "Rajan/training_run")
18
+ return model, tokenizer
19
+
20
+
21
+ # Extract and correct SQL from generated text
22
+ def extract_and_correct_sql(text, correct=False):
23
+ lines = text.splitlines()
24
+
25
+ start_index = 0
26
+ for i, line in enumerate(lines):
27
+ if line.strip().upper().startswith("SELECT"):
28
+ start_index = i
29
+ break
30
+
31
+ generated_sql = "\n".join(lines[start_index:])
32
+
33
+ if correct:
34
+ if not generated_sql.strip().endswith(";"):
35
+ generated_sql = generated_sql.strip() + ";"
36
+
37
+ return generated_sql
38
+
39
+
40
+ # Function to handle file upload and schema extraction
41
+ def upload_and_extract_schema(db_file):
42
+ if db_file is None:
43
+ return "Please upload a database file", None
44
+
45
+ try:
46
+ # Get the file path directly from Gradio
47
+ temp_db_path = db_file.name
48
+
49
+ extractor = SQLiteSchemaExtractor(temp_db_path)
50
+ schema = extractor.get_schema()
51
+ return schema, temp_db_path
52
+ except Exception as e:
53
+ return f"Error extracting schema: {str(e)}", None
54
+
55
+
56
+ # Function to handle chat interaction with streaming effect
57
+ def generate_sql(question, schema, db_path, chat_history):
58
+ if db_path is None or not schema:
59
+ return (
60
+ chat_history
61
+ + [
62
+ {"role": "user", "content": question},
63
+ {"role": "assistant", "content": "Please upload a database file first"},
64
+ ],
65
+ None,
66
+ )
67
+
68
+ try:
69
+ # Load model
70
+ model, tokenizer = load_model()
71
+
72
+ # Format prompt
73
+ prompt_format = """
74
+ {}
75
+ -- Using valid SQLite, answer the following questions for the tables provided above.
76
+ {}
77
+ SELECT"""
78
+
79
+ # Format the prompt with schema and question
80
+ prompt = prompt_format.format(schema, question)
81
+
82
+ # Generate SQL
83
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
84
+ generated_ids = model.generate(input_ids, max_length=500)
85
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
86
+
87
+ # Extract SQL
88
+ sql_query = extract_and_correct_sql(generated_text, correct=True)
89
+
90
+ # Update history using dictionary format
91
+ new_history = chat_history + [
92
+ {"role": "user", "content": question},
93
+ {"role": "assistant", "content": sql_query},
94
+ ]
95
+ return new_history, sql_query
96
+ except Exception as e:
97
+ error_message = f"Error: {str(e)}"
98
+ return (
99
+ chat_history
100
+ + [
101
+ {"role": "user", "content": question},
102
+ {"role": "assistant", "content": error_message},
103
+ ],
104
+ None,
105
+ )
106
+
107
+
108
+ # Function for streaming SQL generation effect
109
+ def stream_sql(question, schema, db_path, chat_history):
110
+ # First add the user message to chat
111
+ yield chat_history + [{"role": "user", "content": question}], ""
112
+
113
+ if db_path is None or not schema:
114
+ yield chat_history + [
115
+ {"role": "user", "content": question},
116
+ {"role": "assistant", "content": "Please upload a database file first"},
117
+ ], "Please upload a database file first"
118
+ return
119
+
120
+ try:
121
+ # Load model
122
+ model, tokenizer = load_model()
123
+
124
+ # Format prompt
125
+ prompt_format = """
126
+ {}
127
+ -- Using valid SQLite, answer the following questions for the tables provided above.
128
+ {}
129
+ SELECT"""
130
+
131
+ # Format the prompt with schema and question
132
+ prompt = prompt_format.format(schema, question)
133
+
134
+ # Generate SQL
135
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
136
+ generated_ids = model.generate(input_ids, max_length=500)
137
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
138
+
139
+ # Extract SQL
140
+ sql_query = extract_and_correct_sql(generated_text, correct=True)
141
+
142
+ # Fixed medium speed (0.03 seconds delay)
143
+ import time
144
+
145
+ delay = 0.03 # 30ms - normal typing speed
146
+
147
+ # Stream the SQL query character by character for effect
148
+ partial_sql = ""
149
+ for char in sql_query:
150
+ partial_sql += char
151
+ # Update chat history and SQL display with each character
152
+ yield chat_history + [
153
+ {"role": "user", "content": question},
154
+ {"role": "assistant", "content": partial_sql},
155
+ ], partial_sql
156
+ time.sleep(delay) # Medium speed typing effect
157
+
158
+ except Exception as e:
159
+ error_message = f"Error: {str(e)}"
160
+ yield chat_history + [
161
+ {"role": "user", "content": question},
162
+ {"role": "assistant", "content": error_message},
163
+ ], error_message
164
+
165
+
166
+ # Main application
167
+ def create_app():
168
+ with gr.Blocks(title="SQL Query Generator", theme=gr.themes.Soft()) as app:
169
+ gr.Markdown("# SQL Query Generator")
170
+ gr.Markdown(
171
+ "Upload a SQLite database, ask questions, and get SQL queries automatically generated"
172
+ )
173
+
174
+ # Store database path
175
+ db_path = gr.State(value=None)
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ # File upload section
180
+ file_input = gr.File(label="Upload SQLite Database (.db file)")
181
+ extract_btn = gr.Button("Extract Schema", variant="primary")
182
+
183
+ # Schema display
184
+ schema_output = gr.Textbox(
185
+ label="Database Schema", lines=10, interactive=False
186
+ )
187
+
188
+ with gr.Column(scale=2):
189
+ # Chat interface
190
+ chatbot = gr.Chatbot(
191
+ label="Query Conversation", height=400, type="messages"
192
+ )
193
+
194
+ with gr.Row():
195
+ question_input = gr.Textbox(
196
+ label="Ask a question about your data",
197
+ placeholder="e.g., Show me the top 10 most sold items",
198
+ )
199
+ submit_btn = gr.Button("Generate SQL", variant="primary")
200
+
201
+ # SQL output display
202
+ sql_output = gr.Code(
203
+ language="sql", label="Generated SQL Query", interactive=False
204
+ )
205
+
206
+ # Event handlers
207
+ extract_btn.click(
208
+ fn=upload_and_extract_schema,
209
+ inputs=[file_input],
210
+ outputs=[schema_output, db_path],
211
+ )
212
+
213
+ submit_btn.click(
214
+ fn=stream_sql,
215
+ inputs=[question_input, schema_output, db_path, chatbot],
216
+ outputs=[chatbot, sql_output],
217
+ api_name="generate",
218
+ queue=True,
219
+ )
220
+
221
+ # Also trigger on enter key
222
+ question_input.submit(
223
+ fn=stream_sql,
224
+ inputs=[question_input, schema_output, db_path, chatbot],
225
+ outputs=[chatbot, sql_output],
226
+ api_name="generate_on_submit",
227
+ queue=True,
228
+ )
229
+
230
+ return app
231
+
232
+
233
+ # Launch the app
234
+ if __name__ == "__main__":
235
+ app = create_app()
236
+ app.launch(share=True)