Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import os | |
from dotenv import load_dotenv | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class LLMService: | |
def __init__(self, db_path): | |
self.db_path = db_path | |
# Load tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct") | |
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-72B-Instruct") | |
def convert_to_sql_query(self, natural_query): | |
try: | |
# Tokenize input | |
inputs = self.tokenizer(f"Translate this to SQL: {natural_query}", return_tensors="pt") | |
# Generate output | |
outputs = self.model.generate(**inputs, max_length=512, num_beams=5) | |
# Decode output | |
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"success": True, "query": sql_query} | |
except Exception as e: | |
logger.error(f"Error generating SQL query: {e}") | |
return {"success": False, "error": str(e)} | |
def execute_query(self, sql_query): | |
try: | |
import sqlite3 | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute(sql_query) | |
results = cursor.fetchall() | |
columns = [desc[0] for desc in cursor.description] | |
conn.close() | |
return {"success": True, "results": results, "columns": columns} | |
except Exception as e: | |
logger.error(f"Error executing SQL query: {e}") | |
return {"success": False, "error": str(e)} | |
def main(): | |
st.title("Natural Language to SQL Query Converter") | |
st.write("Enter your question about the database in natural language.") | |
# Load environment variables | |
load_dotenv() | |
db_path = os.getenv("DB_PATH") | |
if not db_path: | |
st.error("Missing database path in environment variables.") | |
logger.error("DB path not found in environment variables.") | |
return | |
# Initialize LLM Service | |
try: | |
llm_service = LLMService(db_path=db_path) | |
except Exception as e: | |
st.error(f"Error initializing service: {str(e)}") | |
return | |
# Input for natural language query | |
natural_query = st.text_area("Enter your query", "Show me all albums by artist 'Queen'", height=100) | |
if st.button("Generate and Execute Query"): | |
if not natural_query.strip(): | |
st.warning("Please enter a valid query.") | |
return | |
# Convert to SQL | |
with st.spinner("Generating SQL query..."): | |
sql_result = llm_service.convert_to_sql_query(natural_query) | |
if not sql_result["success"]: | |
st.error(f"Error generating SQL query: {sql_result['error']}") | |
return | |
# Display generated SQL | |
st.subheader("Generated SQL Query:") | |
st.code(sql_result["query"], language="sql") | |
# Execute query | |
with st.spinner("Executing query..."): | |
query_result = llm_service.execute_query(sql_result["query"]) | |
if not query_result["success"]: | |
st.error(f"Error executing query: {query_result['error']}") | |
return | |
# Check if there are results | |
if query_result["results"]: | |
df = pd.DataFrame(query_result["results"], columns=query_result["columns"]) | |
# Create a collapsible DataFrame using Streamlit's expander | |
with st.expander("Click to view query results as a DataFrame"): | |
st.dataframe(df) | |
# Extract product details from the JSON result and display them | |
json_results = df.to_dict(orient='records') | |
if "title" in json_results[0] and "images" in json_results[0] and "price" in json_results[0]: | |
st.subheader("Product Details:") | |
for product in json_results: | |
price = product.get("price", "N/A") | |
title = product.get("handle", "N/A") | |
src = product.get("src", "N/A") | |
# Display product details in a neat format using columns for alignment | |
with st.container(): | |
col1, col2, col3 = st.columns([1, 2, 3]) # Adjust column widths as needed | |
with col1: | |
st.image(src, use_container_width=True) # Display product image with container width | |
with col2: | |
st.write(f"**Price:** {price}") # Display price | |
st.write(f"**Title:** {title}") # Display title | |
with col3: | |
st.write(f"**Image Source:** [Link]( {src} )") # Link to the image if needed | |
else: | |
st.info("No results found.") | |
if __name__ == "__main__": | |
main() | |