Ari commited on
Commit
5cf07df
·
verified ·
1 Parent(s): 99fe04f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -17
app.py CHANGED
@@ -2,9 +2,15 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
 
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
  import sqlparse
7
  import logging
 
 
 
 
 
8
 
9
  # Initialize conversation history
10
  if 'history' not in st.session_state:
@@ -66,24 +72,58 @@ def process_input():
66
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
67
  else:
68
  columns = ', '.join(valid_columns)
69
- generated_sql = sql_generation_chain.run({
70
  'question': user_prompt,
71
  'table_name': table_name,
72
  'columns': columns
73
  })
74
 
75
- # Debug: Display generated SQL query for inspection
76
- st.write(f"Generated SQL Query:\n{generated_sql}")
77
-
78
- # Attempt to execute SQL query and handle exceptions
79
- try:
80
- result = pd.read_sql_query(generated_sql, conn)
81
- assistant_response = f"Generated SQL Query:\n{generated_sql}"
82
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
83
- st.session_state.history.append({"role": "assistant", "content": result})
84
- except Exception as e:
85
- logging.error(f"An error occurred during SQL execution: {e}")
86
- assistant_response = f"Error executing SQL query: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
88
 
89
  except Exception as e:
@@ -94,16 +134,31 @@ def process_input():
94
  # Reset the user_input in session state
95
  st.session_state['user_input'] = ''
96
 
 
 
 
 
 
 
 
 
 
 
 
97
  # Display the conversation history
98
  for message in st.session_state.history:
99
  if message['role'] == 'user':
100
  st.markdown(f"**User:** {message['content']}")
101
  elif message['role'] == 'assistant':
102
- if isinstance(message['content'], pd.DataFrame):
103
- st.markdown("**Assistant:** Query Results:")
104
- st.dataframe(message['content'])
 
 
 
105
  else:
106
- st.markdown(f"**Assistant:** {message['content']}")
 
107
 
108
  # Place the input field at the bottom with the callback
109
  st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
+ import numpy as np # New import
6
  from langchain import OpenAI, LLMChain, PromptTemplate
7
  import sqlparse
8
  import logging
9
+ # Removed unused import: from sql_metadata import Parser
10
+ from sklearn.linear_model import LinearRegression # New import
11
+ from sklearn.model_selection import train_test_split # New import
12
+ from sklearn.metrics import mean_squared_error, r2_score # New import
13
+
14
 
15
  # Initialize conversation history
16
  if 'history' not in st.session_state:
 
72
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
73
  else:
74
  columns = ', '.join(valid_columns)
75
+ response = sql_generation_chain.run({
76
  'question': user_prompt,
77
  'table_name': table_name,
78
  'columns': columns
79
  })
80
 
81
+ # Extract code from response
82
+ code = extract_code(response)
83
+ if code:
84
+ # Determine if the code is SQL or Python
85
+ if code.strip().lower().startswith('select'):
86
+ # It's a SQL query
87
+ st.write(f"Generated SQL Query:\n{code}")
88
+ try:
89
+ result = pd.read_sql_query(code, conn)
90
+ assistant_response = f"Generated SQL Query:\n{code}"
91
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
92
+ st.session_state.history.append({"role": "assistant", "content": result})
93
+ except Exception as e:
94
+ logging.error(f"An error occurred during SQL execution: {e}")
95
+ assistant_response = f"Error executing SQL query: {e}"
96
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
97
+ else:
98
+ # It's Python code
99
+ st.write(f"Generated Python Code:\n{code}")
100
+ try:
101
+ # Prepare the local namespace
102
+ local_vars = {
103
+ 'pd': pd,
104
+ 'np': np,
105
+ 'data': data.copy(),
106
+ 'result': None,
107
+ 'LinearRegression': LinearRegression,
108
+ 'train_test_split': train_test_split,
109
+ 'mean_squared_error': mean_squared_error,
110
+ 'r2_score': r2_score
111
+ }
112
+ exec(code, {}, local_vars)
113
+ result = local_vars.get('result')
114
+ if result is not None:
115
+ assistant_response = "Result:"
116
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
117
+ st.session_state.history.append({"role": "assistant", "content": result})
118
+ else:
119
+ assistant_response = "Code executed successfully."
120
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
121
+ except Exception as e:
122
+ logging.error(f"An error occurred during code execution: {e}")
123
+ assistant_response = f"Error executing code: {e}"
124
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
125
+ else:
126
+ assistant_response = response.strip()
127
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
128
 
129
  except Exception as e:
 
134
  # Reset the user_input in session state
135
  st.session_state['user_input'] = ''
136
 
137
+ def extract_code(response):
138
+ """Extracts code enclosed between <CODE> and </CODE> tags."""
139
+ import re
140
+ pattern = r"<CODE>(.*?)</CODE>"
141
+ match = re.search(pattern, response, re.DOTALL)
142
+ if match:
143
+ return match.group(1).strip()
144
+ else:
145
+ return None
146
+
147
+ # Display the conversation history
148
  # Display the conversation history
149
  for message in st.session_state.history:
150
  if message['role'] == 'user':
151
  st.markdown(f"**User:** {message['content']}")
152
  elif message['role'] == 'assistant':
153
+ content = message['content']
154
+ if isinstance(content, pd.DataFrame):
155
+ st.markdown("**Assistant:** Here are the results:")
156
+ st.dataframe(content)
157
+ elif isinstance(content, (int, float, str, list, dict)):
158
+ st.markdown(f"**Assistant:** {content}")
159
  else:
160
+ st.markdown(f"**Assistant:** {content}")
161
+
162
 
163
  # Place the input field at the bottom with the callback
164
  st.text_input("Enter your message:", key='user_input', on_change=process_input)