HusnaManakkot commited on
Commit
704b569
·
verified ·
1 Parent(s): e88c923

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from datasets import load_dataset
4
+
5
+ # Load the Spider dataset
6
+ spider_dataset = load_dataset("HusnaManakkot/new-spider-HM", split='train') # Load a subset of the dataset
7
+ # Extract schema information from the Spider dataset
8
+ table_names = set()
9
+ column_names = set()
10
+ for item in spider_dataset:
11
+ for table in item['db_id']:
12
+ table_names.add(table)
13
+ for column in item['question']:
14
+ column_names.add(column)
15
+
16
+ # Load tokenizer and model
17
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
+
20
+ def generate_sql_from_user_input(query):
21
+ # Generate SQL for the user's query
22
+ input_text = "translate English to SQL: " + query
23
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
24
+ outputs = model.generate(**inputs, max_length=512)
25
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
+ # Post-process the SQL query to match the dataset's schema
28
+ for table_name in table_names:
29
+ if "TABLE" in sql_query:
30
+ sql_query = sql_query.replace("TABLE", table_name)
31
+ break # Assuming only one table is referenced in the query
32
+ for column_name in column_names:
33
+ if "COLUMN" in sql_query:
34
+ sql_query = sql_query.replace("COLUMN", column_name, 1)
35
+ return sql_query
36
+
37
+ # Create a Gradio interface
38
+ interface = gr.Interface(
39
+ fn=generate_sql_from_user_input,
40
+ inputs=gr.Textbox(label="Enter your natural language query"),
41
+ outputs=gr.Textbox(label="Generated SQL Query"),
42
+ title="NL to SQL with T5 using Spider Dataset",
43
+ description="This model generates an SQL query for your natural language input based on the Spider dataset."
44
+ )
45
+
46
+ # Launch the app
47
+ if __name__ == "__main__":
48
+ interface.launch()