aarohanverma commited on
Commit
27e057c
·
verified ·
1 Parent(s): 1b52ffd

Fixed hallucinations

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -14,7 +14,7 @@ def generate_sql(context: str, query: str) -> str:
14
  """
15
  Generates a SQL query given the provided context and natural language query.
16
  Constructs a prompt from the inputs, then performs deterministic generation
17
- with beam search.
18
  """
19
  prompt = f"""Context:
20
  {context}
@@ -25,24 +25,28 @@ Query:
25
  Response:
26
  """
27
  # Tokenize the prompt and move to device
28
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
29
 
30
  # Ensure decoder_start_token_id is set for encoder-decoder generation
31
  if model.config.decoder_start_token_id is None:
32
  model.config.decoder_start_token_id = tokenizer.pad_token_id
33
 
34
- # Generate the SQL output
35
  generated_ids = model.generate(
36
  input_ids=inputs["input_ids"],
37
  decoder_start_token_id=model.config.decoder_start_token_id,
38
- max_new_tokens=250,
39
- temperature=0.0, # Deterministic output
40
- num_beams=3, # Beam search for improved quality
41
- early_stopping=True, # Stop when output is complete
 
42
  )
43
 
44
- # Decode and return the generated SQL statement
45
- return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
46
 
47
  # Create Gradio interface with two input boxes: one for context and one for query
48
  iface = gr.Interface(
@@ -53,7 +57,9 @@ iface = gr.Interface(
53
  ],
54
  outputs="text",
55
  title="Text-to-SQL Generator",
56
- description="Enter your own context (e.g., database schema and sample data) and a natural language query. The model will generate the corresponding SQL statement."
 
 
57
  )
58
 
59
  iface.launch()
 
14
  """
15
  Generates a SQL query given the provided context and natural language query.
16
  Constructs a prompt from the inputs, then performs deterministic generation
17
+ with beam search and repetition handling.
18
  """
19
  prompt = f"""Context:
20
  {context}
 
25
  Response:
26
  """
27
  # Tokenize the prompt and move to device
28
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
29
 
30
  # Ensure decoder_start_token_id is set for encoder-decoder generation
31
  if model.config.decoder_start_token_id is None:
32
  model.config.decoder_start_token_id = tokenizer.pad_token_id
33
 
34
+ # Generate the SQL output with optimized parameters
35
  generated_ids = model.generate(
36
  input_ids=inputs["input_ids"],
37
  decoder_start_token_id=model.config.decoder_start_token_id,
38
+ max_new_tokens=100,
39
+ temperature=0.1,
40
+ num_beams=5,
41
+ repetition_penalty=1.2,
42
+ early_stopping=True,
43
  )
44
 
45
+ # Decode and clean the generated SQL statement
46
+ generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
47
+ generated_sql = generated_sql.split(";")[0] + ";" # ✅ Ensures only the first valid SQL query is returned
48
+
49
+ return generated_sql
50
 
51
  # Create Gradio interface with two input boxes: one for context and one for query
52
  iface = gr.Interface(
 
57
  ],
58
  outputs="text",
59
  title="Text-to-SQL Generator",
60
+ description="Enter your own context (e.g., database schema and sample data) and a natural language query. The model will generate the corresponding SQL statement.",
61
+ theme="compact",
62
+ allow_flagging="never"
63
  )
64
 
65
  iface.launch()