adamboom111 commited on
Commit
cc6747b
·
verified ·
1 Parent(s): de0a20e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -28
app.py CHANGED
@@ -1,38 +1,32 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Load model stick with this one for now, later we can upgrade
5
- generator = pipeline("text2text-generation", model="mrm8488/t5-base-finetuned-wikiSQL")
 
 
6
 
7
- def convert_to_sql(payload):
 
8
  question = payload.get("question", "")
9
  schema = payload.get("schema", "")
10
- sample_rows = payload.get("sample_rows", [])
11
-
12
- # Craft prompt
13
- prompt = f"""
14
- You are an AI that converts natural language into SQL for a DuckDB database.
15
- Given a table with the following schema:
16
- {schema}
17
-
18
- Here are some sample rows:
19
- {sample_rows}
20
-
21
- Write a syntactically correct SQL query (DuckDB-compatible) to answer this question: "{question}"
22
-
23
- Only return the SQL query — no explanation, no markdown.
24
- """.strip()
25
-
26
- result = generator(prompt, max_length=256)[0]["generated_text"]
27
- return result.strip()
28
-
29
- # Define inputs/outputs for interactive mode (not used by FastAPI)
30
  demo = gr.Interface(
31
- fn=convert_to_sql,
32
- inputs=gr.JSON(label="question + schema + sample_rows"),
33
  outputs="text",
34
- title="Text-to-SQL Generator (DuckDB)",
35
- description="Send a JSON payload with question, schema, and sample rows"
36
  )
37
 
38
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
+ # Load the GaussAlgo model
5
+ model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
6
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
 
9
+ def generate_sql(payload):
10
+ # Extract components from payload
11
  question = payload.get("question", "")
12
  schema = payload.get("schema", "")
13
+
14
+ # Build model input
15
+ full_prompt = f"Question: {question} Schema: {schema}"
16
+
17
+ inputs = tokenizer(full_prompt, return_tensors="pt")
18
+ outputs = model.generate(**inputs, max_length=512)
19
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ return generated_sql
22
+
23
+ # Define expected input as a JSON object (dict)
 
 
 
 
 
 
 
 
 
24
  demo = gr.Interface(
25
+ fn=generate_sql,
26
+ inputs=gr.JSON(label="Input JSON (with 'question' and 'schema')"),
27
  outputs="text",
28
+ title="Text-to-SQL Generator",
29
+ description="Input a JSON with your natural language question and database schema. Output is SQL."
30
  )
31
 
32
  demo.launch()