adamboom111 commited on
Commit
542b9fc
·
verified ·
1 Parent(s): 05ad593

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -33
app.py CHANGED
@@ -1,54 +1,39 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- model_path = "defog/sqlcoder-7b-2"
6
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
7
- model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
8
 
 
 
 
 
9
 
10
  def generate_sql(payload):
11
-
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
15
-
16
- sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
17
-
18
- prompt = f"""
19
- ### Task
20
- Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
21
 
22
- ### Database Schema
23
- The query will run on a database with the following schema:
24
- {schema}
25
 
26
- ### Sample Rows
27
- {sample_str}
28
 
29
- ### Answer
30
- Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
31
- [SQL]
32
- """.strip()
33
 
34
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_length=512,
38
- do_sample=False,
39
- num_beams=4,
40
- early_stopping=True
41
- )
42
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- return sql.split("[SQL]")[-1].strip()
44
 
 
45
 
 
46
  demo = gr.Interface(
47
  fn=generate_sql,
48
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
49
  outputs="text",
50
- title="SQLCoder - Text to SQL",
51
- description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL using Defog's sqlcoder-7b-2."
52
  )
53
 
54
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
 
 
 
4
 
5
+ # Load the GaussAlgo model
6
+ model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
 
10
  def generate_sql(payload):
11
+ # Extract parts from the JSON payload
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
 
 
 
 
 
 
15
 
 
 
 
16
 
17
+ # Convert sample rows into a single string
18
+ sample_str = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
19
 
20
+ # Build model input prompt
21
+ prompt = f"Question: {question} Schema: {schema} Sample Rows: {sample_str}"
 
 
22
 
23
+ # Tokenize and generate
24
+ inputs = tokenizer(prompt, return_tensors="pt")
25
+ outputs = model.generate(**inputs, max_length=512)
26
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
27
 
28
+ return generated_sql
29
 
30
+ # Gradio interface
31
  demo = gr.Interface(
32
  fn=generate_sql,
33
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
34
  outputs="text",
35
+ title="Text-to-SQL Generator",
36
+ description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL."
37
  )
38
 
39
  demo.launch()