Updated README.md
Browse files
README.md
CHANGED
@@ -203,51 +203,55 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
203 |
import logging
|
204 |
|
205 |
# Set up logging
|
206 |
-
logging.basicConfig(
|
207 |
-
level=logging.INFO,
|
208 |
-
format="%(asctime)s - %(levelname)s - %(message)s",
|
209 |
-
)
|
210 |
logger = logging.getLogger(__name__)
|
211 |
|
212 |
-
#
|
213 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
214 |
|
215 |
# Load the fine-tuned model and tokenizer
|
216 |
-
model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
|
217 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
|
218 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
219 |
|
|
|
|
|
|
|
|
|
220 |
def run_inference(prompt_text: str) -> str:
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
# Example usage:
|
235 |
context = (
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
"(3, 'Charlie', 'Canada'), (4, 'David', 'USA'); "
|
241 |
-
"INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES "
|
242 |
-
"(101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), "
|
243 |
-
"(103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), "
|
244 |
-
"(105, 4, 900, '2024-03-05');"
|
245 |
-
)
|
246 |
-
query = (
|
247 |
-
"Retrieve the total order amount for each customer, showing only customers from the USA, "
|
248 |
-
"and sort the result by total order amount in descending order."
|
249 |
)
|
250 |
|
|
|
|
|
|
|
251 |
# Construct the prompt
|
252 |
sample_prompt = f"""Context:
|
253 |
{context}
|
@@ -269,12 +273,7 @@ print(query)
|
|
269 |
print("\nResponse:")
|
270 |
print(generated_sql)
|
271 |
|
272 |
-
#
|
273 |
-
# SELECT customers.name, SUM(orders.total_amount) as total_amount FROM customers
|
274 |
-
# INNER JOIN orders ON customers.id = orders.customer_id
|
275 |
-
# WHERE customers.country = 'USA'
|
276 |
-
# GROUP BY customers.name
|
277 |
-
# ORDER BY total_amount DESC;
|
278 |
```
|
279 |
|
280 |
## Citation
|
|
|
203 |
import logging
|
204 |
|
205 |
# Set up logging
|
206 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
207 |
logger = logging.getLogger(__name__)
|
208 |
|
209 |
+
# Ensure device is set (GPU if available)
|
210 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
211 |
|
212 |
# Load the fine-tuned model and tokenizer
|
213 |
+
model_name = "aarohanverma/text2sql-flan-t5-base-qlora-finetuned"
|
214 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
|
215 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
216 |
|
217 |
+
# Ensure decoder start token is set
|
218 |
+
if model.config.decoder_start_token_id is None:
|
219 |
+
model.config.decoder_start_token_id = tokenizer.pad_token_id
|
220 |
+
|
221 |
def run_inference(prompt_text: str) -> str:
|
222 |
+
"""
|
223 |
+
Runs inference on the fine-tuned model using beam search with fixes for repetition.
|
224 |
+
"""
|
225 |
+
inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
226 |
+
|
227 |
+
generated_ids = model.generate(
|
228 |
+
input_ids=inputs["input_ids"],
|
229 |
+
decoder_start_token_id=model.config.decoder_start_token_id, # ✅ Ensure decoder start token
|
230 |
+
max_new_tokens=100, # ✅ Limit to prevent excessive output
|
231 |
+
temperature=0.1, # ✅ Adds slight randomness to avoid repetition
|
232 |
+
num_beams=5, # ✅ Increases quality
|
233 |
+
repetition_penalty=1.2, # ✅ Penalizes repetition
|
234 |
+
early_stopping=True, # ✅ Stops generation once complete
|
235 |
+
)
|
236 |
+
|
237 |
+
generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
238 |
+
|
239 |
+
# ✅ Post-processing to remove repeated text
|
240 |
+
generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
|
241 |
+
|
242 |
+
return generated_sql
|
243 |
|
244 |
# Example usage:
|
245 |
context = (
|
246 |
+
"CREATE TABLE students (id INT PRIMARY KEY, name VARCHAR(100), age INT, grade CHAR(1)); "
|
247 |
+
"INSERT INTO students (id, name, age, grade) VALUES "
|
248 |
+
"(1, 'Alice', 14, 'A'), (2, 'Bob', 15, 'B'), "
|
249 |
+
"(3, 'Charlie', 14, 'A'), (4, 'David', 16, 'C'), (5, 'Eve', 15, 'B');"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
)
|
251 |
|
252 |
+
query = ("Retrieve the names of students who are 15 years old.")
|
253 |
+
|
254 |
+
|
255 |
# Construct the prompt
|
256 |
sample_prompt = f"""Context:
|
257 |
{context}
|
|
|
273 |
print("\nResponse:")
|
274 |
print(generated_sql)
|
275 |
|
276 |
+
# EXPECTED RESPONSE: SELECT name FROM students WHERE age = 15;
|
|
|
|
|
|
|
|
|
|
|
277 |
```
|
278 |
|
279 |
## Citation
|