adamboom111 commited on
Commit
f63709a
·
verified ·
1 Parent(s): aa4a308

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -22
app.py CHANGED
@@ -1,37 +1,52 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
4
- # Load the GaussAlgo model
5
- model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
 
 
9
  def generate_sql(payload):
10
- # Extract parts from the JSON payload
11
  question = payload.get("question", "")
12
  schema = payload.get("schema", "")
13
  sample_rows = payload.get("sample_rows", [])
14
-
15
- # Convert sample rows into a single string
16
- sample_str = " ".join([str(row) for row in sample_rows]) if sample_rows else ""
17
-
18
- # Build model input prompt
19
- prompt = f"Question: {question} Schema: {schema} Sample Rows: {sample_str}"
20
-
21
- # Tokenize and generate
22
- inputs = tokenizer(prompt, return_tensors="pt")
23
- outputs = model.generate(**inputs, max_length=512)
24
- generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  return generated_sql
27
 
28
- # Gradio interface
29
  demo = gr.Interface(
30
  fn=generate_sql,
31
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
32
  outputs="text",
33
- title="Text-to-SQL Generator",
34
- description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL."
35
  )
36
 
37
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load the DeepSeek model
6
+ model_name = "deepseek-ai/DeepSeek-V3" # Or "deepseek-ai/DeepSeek-R1-0528"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
9
 
10
+ # Function to handle JSON prompt for SQL generation
11
  def generate_sql(payload):
 
12
  question = payload.get("question", "")
13
  schema = payload.get("schema", "")
14
  sample_rows = payload.get("sample_rows", [])
15
+
16
+ sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
17
+
18
+ prompt = f"""
19
+ You are a text-to-SQL data analyst.
20
+ Based on the following information, write a clean SQL query that works with DuckDB. Do not hallucinate tables or fields.
21
+ Schema: {schema}
22
+ Sample Rows:
23
+ {sample_str}
24
+ Question: {question}
25
+
26
+ SQL:"""
27
+
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
+ outputs = model.generate(
30
+ **inputs,
31
+ max_new_tokens=128,
32
+ do_sample=True,
33
+ temperature=0.6,
34
+ top_p=0.95,
35
+ pad_token_id=tokenizer.eos_token_id
36
+ )
37
+
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ generated_sql = response.split("SQL:")[-1].strip()
40
+
41
  return generated_sql
42
 
43
+ # Launch Gradio interface
44
  demo = gr.Interface(
45
  fn=generate_sql,
46
  inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
47
  outputs="text",
48
+ title="Text-to-SQL (DeepSeek)",
49
+ description="Use DeepSeek to convert a natural language question and schema into SQL."
50
  )
51
 
52
+ demo.launch()