|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_name = "EleutherAI/gpt-neo-2.7B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
schema = { |
|
"products": { |
|
"columns": ["product_id", "name", "price", "category_id"], |
|
"relations": "category_id -> categories.id", |
|
}, |
|
"categories": { |
|
"columns": ["id", "category_name"], |
|
"relations": None, |
|
}, |
|
"orders": { |
|
"columns": ["order_id", "customer_name", "product_id", "order_date"], |
|
"relations": "product_id -> products.product_id", |
|
}, |
|
} |
|
|
|
def generate_context(schema): |
|
""" |
|
Generate context dynamically from the schema. |
|
""" |
|
context = "### Database Schema ###\n\n" |
|
for table, details in schema.items(): |
|
context += f"Table: {table}\nColumns: {', '.join(details['columns'])}\n" |
|
if details.get("relations"): |
|
context += f"Relations: {details['relations']}\n" |
|
context += "\n" |
|
|
|
context += "### Instructions ###\n" |
|
context += ( |
|
"Generate SQL queries based on the user's question. " |
|
"Understand the schema to identify relevant tables and relationships. " |
|
"If the question involves multiple tables, use appropriate joins.\n" |
|
) |
|
return context |
|
|
|
|
|
context = generate_context(schema) |
|
|
|
def answer_question(context, question): |
|
""" |
|
Generate an SQL query or database-related response using the model. |
|
""" |
|
prompt = f"{context}\n\nUser Question: {question}\nSQL Query or Answer:" |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
|
outputs = model.generate(inputs.input_ids, max_length=256, num_beams=5, early_stopping=True) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
print("Database Assistant is ready. Ask your questions!") |
|
|
|
|
|
questions = [ |
|
"describe the product table for me, what kind of data it is storing and all" |
|
] |
|
|
|
for user_question in questions: |
|
print(f"Question: {user_question}") |
|
response = answer_question(context, user_question) |
|
print("\nGenerated Response:\n", response, "\n") |
|
|
|
|