graphql / app.py
nileshhanotia's picture
Update app.py
39485d9 verified
raw
history blame
5.01 kB
import streamlit as st
import pandas as pd
import os
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMService:
def __init__(self, db_path):
self.db_path = db_path
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
def convert_to_sql_query(self, natural_query):
try:
# Tokenize input
inputs = self.tokenizer(f"Translate this to SQL: {natural_query}", return_tensors="pt")
# Generate output
outputs = self.model.generate(**inputs, max_length=512, num_beams=5)
# Decode output
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"success": True, "query": sql_query}
except Exception as e:
logger.error(f"Error generating SQL query: {e}")
return {"success": False, "error": str(e)}
def execute_query(self, sql_query):
try:
import sqlite3
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(sql_query)
results = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
conn.close()
return {"success": True, "results": results, "columns": columns}
except Exception as e:
logger.error(f"Error executing SQL query: {e}")
return {"success": False, "error": str(e)}
def main():
st.title("Natural Language to SQL Query Converter")
st.write("Enter your question about the database in natural language.")
# Load environment variables
load_dotenv()
db_path = os.getenv("DB_PATH")
if not db_path:
st.error("Missing database path in environment variables.")
logger.error("DB path not found in environment variables.")
return
# Initialize LLM Service
try:
llm_service = LLMService(db_path=db_path)
except Exception as e:
st.error(f"Error initializing service: {str(e)}")
return
# Input for natural language query
natural_query = st.text_area("Enter your query", "Show me all albums by artist 'Queen'", height=100)
if st.button("Generate and Execute Query"):
if not natural_query.strip():
st.warning("Please enter a valid query.")
return
# Convert to SQL
with st.spinner("Generating SQL query..."):
sql_result = llm_service.convert_to_sql_query(natural_query)
if not sql_result["success"]:
st.error(f"Error generating SQL query: {sql_result['error']}")
return
# Display generated SQL
st.subheader("Generated SQL Query:")
st.code(sql_result["query"], language="sql")
# Execute query
with st.spinner("Executing query..."):
query_result = llm_service.execute_query(sql_result["query"])
if not query_result["success"]:
st.error(f"Error executing query: {query_result['error']}")
return
# Check if there are results
if query_result["results"]:
df = pd.DataFrame(query_result["results"], columns=query_result["columns"])
# Create a collapsible DataFrame using Streamlit's expander
with st.expander("Click to view query results as a DataFrame"):
st.dataframe(df)
# Extract product details from the JSON result and display them
json_results = df.to_dict(orient='records')
if "title" in json_results[0] and "images" in json_results[0] and "price" in json_results[0]:
st.subheader("Product Details:")
for product in json_results:
price = product.get("price", "N/A")
title = product.get("handle", "N/A")
src = product.get("src", "N/A")
# Display product details in a neat format using columns for alignment
with st.container():
col1, col2, col3 = st.columns([1, 2, 3]) # Adjust column widths as needed
with col1:
st.image(src, use_container_width=True) # Display product image with container width
with col2:
st.write(f"**Price:** {price}") # Display price
st.write(f"**Title:** {title}") # Display title
with col3:
st.write(f"**Image Source:** [Link]( {src} )") # Link to the image if needed
else:
st.info("No results found.")
if __name__ == "__main__":
main()