aarohanverma commited on
Commit
9d0df30
·
verified ·
1 Parent(s): 27e057c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -2,53 +2,61 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
- # Set up device (GPU if available)
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Load the fine-tuned model and tokenizer
9
- model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned" # Replace with your model repository name
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
12
 
 
 
 
 
 
 
 
 
 
 
13
  def generate_sql(context: str, query: str) -> str:
14
  """
15
  Generates a SQL query given the provided context and natural language query.
16
  Constructs a prompt from the inputs, then performs deterministic generation
17
- with beam search and repetition handling.
18
  """
19
  prompt = f"""Context:
20
  {context}
21
-
22
  Query:
23
  {query}
24
-
25
  Response:
26
  """
27
- # Tokenize the prompt and move to device
28
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
29
 
30
- # Ensure decoder_start_token_id is set for encoder-decoder generation
31
  if model.config.decoder_start_token_id is None:
32
  model.config.decoder_start_token_id = tokenizer.pad_token_id
33
 
34
- # Generate the SQL output with optimized parameters
35
- generated_ids = model.generate(
36
- input_ids=inputs["input_ids"],
37
- decoder_start_token_id=model.config.decoder_start_token_id,
38
- max_new_tokens=100,
39
- temperature=0.1,
40
- num_beams=5,
41
- repetition_penalty=1.2,
42
- early_stopping=True,
43
- )
 
44
 
45
- # Decode and clean the generated SQL statement
46
  generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
47
- generated_sql = generated_sql.split(";")[0] + ";" # Ensures only the first valid SQL query is returned
48
-
49
  return generated_sql
50
 
51
- # Create Gradio interface with two input boxes: one for context and one for query
52
  iface = gr.Interface(
53
  fn=generate_sql,
54
  inputs=[
 
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
+ # Set up device: use GPU if available, else CPU.
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Load the fine-tuned model and tokenizer.
9
+ model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
12
 
13
+ # For CPU inference, convert the model to FP32 for better compatibility.
14
+ if device.type == "cpu":
15
+ model = model.float()
16
+
17
+ # Optionally compile the model for speed improvements (requires PyTorch 2.0+).
18
+ try:
19
+ model = torch.compile(model)
20
+ except Exception as e:
21
+ print("torch.compile optimization failed:", e)
22
+
23
  def generate_sql(context: str, query: str) -> str:
24
  """
25
  Generates a SQL query given the provided context and natural language query.
26
  Constructs a prompt from the inputs, then performs deterministic generation
27
+ using beam search with repetition handling.
28
  """
29
  prompt = f"""Context:
30
  {context}
 
31
  Query:
32
  {query}
 
33
  Response:
34
  """
35
+ # Tokenize the prompt with truncation and max length; move to device.
36
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
37
 
38
+ # Ensure the decoder start token is set.
39
  if model.config.decoder_start_token_id is None:
40
  model.config.decoder_start_token_id = tokenizer.pad_token_id
41
 
42
+ # Generate SQL output with no_grad to optimize CPU usage.
43
+ with torch.no_grad():
44
+ generated_ids = model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ decoder_start_token_id=model.config.decoder_start_token_id,
47
+ max_new_tokens=100,
48
+ temperature=0.0, # Deterministic output
49
+ num_beams=5,
50
+ repetition_penalty=1.2,
51
+ early_stopping=True,
52
+ )
53
 
54
+ # Decode and clean the generated SQL statement.
55
  generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
56
+ generated_sql = generated_sql.split(";")[0].strip() + ";" # Keep only the first valid SQL query
 
57
  return generated_sql
58
 
59
+ # Create Gradio interface with two input boxes: one for context and one for query.
60
  iface = gr.Interface(
61
  fn=generate_sql,
62
  inputs=[