aarohanverma commited on
Commit
32aace6
·
verified ·
1 Parent(s): c9b612a

Updated README.md

Browse files
Files changed (1) hide show
  1. README.md +36 -37
README.md CHANGED
@@ -203,51 +203,55 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
203
  import logging
204
 
205
  # Set up logging
206
- logging.basicConfig(
207
- level=logging.INFO,
208
- format="%(asctime)s - %(levelname)s - %(message)s",
209
- )
210
  logger = logging.getLogger(__name__)
211
 
212
- # Set device (GPU if available)
213
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214
 
215
  # Load the fine-tuned model and tokenizer
216
- model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
217
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
218
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
219
 
 
 
 
 
220
  def run_inference(prompt_text: str) -> str:
221
- """
222
- Runs inference using deterministic decoding with beam search.
223
- """
224
- inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
225
- generated_ids = model.generate(
226
- input_ids=inputs["input_ids"],
227
- max_new_tokens=250,
228
- temperature=0.0,
229
- num_beams=3,
230
- early_stopping=True,
231
- )
232
- return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
233
 
234
  # Example usage:
235
  context = (
236
- "CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); "
237
- "CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), "
238
- "order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); "
239
- "INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), "
240
- "(3, 'Charlie', 'Canada'), (4, 'David', 'USA'); "
241
- "INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES "
242
- "(101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), "
243
- "(103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), "
244
- "(105, 4, 900, '2024-03-05');"
245
- )
246
- query = (
247
- "Retrieve the total order amount for each customer, showing only customers from the USA, "
248
- "and sort the result by total order amount in descending order."
249
  )
250
 
 
 
 
251
  # Construct the prompt
252
  sample_prompt = f"""Context:
253
  {context}
@@ -269,12 +273,7 @@ print(query)
269
  print("\nResponse:")
270
  print(generated_sql)
271
 
272
- # Expected Output:
273
- # SELECT customers.name, SUM(orders.total_amount) as total_amount FROM customers
274
- # INNER JOIN orders ON customers.id = orders.customer_id
275
- # WHERE customers.country = 'USA'
276
- # GROUP BY customers.name
277
- # ORDER BY total_amount DESC;
278
  ```
279
 
280
  ## Citation
 
203
  import logging
204
 
205
  # Set up logging
206
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
207
  logger = logging.getLogger(__name__)
208
 
209
+ # Ensure device is set (GPU if available)
210
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
 
212
  # Load the fine-tuned model and tokenizer
213
+ model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
214
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
215
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
216
 
217
+ # Ensure decoder start token is set
218
+ if model.config.decoder_start_token_id is None:
219
+ model.config.decoder_start_token_id = tokenizer.pad_token_id
220
+
221
  def run_inference(prompt_text: str) -> str:
222
+ """
223
+ Runs inference on the fine-tuned model using beam search with fixes for repetition.
224
+ """
225
+ inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512).to(device)
226
+
227
+ generated_ids = model.generate(
228
+ input_ids=inputs["input_ids"],
229
+ decoder_start_token_id=model.config.decoder_start_token_id, # ✅ Ensure decoder start token
230
+ max_new_tokens=100, # ✅ Limit to prevent excessive output
231
+ temperature=0.1, # ✅ Adds slight randomness to avoid repetition
232
+ num_beams=5, # ✅ Increases quality
233
+ repetition_penalty=1.2, # ✅ Penalizes repetition
234
+ early_stopping=True, # ✅ Stops generation once complete
235
+ )
236
+
237
+ generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
238
+
239
+ # ✅ Post-processing to remove repeated text
240
+ generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
241
+
242
+ return generated_sql
243
 
244
  # Example usage:
245
  context = (
246
+ "CREATE TABLE students (id INT PRIMARY KEY, name VARCHAR(100), age INT, grade CHAR(1)); "
247
+ "INSERT INTO students (id, name, age, grade) VALUES "
248
+ "(1, 'Alice', 14, 'A'), (2, 'Bob', 15, 'B'), "
249
+ "(3, 'Charlie', 14, 'A'), (4, 'David', 16, 'C'), (5, 'Eve', 15, 'B');"
 
 
 
 
 
 
 
 
 
250
  )
251
 
252
+ query = ("Retrieve the names of students who are 15 years old.")
253
+
254
+
255
  # Construct the prompt
256
  sample_prompt = f"""Context:
257
  {context}
 
273
  print("\nResponse:")
274
  print(generated_sql)
275
 
276
+ # EXPECTED RESPONSE: SELECT name FROM students WHERE age = 15;
 
 
 
 
 
277
  ```
278
 
279
  ## Citation