File size: 13,005 Bytes
8b561c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b3694
 
 
8b561c4
 
b7b3694
8b561c4
 
b7b3694
8b561c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
import sqlite3
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import streamlit as st
import pandas as pd
import tempfile
import shutil
import glob
import plotly.graph_objs as go
import plotly.io as pio
import json

from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        # Get the directory of the current script
        script_dir = os.path.dirname(os.path.abspath(__file__))
        
        # Create temp directories in the script's parent directory
        temp_dir = os.path.join(script_dir, 'temp_talk2table')
        os.makedirs(temp_dir, exist_ok=True)
        
        # ChromaDB path
        chroma_path = os.path.join(temp_dir, 'chromadb')
        
        # Update config with local paths
        if config is None:
            config = {}
        config['persist_directory'] = chroma_path
        
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

def clear_existing_databases():
    """
    Clear existing temporary databases and directories
    """
    script_dir = os.path.dirname(os.path.abspath(__file__))
    temp_dir = os.path.join(script_dir, 'temp_talk2table')
    
    if os.path.exists(temp_dir):
        try:
            shutil.rmtree(temp_dir)
            st.success("Temporary databases and directories cleared successfully.")
        except Exception as e:
            st.error(f"Error clearing databases: {e}")
    else:
        st.info("No temporary databases found.")

@st.cache_resource(ttl=3600)
def setup_vanna(openai_api_key):
    """
    Set up Vanna instance with caching to prevent recreation on every rerun
    """
    vn = MyVanna(config={
        'api_key': openai_api_key, 
        'model': 'gpt-3.5-turbo-0125',
        'allow_llm_to_see_data': True
    })
    return vn

@st.cache_data(ttl=3600)
def load_csv_to_sqlite(csv_file, table_name='user_data'):
    """
    Cache the CSV to SQLite conversion with local temp directory
    """
    # Get the directory of the current script
    script_dir = os.path.dirname(os.path.abspath(__file__))
    temp_dir = os.path.join(script_dir, 'temp_talk2table')
    os.makedirs(temp_dir, exist_ok=True)
    
    # Create SQLite database in the temp directory
    db_path = os.path.join(temp_dir, 'vanna_user_database.sqlite')
    
    df = pd.read_csv(csv_file, encoding_errors='ignore')
    
    conn = sqlite3.connect(db_path)
    df.to_sql(table_name, conn, if_exists='replace', index=False)
    conn.close()
            
    return db_path, df

@st.cache_data(ttl=3600)
def convert_to_information_schema_df(input_df):
    """
    Convert input DataFrame to information schema DataFrame
    """
    rows = []
    database = 'main'
    schema = 'public'
    table_name = 'user_data'
    
    for _, row in input_df.iterrows():
        row_data = {
            'TABLE_CATALOG': database,
            'TABLE_SCHEMA': schema,
            'TABLE_NAME': table_name,
            'COLUMN_NAME': row['name'],
            'DATA_TYPE': row['type'],
            'IS_NULLABLE': 'NO' if row['notnull'] else 'YES',
            'COLUMN_DEFAULT': row['dflt_value'],
            'IS_PRIMARY_KEY': 'YES' if row['pk'] else 'NO'
        }
        rows.append(row_data)
    
    return pd.DataFrame(rows)

def generate_followup_questions_cached(vn, prompt, sql=None, df=None):
    """
    Safely generate follow-up questions with optional SQL and DataFrame
    """
    try:
        # If both SQL and DataFrame are provided, use the method that requires them
        if sql is not None and df is not None:
            similar_questions = vn.generate_followup_questions(prompt, sql, df)
        else:
            # Fallback to method without SQL and DataFrame
            similar_questions = vn.generate_followup_questions(prompt)
        
        # Ensure we're working with a list of questions
        if isinstance(similar_questions, list):
            # If list of dicts, extract questions
            if similar_questions and isinstance(similar_questions[0], dict):
                similar_questions = [q.get('question', '') for q in similar_questions if isinstance(q, dict)]
            
            # Remove empty strings and duplicates
            similar_questions = list(dict.fromkeys(filter(bool, similar_questions)))
        else:
            similar_questions = []
        
        return similar_questions[:5]  # Limit to 5 follow-up questions
    except Exception as e:
        st.warning(f"Error getting similar questions: {e}")
        return []

