aarohanverma commited on
Commit
451c534
·
verified ·
1 Parent(s): c5f31f4

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -6,13 +6,15 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  # Load the fine-tuned model and tokenizer
9
- model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned" # Replace with your model repo name
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
12
 
13
  def generate_sql(context: str, query: str) -> str:
14
  """
15
- Constructs a prompt using the user-provided context and query, then generates a SQL query.
 
 
16
  """
17
  prompt = f"""Context:
18
  {context}
@@ -22,17 +24,27 @@ Query:
22
 
23
  Response:
24
  """
 
25
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
26
  generated_ids = model.generate(
27
  input_ids=inputs["input_ids"],
 
28
  max_new_tokens=250,
29
  temperature=0.0, # Deterministic output
30
- num_beams=3, # Beam search for quality output
31
- early_stopping=True,
32
  )
 
 
33
  return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
34
 
35
- # Create a Gradio interface with two input boxes: one for context, one for query.
36
  iface = gr.Interface(
37
  fn=generate_sql,
38
  inputs=[
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  # Load the fine-tuned model and tokenizer
9
+ model_name = "aarohanverma/text2sql_flant5base_finetuned" # Replace with your model repository name
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
12
 
13
  def generate_sql(context: str, query: str) -> str:
14
  """
15
+ Generates a SQL query given the provided context and natural language query.
16
+ Constructs a prompt from the inputs, then performs deterministic generation
17
+ with beam search.
18
  """
19
  prompt = f"""Context:
20
  {context}
 
24
 
25
  Response:
26
  """
27
+ # Tokenize the prompt and move to device
28
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
29
+
30
+ # Ensure decoder_start_token_id is set for encoder-decoder generation
31
+ if model.config.decoder_start_token_id is None:
32
+ model.config.decoder_start_token_id = tokenizer.pad_token_id
33
+
34
+ # Generate the SQL output
35
  generated_ids = model.generate(
36
  input_ids=inputs["input_ids"],
37
+ decoder_start_token_id=model.config.decoder_start_token_id,
38
  max_new_tokens=250,
39
  temperature=0.0, # Deterministic output
40
+ num_beams=3, # Beam search for improved quality
41
+ early_stopping=True, # Stop when output is complete
42
  )
43
+
44
+ # Decode and return the generated SQL statement
45
  return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
 
47
+ # Create Gradio interface with two input boxes: one for context and one for query
48
  iface = gr.Interface(
49
  fn=generate_sql,
50
  inputs=[