Rajan commited on
Commit
c3b3775
·
verified ·
1 Parent(s): 7d64919

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from peft import PeftConfig, PeftModel
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from schema_extractor import SQLiteSchemaExtractor
10
+
11
+
12
+ # Load model and tokenizer
13
+ def load_model():
14
+ config = PeftConfig.from_pretrained("Rajan/training_run")
15
+ tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-350M")
16
+ base_model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-350M")
17
+ model = PeftModel.from_pretrained(base_model, "Rajan/training_run")
18
+ return model, tokenizer
19
+
20
+
21
+ # Extract and correct SQL from generated text
22
+ def extract_and_correct_sql(text, correct=False):
23
+ lines = text.splitlines()
24
+
25
+ start_index = 0
26
+ for i, line in enumerate(lines):
27
+ if line.strip().upper().startswith("SELECT"):
28
+ start_index = i
29
+ break
30
+
31
+ generated_sql = "\n".join(lines[start_index:])
32
+
33
+ if correct:
34
+ if not generated_sql.strip().endswith(";"):
35
+ generated_sql = generated_sql.strip() + ";"
36
+
37
+ return generated_sql
38
+
39
+
40
+ # Function to handle file upload and schema extraction
41
+ def upload_and_extract_schema(db_file):
42
+ if db_file is None:
43
+ return "Please upload a database file", None
44
+
45
+ try:
46
+ # Get the file path directly from Gradio
47
+ temp_db_path = db_file.name
48
+
49
+ extractor = SQLiteSchemaExtractor(temp_db_path)
50
+ schema = extractor.get_schema()
51
+ return schema, temp_db_path
52
+ except Exception as e:
53
+ return f"Error extracting schema: {str(e)}", None
54
+
55
+
56
+ # Function to handle chat interaction with streaming effect
57
+ def generate_sql(question, schema, db_path, chat_history):
58
+ if db_path is None or not schema:
59
+ return (
60
+ chat_history
61
+ + [
62
+ {"role": "user", "content": question},
63
+ {"role": "assistant", "content": "Please upload a database file first"},
64
+ ],
65
+ None,
66
+ )
67
+
68
+ try:
69
+ # Load model
70
+ model, tokenizer = load_model()
71
+
72
+ # Format prompt
73
+ prompt_format = """
74
+ {}
75
+ -- Using valid SQLite, answer the following questions for the tables provided above.
76
+ {}
77
+ SELECT"""
78
+
79
+ # Format the prompt with schema and question
80
+ prompt = prompt_format.format(schema, question)
81
+
82
+ # Generate SQL
83
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
84
+ generated_ids = model.generate(input_ids, max_length=500)
85
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
86
+
87
+ # Extract SQL
88
+ sql_query = extract_and_correct_sql(generated_text, correct=True)
89
+
90
+ # Update history using dictionary format
91
+ new_history = chat_history + [
92
+ {"role": "user", "content": question},
93
+ {"role": "assistant", "content": sql_query},
94
+ ]
95
+ return new_history, sql_query
96
+ except Exception as e:
97
+ error_message = f"Error: {str(e)}"
98
+ return (
99
+ chat_history
100
+ + [
101
+ {"role": "user", "content": question},
102
+ {"role": "assistant", "content": error_message},
103
+ ],
104
+ None,
105
+ )
106
+
107
+
108
+ # Function for streaming SQL generation effect
109
+ def stream_sql(question, schema, db_path, chat_history):
110
+ # First add the user message to chat
111
+ yield chat_history + [{"role": "user", "content": question}], ""
112
+
113
+ if db_path is None or not schema:
114
+ yield chat_history + [
115
+ {"role": "user", "content": question},
116
+ {"role": "assistant", "content": "Please upload a database file first"},
117
+ ], "Please upload a database file first"
118
+ return
119
+
120
+ try:
121
+ # Load model
122
+ model, tokenizer = load_model()
123
+
124
+ # Format prompt
125
+ prompt_format = """
126
+ {}
127
+ -- Using valid SQLite, answer the following questions for the tables provided above.
128
+ {}
129
+ SELECT"""
130
+
131
+ # Format the prompt with schema and question
132
+ prompt = prompt_format.format(schema, question)
133
+
134
+ # Generate SQL
135
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
136
+ generated_ids = model.generate(input_ids, max_length=500)
137
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
138
+
139
+ # Extract SQL
140
+ sql_query = extract_and_correct_sql(generated_text, correct=True)
141
+
142
+ # Fixed medium speed (0.03 seconds delay)
143
+ import time
144
+
145
+ delay = 0.03 # 30ms - normal typing speed
146
+
147
+ # Stream the SQL query character by character for effect
148
+ partial_sql = ""
149
+ for char in sql_query:
150
+ partial_sql += char
151
+ # Update chat history and SQL display with each character
152
+ yield chat_history + [
153
+ {"role": "user", "content": question},
154
+ {"role": "assistant", "content": partial_sql},
155
+ ], partial_sql
156
+ time.sleep(delay) # Medium speed typing effect
157
+
158
+ except Exception as e:
159
+ error_message = f"Error: {str(e)}"
160
+ yield chat_history + [
161
+ {"role": "user", "content": question},
162
+ {"role": "assistant", "content": error_message},
163
+ ], error_message
164
+
165
+
166
+ # Main application
167
+ def create_app():
168
+ with gr.Blocks(title="SQL Query Generator", theme=gr.themes.Soft()) as app:
169
+ gr.Markdown("# SQL Query Generator")
170
+ gr.Markdown(
171
+ "Upload a SQLite database, ask questions, and get SQL queries automatically generated"
172
+ )
173
+
174
+ # Store database path
175
+ db_path = gr.State(value=None)
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ # File upload section
180
+ file_input = gr.File(label="Upload SQLite Database (.db file)")
181
+ extract_btn = gr.Button("Extract Schema", variant="primary")
182
+
183
+ # Schema display
184
+ schema_output = gr.Textbox(
185
+ label="Database Schema", lines=10, interactive=False
186
+ )
187
+
188
+ with gr.Column(scale=2):
189
+ # Chat interface
190
+ chatbot = gr.Chatbot(
191
+ label="Query Conversation", height=400, type="messages"
192
+ )
193
+
194
+ with gr.Row():
195
+ question_input = gr.Textbox(
196
+ label="Ask a question about your data",
197
+ placeholder="e.g., Show me the top 10 most sold items",
198
+ )
199
+ submit_btn = gr.Button("Generate SQL", variant="primary")
200
+
201
+ # SQL output display
202
+ sql_output = gr.Code(
203
+ language="sql", label="Generated SQL Query", interactive=False
204
+ )
205
+
206
+ # Event handlers
207
+ extract_btn.click(
208
+ fn=upload_and_extract_schema,
209
+ inputs=[file_input],
210
+ outputs=[schema_output, db_path],
211
+ )
212
+
213
+ submit_btn.click(
214
+ fn=stream_sql,
215
+ inputs=[question_input, schema_output, db_path, chatbot],
216
+ outputs=[chatbot, sql_output],
217
+ api_name="generate",
218
+ queue=True,
219
+ )
220
+
221
+ # Also trigger on enter key
222
+ question_input.submit(
223
+ fn=stream_sql,
224
+ inputs=[question_input, schema_output, db_path, chatbot],
225
+ outputs=[chatbot, sql_output],
226
+ api_name="generate_on_submit",
227
+ queue=True,
228
+ )
229
+
230
+ return app
231
+
232
+
233
+ # Launch the app
234
+ if __name__ == "__main__":
235
+ app = create_app()
236
+ app.launch(share=True)