File size: 5,014 Bytes
92cbbe3
39485d9
 
 
 
92cbbe3
 
39485d9
92cbbe3
39485d9
92cbbe3
39485d9
 
 
 
 
 
92cbbe3
39485d9
 
 
 
 
 
 
 
 
 
 
 
92cbbe3
39485d9
 
 
 
 
 
 
 
 
 
 
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
 
92cbbe3
39485d9
92cbbe3
39485d9
92cbbe3
39485d9
 
92cbbe3
39485d9
 
92cbbe3
39485d9
 
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
92cbbe3
39485d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92cbbe3
39485d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import pandas as pd
import os
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LLMService:
    def __init__(self, db_path):
        self.db_path = db_path
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
        self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-72B-Instruct")

    def convert_to_sql_query(self, natural_query):
        try:
            # Tokenize input
            inputs = self.tokenizer(f"Translate this to SQL: {natural_query}", return_tensors="pt")
            # Generate output
            outputs = self.model.generate(**inputs, max_length=512, num_beams=5)
            # Decode output
            sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return {"success": True, "query": sql_query}
        except Exception as e:
            logger.error(f"Error generating SQL query: {e}")
            return {"success": False, "error": str(e)}

    def execute_query(self, sql_query):
        try:
            import sqlite3
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            cursor.execute(sql_query)
            results = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]
            conn.close()
            return {"success": True, "results": results, "columns": columns}
        except Exception as e:
            logger.error(f"Error executing SQL query: {e}")
            return {"success": False, "error": str(e)}

def main():
    st.title("Natural Language to SQL Query Converter")
    st.write("Enter your question about the database in natural language.")
    
    # Load environment variables
    load_dotenv()
    db_path = os.getenv("DB_PATH")
    
    if not db_path:
        st.error("Missing database path in environment variables.")
        logger.error("DB path not found in environment variables.")
        return
    
    # Initialize LLM Service
    try:
        llm_service = LLMService(db_path=db_path)
    except Exception as e:
        st.error(f"Error initializing service: {str(e)}")
        return

    # Input for natural language query
    natural_query = st.text_area("Enter your query", "Show me all albums by artist 'Queen'", height=100)

    if st.button("Generate and Execute Query"):
        if not natural_query.strip():
            st.warning("Please enter a valid query.")
            return

        # Convert to SQL
        with st.spinner("Generating SQL query..."):
            sql_result = llm_service.convert_to_sql_query(natural_query)

        if not sql_result["success"]:
            st.error(f"Error generating SQL query: {sql_result['error']}")
            return

        # Display generated SQL
        st.subheader("Generated SQL Query:")
        st.code(sql_result["query"], language="sql")

        # Execute query
        with st.spinner("Executing query..."):
            query_result = llm_service.execute_query(sql_result["query"])

        if not query_result["success"]:
            st.error(f"Error executing query: {query_result['error']}")
            return

        # Check if there are results
        if query_result["results"]:
            df = pd.DataFrame(query_result["results"], columns=query_result["columns"])

            # Create a collapsible DataFrame using Streamlit's expander
            with st.expander("Click to view query results as a DataFrame"):
                st.dataframe(df)

            # Extract product details from the JSON result and display them
            json_results = df.to_dict(orient='records')
            if "title" in json_results[0] and "images" in json_results[0] and "price" in json_results[0]:
                st.subheader("Product Details:")
                for product in json_results:
                    price = product.get("price", "N/A")
                    title = product.get("handle", "N/A")
                    src = product.get("src", "N/A")
                    
                    # Display product details in a neat format using columns for alignment
                    with st.container():
                        col1, col2, col3 = st.columns([1, 2, 3])  # Adjust column widths as needed
                        
                        with col1:
                            st.image(src, use_container_width=True)  # Display product image with container width
                        with col2:
                            st.write(f"**Price:** {price}")  # Display price
                            st.write(f"**Title:** {title}")  # Display title
                        with col3:
                            st.write(f"**Image Source:** [Link]( {src} )")  # Link to the image if needed
        else:
            st.info("No results found.")

if __name__ == "__main__":
    main()