aarohanverma commited on
Commit
816ccb1
·
verified ·
1 Parent(s): c81583e

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+
5
+ # Set up device (GPU if available)
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # Load the fine-tuned model and tokenizer
9
+ model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned" # Replace with your model repo name
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
12
+
13
+ def generate_sql(context: str, query: str) -> str:
14
+ """
15
+ Constructs a prompt using the user-provided context and query, then generates a SQL query.
16
+ """
17
+ prompt = f"""Context:
18
+ {context}
19
+
20
+ Query:
21
+ {query}
22
+
23
+ Response:
24
+ """
25
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
+ generated_ids = model.generate(
27
+ input_ids=inputs["input_ids"],
28
+ max_new_tokens=250,
29
+ temperature=0.0, # Deterministic output
30
+ num_beams=3, # Beam search for quality output
31
+ early_stopping=True,
32
+ )
33
+ return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
34
+
35
+ # Create a Gradio interface with two input boxes: one for context, one for query.
36
+ iface = gr.Interface(
37
+ fn=generate_sql,
38
+ inputs=[
39
+ gr.Textbox(lines=8, label="Context", placeholder="Enter table schema, sample data, etc."),
40
+ gr.Textbox(lines=2, label="Query", placeholder="Enter your natural language query here...")
41
+ ],
42
+ outputs="text",
43
+ title="Text-to-SQL Generator",
44
+ description="Enter your own context (e.g., database schema and sample data) and a natural language query. The model will generate the corresponding SQL statement."
45
+ )
46
+
47
+ iface.launch()