chronos / database.py
Manoj Kumar
updated question structure
f1b2798
raw
history blame
2.29 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model and tokenizer
model_name = "EleutherAI/gpt-neo-2.7B" # Replace with a suitable model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Example schema
schema = {
"products": {
"columns": ["product_id", "name", "price", "category_id"],
"relations": "category_id -> categories.id",
},
"categories": {
"columns": ["id", "category_name"],
"relations": None,
},
"orders": {
"columns": ["order_id", "customer_name", "product_id", "order_date"],
"relations": "product_id -> products.product_id",
},
}
def generate_context(schema):
"""
Generate context dynamically from the schema.
"""
context = "### Database Schema ###\n\n"
for table, details in schema.items():
context += f"Table: {table}\nColumns: {', '.join(details['columns'])}\n"
if details.get("relations"):
context += f"Relations: {details['relations']}\n"
context += "\n"
context += "### Instructions ###\n"
context += (
"Generate SQL queries based on the user's question. "
"Understand the schema to identify relevant tables and relationships. "
"If the question involves multiple tables, use appropriate joins.\n"
)
return context
# Generate dynamic context
context = generate_context(schema)
def answer_question(context, question):
"""
Generate an SQL query or database-related response using the model.
"""
prompt = f"{context}\n\nUser Question: {question}\nSQL Query or Answer:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = model.generate(inputs.input_ids, max_length=256, num_beams=5, early_stopping=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Interactive loop
print("Database Assistant is ready. Ask your questions!")
# Example interactive questions
questions = [
"describe the product table for me, what kind of data it is storing and all"
]
for user_question in questions:
print(f"Question: {user_question}")
response = answer_question(context, user_question)
print("\nGenerated Response:\n", response, "\n")