adamboom111 commited on
Commit
42749e6
·
verified ·
1 Parent(s): 542b9fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -1,39 +1,41 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
-
5
- # Load the GaussAlgo model
6
- model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
8
- tokenizer = AutoTokenizer.from_pretrained(model_path)
9
 
10
  def generate_sql(payload):
11
- # Extract parts from the JSON payload
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
15
 
 
 
16
 
17
- # Convert sample rows into a single string
18
- sample_str = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
19
-
20
- # Build model input prompt
21
- prompt = f"Question: {question} Schema: {schema} Sample Rows: {sample_str}"
 
 
 
22
 
23
- # Tokenize and generate
24
- inputs = tokenizer(prompt, return_tensors="pt")
25
- outputs = model.generate(**inputs, max_length=512)
26
- generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
28
- return generated_sql
29
 
30
- # Gradio interface
31
  demo = gr.Interface(
32
  fn=generate_sql,
33
- inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
34
  outputs="text",
35
- title="Text-to-SQL Generator",
36
- description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL."
37
  )
38
 
39
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
 
4
+ # Load FLAN-T5-small
5
+ model_name = "google/flan-t5-small"
6
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
7
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
 
8
 
9
  def generate_sql(payload):
 
10
  question = payload.get("question", "")
11
  schema = payload.get("schema", "")
12
  sample_rows = payload.get("sample_rows", [])
13
 
14
+ # Convert sample rows into flat string
15
+ rows_text = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
16
 
17
+ # Construct prompt for instruction tuning
18
+ prompt = (
19
+ f"You are a SQL expert.\n"
20
+ f"Schema: {schema}\n"
21
+ f"Sample Rows: {rows_text}\n"
22
+ f"Question: {question}\n"
23
+ f"Generate SQL:"
24
+ )
25
 
26
+ # Tokenize and generate SQL
27
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
28
+ outputs = model.generate(input_ids, max_length=256, temperature=0.6)
29
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
+ return sql
32
 
 
33
  demo = gr.Interface(
34
  fn=generate_sql,
35
+ inputs=gr.JSON(label="JSON (question, schema, sample_rows)"),
36
  outputs="text",
37
+ title="FLAN-T5 Text-to-SQL",
38
+ description="Using FLAN-T5 to generate SQL from natural language and tabular schema."
39
  )
40
 
41
+ demo.launch()