adamboom111 commited on
Commit
8f665dd
·
verified ·
1 Parent(s): cb04666

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -1,37 +1,51 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
4
- # Load the GaussAlgo model
5
- model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
 
9
  def generate_sql(payload):
10
- # Extract parts from the JSON payload
11
  question = payload.get("question", "")
12
  schema = payload.get("schema", "")
13
  sample_rows = payload.get("sample_rows", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Convert sample rows into a single string
16
- sample_str = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
17
-
18
- # Build model input prompt
19
- prompt = f"Question: {question} Schema: {schema} Sample Rows: {sample_str}"
20
-
21
- # Tokenize and generate
22
- inputs = tokenizer(prompt, return_tensors="pt")
23
- outputs = model.generate(**inputs, max_length=512)
24
- generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
-
26
- return generated_sql
27
-
28
- # Gradio interface
29
  demo = gr.Interface(
30
  fn=generate_sql,
31
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
32
  outputs="text",
33
- title="Text-to-SQL Generator",
34
- description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL."
35
  )
36
 
37
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ model_path = "defog/sqlcoder-7b-2"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
7
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
 
8
 
9
  def generate_sql(payload):
 
10
  question = payload.get("question", "")
11
  schema = payload.get("schema", "")
12
  sample_rows = payload.get("sample_rows", [])
13
+
14
+ sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
15
+
16
+ prompt = f"""
17
+ ### Task
18
+ Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
19
+
20
+ ### Database Schema
21
+ The query will run on a database with the following schema:
22
+ {schema}
23
+
24
+ ### Sample Rows
25
+ {sample_str}
26
+
27
+ ### Answer
28
+ Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
29
+ [SQL]
30
+ """.strip()
31
+
32
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
+ outputs = model.generate(
34
+ **inputs,
35
+ max_length=512,
36
+ do_sample=False,
37
+ num_beams=4,
38
+ early_stopping=True
39
+ )
40
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ return sql.split("[SQL]")[-1].strip()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  demo = gr.Interface(
44
  fn=generate_sql,
45
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
46
  outputs="text",
47
+ title="SQLCoder - Text to SQL",
48
+ description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL using Defog's sqlcoder-7b-2."
49
  )
50
 
51
  demo.launch()