def main():
    st.set_page_config(page_title="Talk2Table", layout="wide")
    st.title("🤖 Talk2Table")

    # Sidebar for configuration
    st.sidebar.header("OpenAI Configuration")
    openai_api_key = st.sidebar.text_input(label="OpenAI API KEY", placeholder="sk-...", type="password")
    
    # # Add a button to clear existing databases
    # if st.sidebar.button("Clear Temp Databases"):
    #     clear_existing_databases()
    
    # Configuration checkboxes
    show_sql = st.sidebar.checkbox("Show SQL Query", value=False)
    show_table = st.sidebar.checkbox("Show Data Table", value=True)
    show_chart = st.sidebar.checkbox("Show Plotly Chart", value=True)
    show_summary = st.sidebar.checkbox("Show Summary", value=True)

    # Initialize or reset session state
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    # Ensure these session state variables exist
    if 'last_plot' not in st.session_state:
        st.session_state.last_plot = None

    # CSV File Upload
    uploaded_file = st.file_uploader("Upload a CSV file", type=['csv'])
    
    # Chat container
    chat_container = st.container()

    if uploaded_file is not None and openai_api_key:
        # Save uploaded file temporarily and load to SQLite
        script_dir = os.path.dirname(os.path.abspath(__file__))
        temp_dir = os.path.join(script_dir, 'temp_talk2table')
        os.makedirs(temp_dir, exist_ok=True)
        
        temp_csv_path = os.path.join(temp_dir, uploaded_file.name)
        with open(temp_csv_path, 'wb') as f:
            f.write(uploaded_file.getbuffer())

        # Load CSV to SQLite
        db_path, df = load_csv_to_sqlite(temp_csv_path)
        
        if db_path and df is not None:
            # Setup Vanna instance with caching
            vn = setup_vanna(openai_api_key)
            
            # Connect to SQLite and train
            vn.connect_to_sqlite(db_path)

            # Train Vanna with table schema
            df_information_schema = vn.run_sql("PRAGMA table_info('user_data');")
            plan_df = convert_to_information_schema_df(df_information_schema)
            
            # Enhanced training
            plan = vn.get_training_plan_generic(plan_df)
            vn.train(plan=plan)

            # Display existing messages and their plots
            with chat_container:
                for message in st.session_state.messages:
                    with st.chat_message(message["role"]):
                        st.markdown(message["content"])
                        
                        # If the message has a plot and chart is enabled, display it
                        if message["role"] == "assistant" and 'plot' in message and show_chart:
                            try:
                                # Use plotly.io to parse the JSON figure
                                plot_fig = pio.from_json(message['plot'])
                                st.plotly_chart(plot_fig, use_container_width=True)
                            except Exception as e:
                                st.error(f"Error rendering plot: {e}")

            # Sidebar for suggested questions
            st.sidebar.header("Suggested Questions")
            for q in st.session_state.get('similar_questions', []):
                st.sidebar.markdown("* "+q)

            prompt = st.chat_input("Ask a question about your data...")

            if prompt:
                st.session_state.messages.append({"role": "user", "content": prompt})
                with st.chat_message("user"):
                    st.markdown(prompt)

                with st.chat_message("assistant"):
                    with st.spinner("Generating answer..."):
                        try:
                            # Generate SQL with explicit allow_llm_to_see_data
                            sql, results_df, fig = vn.ask(
                                question=prompt, 
                                print_results=False, 
                                auto_train=True, 
                                visualize=show_chart,
                                allow_llm_to_see_data=True
                            )

                            # Prepare response
                            response = ""
                            
                            # Prepare message with plot
                            assistant_message = {
                                "role": "assistant",
                                "content": "",
                                "plot": None
                            }
                            
                            # Update last successful query state
                            if sql:
                                st.session_state.last_prompt = prompt
                                st.session_state.last_sql = sql
                                st.session_state.last_df = results_df

                            if show_sql and sql:
                                response += f"**Generated SQL:**\n```sql\n{sql}\n```\n\n"
                            
                            if show_summary and results_df is not None:
                                try:
                                    summary = vn.generate_summary(prompt, results_df)
                                    response += f"**Summary:**\n{summary}\n\n"
                                except Exception as sum_error:
                                    st.warning(f"Could not generate summary: {sum_error}")
                            
                            if show_table and results_df is not None:
                                try:
                                    response += "**Data Results:**\n" + results_df.to_markdown() + "\n\n"
                                except Exception as table_error:
                                    st.warning(f"Could not display table: {table_error}")
                                    response += "**Data Results:** Unable to display table\n\n"

                            # Store the plot in the message only if chart is enabled and fig is not None
                            if show_chart and fig is not None:
                                # Use plotly.io to convert figure to JSON
                                plot_json = pio.to_json(fig, remove_uids=True)
                                assistant_message['plot'] = plot_json
                                st.session_state.last_plot = plot_json
                                st.plotly_chart(fig, use_container_width=True)
                            else:
                                # If chart is disabled or fig is None, use the last successful plot if available
                                if st.session_state.last_plot and show_chart:
                                    try:
                                        last_plot_fig = pio.from_json(st.session_state.last_plot)
                                        st.plotly_chart(last_plot_fig, use_container_width=True)
                                    except Exception as e:
                                        st.warning(f"Could not render previous plot: {e}")

                            # Generate follow-up questions
                            similar_questions = generate_followup_questions_cached(
                                vn, 
                                prompt, 
                                sql=st.session_state.get('last_sql'), 
                                df=st.session_state.get('last_df')
                            )
                            st.session_state.similar_questions = similar_questions

                            # Finalize the assistant message
                            assistant_message['content'] = response
                            st.session_state.messages.append(assistant_message)

                            st.markdown(response)

                        except Exception as e:
                            error_message = f"Error generating answer: {str(e)}"
                            st.error(error_message)
                            st.session_state.messages.append({"role": "assistant", "content": error_message})

        else:
            st.info("Please provide both OpenAI API Key and upload a CSV file to enable chat.")

if __name__ == "__main__":
    main()