anindya-hf-2002 commited on
Commit
8b561c4
·
verified ·
1 Parent(s): fa96510

upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +312 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ __import__('pysqlite3')
4
+ import sys
5
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import tempfile
9
+ import shutil
10
+ import glob
11
+ import plotly.graph_objs as go
12
+ import plotly.io as pio
13
+ import json
14
+
15
+ from vanna.openai import OpenAI_Chat
16
+ from vanna.chromadb import ChromaDB_VectorStore
17
+
18
+ class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
19
+ def __init__(self, config=None):
20
+ # Get the directory of the current script
21
+ script_dir = os.path.dirname(os.path.abspath(__file__))
22
+
23
+ # Create temp directories in the script's parent directory
24
+ temp_dir = os.path.join(script_dir, 'temp_talk2table')
25
+ os.makedirs(temp_dir, exist_ok=True)
26
+
27
+ # ChromaDB path
28
+ chroma_path = os.path.join(temp_dir, 'chromadb')
29
+
30
+ # Update config with local paths
31
+ if config is None:
32
+ config = {}
33
+ config['persist_directory'] = chroma_path
34
+
35
+ ChromaDB_VectorStore.__init__(self, config=config)
36
+ OpenAI_Chat.__init__(self, config=config)
37
+
38
+ def clear_existing_databases():
39
+ """
40
+ Clear existing temporary databases and directories
41
+ """
42
+ script_dir = os.path.dirname(os.path.abspath(__file__))
43
+ temp_dir = os.path.join(script_dir, 'temp_talk2table')
44
+
45
+ if os.path.exists(temp_dir):
46
+ try:
47
+ shutil.rmtree(temp_dir)
48
+ st.success("Temporary databases and directories cleared successfully.")
49
+ except Exception as e:
50
+ st.error(f"Error clearing databases: {e}")
51
+ else:
52
+ st.info("No temporary databases found.")
53
+
54
+ @st.cache_resource(ttl=3600)
55
+ def setup_vanna(openai_api_key):
56
+ """
57
+ Set up Vanna instance with caching to prevent recreation on every rerun
58
+ """
59
+ vn = MyVanna(config={
60
+ 'api_key': openai_api_key,
61
+ 'model': 'gpt-3.5-turbo-0125',
62
+ 'allow_llm_to_see_data': True
63
+ })
64
+ return vn
65
+
66
+ @st.cache_data(ttl=3600)
67
+ def load_csv_to_sqlite(csv_file, table_name='user_data'):
68
+ """
69
+ Cache the CSV to SQLite conversion with local temp directory
70
+ """
71
+ # Get the directory of the current script
72
+ script_dir = os.path.dirname(os.path.abspath(__file__))
73
+ temp_dir = os.path.join(script_dir, 'temp_talk2table')
74
+ os.makedirs(temp_dir, exist_ok=True)
75
+
76
+ # Create SQLite database in the temp directory
77
+ db_path = os.path.join(temp_dir, 'vanna_user_database.sqlite')
78
+
79
+ df = pd.read_csv(csv_file, encoding_errors='ignore')
80
+
81
+ conn = sqlite3.connect(db_path)
82
+ df.to_sql(table_name, conn, if_exists='replace', index=False)
83
+ conn.close()
84
+
85
+ return db_path, df
86
+
87
+ @st.cache_data(ttl=3600)
88
+ def convert_to_information_schema_df(input_df):
89
+ """
90
+ Convert input DataFrame to information schema DataFrame
91
+ """
92
+ rows = []
93
+ database = 'main'
94
+ schema = 'public'
95
+ table_name = 'user_data'
96
+
97
+ for _, row in input_df.iterrows():
98
+ row_data = {
99
+ 'TABLE_CATALOG': database,
100
+ 'TABLE_SCHEMA': schema,
101
+ 'TABLE_NAME': table_name,
102
+ 'COLUMN_NAME': row['name'],
103
+ 'DATA_TYPE': row['type'],
104
+ 'IS_NULLABLE': 'NO' if row['notnull'] else 'YES',
105
+ 'COLUMN_DEFAULT': row['dflt_value'],
106
+ 'IS_PRIMARY_KEY': 'YES' if row['pk'] else 'NO'
107
+ }
108
+ rows.append(row_data)
109
+
110
+ return pd.DataFrame(rows)
111
+
112
+ def generate_followup_questions_cached(vn, prompt, sql=None, df=None):
113
+ """
114
+ Safely generate follow-up questions with optional SQL and DataFrame
115
+ """
116
+ try:
117
+ # If both SQL and DataFrame are provided, use the method that requires them
118
+ if sql is not None and df is not None:
119
+ similar_questions = vn.generate_followup_questions(prompt, sql, df)
120
+ else:
121
+ # Fallback to method without SQL and DataFrame
122
+ similar_questions = vn.generate_followup_questions(prompt)
123
+
124
+ # Ensure we're working with a list of questions
125
+ if isinstance(similar_questions, list):
126
+ # If list of dicts, extract questions
127
+ if similar_questions and isinstance(similar_questions[0], dict):
128
+ similar_questions = [q.get('question', '') for q in similar_questions if isinstance(q, dict)]
129
+
130
+ # Remove empty strings and duplicates
131
+ similar_questions = list(dict.fromkeys(filter(bool, similar_questions)))
132
+ else:
133
+ similar_questions = []
134
+
135
+ return similar_questions[:5] # Limit to 5 follow-up questions
136
+ except Exception as e:
137
+ st.warning(f"Error getting similar questions: {e}")
138
+ return []
139
+
140
+ def main():
141
+ st.set_page_config(page_title="Talk2Table", layout="wide")
142
+ st.title("🤖 Talk2Table")
143
+
144
+ # Sidebar for configuration
145
+ st.sidebar.header("OpenAI Configuration")
146
+ openai_api_key = st.sidebar.text_input(label="OpenAI API KEY", placeholder="sk-...", type="password")
147
+
148
+ # Add a button to clear existing databases
149
+ if st.sidebar.button("Clear Temp Databases"):
150
+ clear_existing_databases()
151
+
152
+ # Configuration checkboxes
153
+ show_sql = st.sidebar.checkbox("Show SQL Query", value=True)
154
+ show_table = st.sidebar.checkbox("Show Data Table", value=True)
155
+ show_chart = st.sidebar.checkbox("Show Plotly Chart", value=True)
156
+ show_summary = st.sidebar.checkbox("Show Summary", value=False)
157
+
158
+ # Initialize or reset session state
159
+ if 'messages' not in st.session_state:
160
+ st.session_state.messages = []
161
+
162
+ # Ensure these session state variables exist
163
+ if 'last_plot' not in st.session_state:
164
+ st.session_state.last_plot = None
165
+
166
+ # CSV File Upload
167
+ uploaded_file = st.file_uploader("Upload a CSV file", type=['csv'])
168
+
169
+ # Chat container
170
+ chat_container = st.container()
171
+
172
+ if uploaded_file is not None and openai_api_key:
173
+ # Save uploaded file temporarily and load to SQLite
174
+ script_dir = os.path.dirname(os.path.abspath(__file__))
175
+ temp_dir = os.path.join(script_dir, 'temp_talk2table')
176
+ os.makedirs(temp_dir, exist_ok=True)
177
+
178
+ temp_csv_path = os.path.join(temp_dir, uploaded_file.name)
179
+ with open(temp_csv_path, 'wb') as f:
180
+ f.write(uploaded_file.getbuffer())
181
+
182
+ # Load CSV to SQLite
183
+ db_path, df = load_csv_to_sqlite(temp_csv_path)
184
+
185
+ if db_path and df is not None:
186
+ # Setup Vanna instance with caching
187
+ vn = setup_vanna(openai_api_key)
188
+
189
+ # Connect to SQLite and train
190
+ vn.connect_to_sqlite(db_path)
191
+
192
+ # Train Vanna with table schema
193
+ df_information_schema = vn.run_sql("PRAGMA table_info('user_data');")
194
+ plan_df = convert_to_information_schema_df(df_information_schema)
195
+
196
+ # Enhanced training
197
+ plan = vn.get_training_plan_generic(plan_df)
198
+ vn.train(plan=plan)
199
+
200
+ # Display existing messages and their plots
201
+ with chat_container:
202
+ for message in st.session_state.messages:
203
+ with st.chat_message(message["role"]):
204
+ st.markdown(message["content"])
205
+
206
+ # If the message has a plot and chart is enabled, display it
207
+ if message["role"] == "assistant" and 'plot' in message and show_chart:
208
+ try:
209
+ # Use plotly.io to parse the JSON figure
210
+ plot_fig = pio.from_json(message['plot'])
211
+ st.plotly_chart(plot_fig, use_container_width=True)
212
+ except Exception as e:
213
+ st.error(f"Error rendering plot: {e}")
214
+
215
+ # Sidebar for suggested questions
216
+ st.sidebar.header("Suggested Questions")
217
+ for q in st.session_state.get('similar_questions', []):
218
+ st.sidebar.markdown("* "+q)
219
+
220
+ prompt = st.chat_input("Ask a question about your data...")
221
+
222
+ if prompt:
223
+ st.session_state.messages.append({"role": "user", "content": prompt})
224
+ with st.chat_message("user"):
225
+ st.markdown(prompt)
226
+
227
+ with st.chat_message("assistant"):
228
+ with st.spinner("Generating answer..."):
229
+ try:
230
+ # Generate SQL with explicit allow_llm_to_see_data
231
+ sql, results_df, fig = vn.ask(
232
+ question=prompt,
233
+ print_results=False,
234
+ auto_train=True,
235
+ visualize=show_chart,
236
+ allow_llm_to_see_data=True
237
+ )
238
+
239
+ # Prepare response
240
+ response = ""
241
+
242
+ # Prepare message with plot
243
+ assistant_message = {
244
+ "role": "assistant",
245
+ "content": "",
246
+ "plot": None
247
+ }
248
+
249
+ # Update last successful query state
250
+ if sql:
251
+ st.session_state.last_prompt = prompt
252
+ st.session_state.last_sql = sql
253
+ st.session_state.last_df = results_df
254
+
255
+ if show_sql and sql:
256
+ response += f"**Generated SQL:**\n```sql\n{sql}\n```\n\n"
257
+
258
+ if show_summary and results_df is not None:
259
+ try:
260
+ summary = vn.generate_summary(prompt, results_df)
261
+ response += f"**Summary:**\n{summary}\n\n"
262
+ except Exception as sum_error:
263
+ st.warning(f"Could not generate summary: {sum_error}")
264
+
265
+ if show_table and results_df is not None:
266
+ try:
267
+ response += "**Data Results:**\n" + results_df.to_markdown() + "\n\n"
268
+ except Exception as table_error:
269
+ st.warning(f"Could not display table: {table_error}")
270
+ response += "**Data Results:** Unable to display table\n\n"
271
+
272
+ # Store the plot in the message only if chart is enabled and fig is not None
273
+ if show_chart and fig is not None:
274
+ # Use plotly.io to convert figure to JSON
275
+ plot_json = pio.to_json(fig, remove_uids=True)
276
+ assistant_message['plot'] = plot_json
277
+ st.session_state.last_plot = plot_json
278
+ st.plotly_chart(fig, use_container_width=True)
279
+ else:
280
+ # If chart is disabled or fig is None, use the last successful plot if available
281
+ if st.session_state.last_plot and show_chart:
282
+ try:
283
+ last_plot_fig = pio.from_json(st.session_state.last_plot)
284
+ st.plotly_chart(last_plot_fig, use_container_width=True)
285
+ except Exception as e:
286
+ st.warning(f"Could not render previous plot: {e}")
287
+
288
+ # Generate follow-up questions
289
+ similar_questions = generate_followup_questions_cached(
290
+ vn,
291
+ prompt,
292
+ sql=st.session_state.get('last_sql'),
293
+ df=st.session_state.get('last_df')
294
+ )
295
+ st.session_state.similar_questions = similar_questions
296
+
297
+ # Finalize the assistant message
298
+ assistant_message['content'] = response
299
+ st.session_state.messages.append(assistant_message)
300
+
301
+ st.markdown(response)
302
+
303
+ except Exception as e:
304
+ error_message = f"Error generating answer: {str(e)}"
305
+ st.error(error_message)
306
+ st.session_state.messages.append({"role": "assistant", "content": error_message})
307
+
308
+ else:
309
+ st.info("Please provide both OpenAI API Key and upload a CSV file to enable chat.")
310
+
311
+ if __name__ == "__main__":
312
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ vanna
2
+ streamlit
3
+ pandas
4
+ plotly
5
+ vanna[chromadb,openai]
6
+ pysqlite3-binary