adamboom111 commited on
Commit
aa4a308
·
verified ·
1 Parent(s): 24c01c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -37
app.py CHANGED
@@ -1,51 +1,37 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- model_path = "defog/sqlcoder-7b-2"
6
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
7
- model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
 
8
 
9
  def generate_sql(payload):
 
10
  question = payload.get("question", "")
11
  schema = payload.get("schema", "")
12
  sample_rows = payload.get("sample_rows", [])
13
-
14
- sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
15
-
16
- prompt = f"""
17
- ### Task
18
- Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
19
-
20
- ### Database Schema
21
- The query will run on a database with the following schema:
22
- {schema}
23
-
24
- ### Sample Rows
25
- {sample_str}
26
-
27
- ### Answer
28
- Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
29
- [SQL]
30
- """.strip()
31
-
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
- outputs = model.generate(
34
- **inputs,
35
- max_length=512,
36
- do_sample=False,
37
- num_beams=4,
38
- early_stopping=True
39
- )
40
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
- return sql.split("[SQL]")[-1].strip()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  demo = gr.Interface(
44
  fn=generate_sql,
45
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
46
  outputs="text",
47
- title="SQLCoder - Text to SQL",
48
- description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL using Defog's sqlcoder-7b-2."
49
  )
50
 
51
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
4
+ # Load the GaussAlgo model
5
+ model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
6
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
 
9
  def generate_sql(payload):
10
+ # Extract parts from the JSON payload
11
  question = payload.get("question", "")
12
  schema = payload.get("schema", "")
13
  sample_rows = payload.get("sample_rows", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Convert sample rows into a single string
16
+ sample_str = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
17
+
18
+ # Build model input prompt
19
+ prompt = f"Question: {question} Schema: {schema} Sample Rows: {sample_str}"
20
+
21
+ # Tokenize and generate
22
+ inputs = tokenizer(prompt, return_tensors="pt")
23
+ outputs = model.generate(**inputs, max_length=512)
24
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
+
26
+ return generated_sql
27
+
28
+ # Gradio interface
29
  demo = gr.Interface(
30
  fn=generate_sql,
31
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
32
  outputs="text",
33
+ title="Text-to-SQL Generator",
34
+ description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL."
35
  )
36
 
37
+ demo.launch()