adamboom111 commited on
Commit
05ad593
·
verified ·
1 Parent(s): 5606827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -2,51 +2,53 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load the DeepSeek model
6
- model_name = "deepseek-ai/DeepSeek-V3" # Or "deepseek-ai/DeepSeek-R1-0528"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
9
 
10
- # Function to handle JSON prompt for SQL generation
11
  def generate_sql(payload):
 
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
15
 
16
  sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
17
-
18
  prompt = f"""
19
- You are a text-to-SQL data analyst.
20
- Based on the following information, write a clean SQL query that works with DuckDB. Do not hallucinate tables or fields.
21
- Schema: {schema}
22
- Sample Rows:
 
 
 
 
23
  {sample_str}
24
- Question: {question}
25
 
26
- SQL:"""
 
 
 
27
 
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
  outputs = model.generate(
30
  **inputs,
31
- max_new_tokens=128,
32
- do_sample=True,
33
- temperature=0.6,
34
- top_p=0.95,
35
- pad_token_id=tokenizer.eos_token_id
36
  )
37
-
38
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
- generated_sql = response.split("SQL:")[-1].strip()
40
-
41
- return generated_sql
42
 
43
- # Launch Gradio interface
44
  demo = gr.Interface(
45
  fn=generate_sql,
46
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
47
  outputs="text",
48
- title="Text-to-SQL (DeepSeek)",
49
- description="Use DeepSeek to convert a natural language question and schema into SQL."
50
  )
51
 
52
- demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ model_path = "defog/sqlcoder-7b-2"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
7
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
8
+
9
 
 
10
  def generate_sql(payload):
11
+
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
15
 
16
  sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
17
+
18
  prompt = f"""
19
+ ### Task
20
+ Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
21
+
22
+ ### Database Schema
23
+ The query will run on a database with the following schema:
24
+ {schema}
25
+
26
+ ### Sample Rows
27
  {sample_str}
 
28
 
29
+ ### Answer
30
+ Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
31
+ [SQL]
32
+ """.strip()
33
 
34
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
35
  outputs = model.generate(
36
  **inputs,
37
+ max_length=512,
38
+ do_sample=False,
39
+ num_beams=4,
40
+ early_stopping=True
 
41
  )
42
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ return sql.split("[SQL]")[-1].strip()
44
+
 
 
45
 
 
46
  demo = gr.Interface(
47
  fn=generate_sql,
48
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
49
  outputs="text",
50
+ title="SQLCoder - Text to SQL",
51
+ description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL using Defog's sqlcoder-7b-2."
52
  )
53
 
54
+ demo.launch()