Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,20 +8,14 @@ from streamlit_pandas_profiling import st_profile_report
|
|
8 |
import os
|
9 |
import requests
|
10 |
import json
|
11 |
-
from datetime import datetime
|
12 |
import re
|
13 |
-
import tempfile
|
14 |
from scipy import stats
|
15 |
-
from sklearn.impute import SimpleImputer
|
16 |
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
|
17 |
from sklearn.decomposition import PCA
|
18 |
-
import streamlit.components.v1 as components
|
19 |
-
from io import StringIO
|
20 |
from dotenv import load_dotenv
|
21 |
from flask import Flask, request, jsonify
|
22 |
from openai import OpenAI
|
23 |
import threading
|
24 |
-
from sentence_transformers import SentenceTransformer
|
25 |
|
26 |
# Load environment variables
|
27 |
load_dotenv()
|
@@ -30,13 +24,6 @@ load_dotenv()
|
|
30 |
flask_app = Flask(__name__)
|
31 |
FLASK_PORT = 5000 # Internal port for Flask, not exposed externally
|
32 |
|
33 |
-
# Initialize OpenAI client
|
34 |
-
api_key = os.getenv("OPENAI_API_KEY")
|
35 |
-
if not api_key:
|
36 |
-
st.error("OPENAI_API_KEY not set. Please configure it in the Hugging Face Space secrets.")
|
37 |
-
st.stop()
|
38 |
-
client = OpenAI(api_key=api_key)
|
39 |
-
|
40 |
# Flask RAG Endpoint
|
41 |
@flask_app.route('/rag_chat', methods=['POST'])
|
42 |
def rag_chat():
|
@@ -45,7 +32,6 @@ def rag_chat():
|
|
45 |
app_mode = data.get('app_mode', 'Data Upload')
|
46 |
dataset_text = data.get('dataset_text', '')
|
47 |
|
48 |
-
# RAG Logic: Use dataset_text as retrieval context
|
49 |
system_prompt = (
|
50 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
51 |
"The app has three pages:\n"
|
@@ -71,7 +57,7 @@ def rag_chat():
|
|
71 |
{"role": "system", "content": system_prompt},
|
72 |
{"role": "user", "content": user_input}
|
73 |
],
|
74 |
-
max_tokens=100,
|
75 |
temperature=0.7
|
76 |
)
|
77 |
return jsonify({"response": response.choices[0].message.content})
|
@@ -82,7 +68,6 @@ def rag_chat():
|
|
82 |
def run_flask():
|
83 |
flask_app.run(host='0.0.0.0', port=FLASK_PORT, debug=False, use_reloader=False)
|
84 |
|
85 |
-
# Start Flask thread
|
86 |
flask_thread = threading.Thread(target=run_flask, daemon=True)
|
87 |
flask_thread.start()
|
88 |
|
@@ -95,11 +80,11 @@ def update_cleaned_data(df):
|
|
95 |
if 'data_versions' not in st.session_state:
|
96 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
97 |
st.session_state.data_versions.append(df.copy())
|
|
|
98 |
st.success("✅ Action completed successfully!")
|
99 |
st.rerun()
|
100 |
|
101 |
def convert_csv_to_json_and_text(df):
|
102 |
-
"""Convert DataFrame to JSON and then to plain text."""
|
103 |
json_data = df.to_json(orient="records")
|
104 |
data_dict = json.loads(json_data)
|
105 |
text_summary = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
@@ -115,7 +100,6 @@ def convert_csv_to_json_and_text(df):
|
|
115 |
return text_summary
|
116 |
|
117 |
def get_chatbot_response(user_input, app_mode, dataset_text=""):
|
118 |
-
"""Send request to internal Flask RAG endpoint."""
|
119 |
payload = {
|
120 |
"user_input": user_input,
|
121 |
"app_mode": app_mode,
|
@@ -128,8 +112,88 @@ def get_chatbot_response(user_input, app_mode, dataset_text=""):
|
|
128 |
except requests.exceptions.RequestException as e:
|
129 |
return f"Error: Could not connect to RAG server. {str(e)}"
|
130 |
|
131 |
-
#
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
with st.sidebar:
|
134 |
st.title("🔮 Data-Vision Pro")
|
135 |
st.markdown("Your AI-powered data analysis suite with RAG.")
|
@@ -145,6 +209,13 @@ with st.sidebar:
|
|
145 |
st.info("🧹 Clean and preprocess your data using various tools.")
|
146 |
elif app_mode == "EDA":
|
147 |
st.info("🔍 Explore your data visually and statistically.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
st.markdown("---")
|
150 |
st.markdown("**Note**: Requires dependencies in `requirements.txt`.")
|
@@ -159,15 +230,29 @@ with st.sidebar:
|
|
159 |
st.markdown("Created by Calvin Allen-Crawford")
|
160 |
st.markdown("v1.0 | © 2025")
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
# Main App Pages
|
163 |
if app_mode == "Data Upload":
|
164 |
st.title("📤 Data Upload & Profiling")
|
165 |
st.header("Upload Your Dataset")
|
166 |
st.write("Supported formats: CSV, XLSX")
|
167 |
-
|
168 |
if 'raw_data' not in st.session_state:
|
169 |
st.info("It looks like no dataset has been uploaded yet. Would you like to upload a CSV or XLSX file?")
|
170 |
-
|
171 |
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
172 |
if uploaded_file:
|
173 |
st.session_state.pop('raw_data', None)
|
@@ -182,6 +267,7 @@ if app_mode == "Data Upload":
|
|
182 |
st.error("Uploaded file is empty.")
|
183 |
st.stop()
|
184 |
st.session_state.raw_data = df
|
|
|
185 |
st.session_state.dataset_text = convert_csv_to_json_and_text(df)
|
186 |
if 'data_versions' not in st.session_state:
|
187 |
st.session_state.data_versions = [df.copy()]
|
@@ -226,6 +312,92 @@ elif app_mode == "Data Cleaning":
|
|
226 |
st.session_state.dataset_text = convert_csv_to_json_and_text(st.session_state.cleaned_data)
|
227 |
st.rerun()
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
elif app_mode == "EDA":
|
230 |
st.title("🔍 Interactive Data Explorer")
|
231 |
if 'cleaned_data' not in st.session_state:
|
@@ -242,10 +414,109 @@ elif app_mode == "EDA":
|
|
242 |
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
243 |
col4.metric("Duplicates", df.duplicated().sum())
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
# Chatbot Section
|
246 |
st.markdown("---")
|
247 |
st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
|
248 |
-
st.info("Ask me about the app or your data! Try: '
|
249 |
if "chat_history" not in st.session_state:
|
250 |
st.session_state.chat_history = []
|
251 |
|
@@ -258,10 +529,13 @@ if user_input:
|
|
258 |
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
259 |
with st.chat_message("user"):
|
260 |
st.markdown(user_input)
|
261 |
-
|
262 |
-
with st.spinner("Thinking with RAG..."):
|
263 |
dataset_text = st.session_state.get("dataset_text", "")
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
266 |
with st.chat_message("assistant"):
|
267 |
st.markdown(response)
|
|
|
8 |
import os
|
9 |
import requests
|
10 |
import json
|
|
|
11 |
import re
|
|
|
12 |
from scipy import stats
|
|
|
13 |
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
|
14 |
from sklearn.decomposition import PCA
|
|
|
|
|
15 |
from dotenv import load_dotenv
|
16 |
from flask import Flask, request, jsonify
|
17 |
from openai import OpenAI
|
18 |
import threading
|
|
|
19 |
|
20 |
# Load environment variables
|
21 |
load_dotenv()
|
|
|
24 |
flask_app = Flask(__name__)
|
25 |
FLASK_PORT = 5000 # Internal port for Flask, not exposed externally
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Flask RAG Endpoint
|
28 |
@flask_app.route('/rag_chat', methods=['POST'])
|
29 |
def rag_chat():
|
|
|
32 |
app_mode = data.get('app_mode', 'Data Upload')
|
33 |
dataset_text = data.get('dataset_text', '')
|
34 |
|
|
|
35 |
system_prompt = (
|
36 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
37 |
"The app has three pages:\n"
|
|
|
57 |
{"role": "system", "content": system_prompt},
|
58 |
{"role": "user", "content": user_input}
|
59 |
],
|
60 |
+
max_tokens=100,
|
61 |
temperature=0.7
|
62 |
)
|
63 |
return jsonify({"response": response.choices[0].message.content})
|
|
|
68 |
def run_flask():
|
69 |
flask_app.run(host='0.0.0.0', port=FLASK_PORT, debug=False, use_reloader=False)
|
70 |
|
|
|
71 |
flask_thread = threading.Thread(target=run_flask, daemon=True)
|
72 |
flask_thread.start()
|
73 |
|
|
|
80 |
if 'data_versions' not in st.session_state:
|
81 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
82 |
st.session_state.data_versions.append(df.copy())
|
83 |
+
st.session_state.dataset_text = convert_csv_to_json_and_text(df)
|
84 |
st.success("✅ Action completed successfully!")
|
85 |
st.rerun()
|
86 |
|
87 |
def convert_csv_to_json_and_text(df):
|
|
|
88 |
json_data = df.to_json(orient="records")
|
89 |
data_dict = json.loads(json_data)
|
90 |
text_summary = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
|
|
100 |
return text_summary
|
101 |
|
102 |
def get_chatbot_response(user_input, app_mode, dataset_text=""):
|
|
|
103 |
payload = {
|
104 |
"user_input": user_input,
|
105 |
"app_mode": app_mode,
|
|
|
112 |
except requests.exceptions.RequestException as e:
|
113 |
return f"Error: Could not connect to RAG server. {str(e)}"
|
114 |
|
115 |
+
# Command Functions for LLM
|
116 |
+
def drop_columns(columns):
|
117 |
+
if 'cleaned_data' in st.session_state:
|
118 |
+
df = st.session_state.cleaned_data.copy()
|
119 |
+
columns_to_drop = [col.strip() for col in columns.split(',')]
|
120 |
+
valid_columns = [col for col in columns_to_drop if col in df.columns]
|
121 |
+
if valid_columns:
|
122 |
+
df.drop(valid_columns, axis=1, inplace=True)
|
123 |
+
update_cleaned_data(df)
|
124 |
+
return f"Dropped columns: {', '.join(valid_columns)}"
|
125 |
+
else:
|
126 |
+
return "No valid columns found to drop."
|
127 |
+
return "No dataset loaded."
|
128 |
+
|
129 |
+
# LLM-Driven EDA Commands
|
130 |
+
def generate_scatter_plot(params):
|
131 |
+
df = st.session_state.cleaned_data
|
132 |
+
match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
|
133 |
+
if match and len(match.groups()) >= 2:
|
134 |
+
x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
|
135 |
+
if x_axis in df.columns and y_axis in df.columns:
|
136 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
137 |
+
st.plotly_chart(fig)
|
138 |
+
st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
|
139 |
+
return f"Generated scatter plot of {x_axis} vs {y_axis}"
|
140 |
+
return "Invalid columns for scatter plot."
|
141 |
+
|
142 |
+
def generate_histogram(params):
|
143 |
+
df = st.session_state.cleaned_data
|
144 |
+
x_axis = params.strip()
|
145 |
+
if x_axis in df.columns:
|
146 |
+
fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
|
147 |
+
st.plotly_chart(fig)
|
148 |
+
st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
|
149 |
+
return f"Generated histogram of {x_axis}"
|
150 |
+
return "Invalid column for histogram."
|
151 |
+
|
152 |
+
# Inference from Plotted Data
|
153 |
+
def analyze_plot():
|
154 |
+
if "last_plot" not in st.session_state:
|
155 |
+
return "No plot available to analyze."
|
156 |
+
plot_info = st.session_state.last_plot
|
157 |
+
df = pd.read_json(plot_info["data"])
|
158 |
+
plot_type = plot_info["type"]
|
159 |
+
x_col = plot_info["x"]
|
160 |
+
y_col = plot_info["y"] if "y" in plot_info else None
|
161 |
+
|
162 |
+
if plot_type == "Scatter Plot" and y_col:
|
163 |
+
correlation = df[x_col].corr(df[y_col])
|
164 |
+
strength = "strong" if abs(correlation) > 0.7 else "moderate" if abs(correlation) > 0.3 else "weak"
|
165 |
+
direction = "positive" if correlation > 0 else "negative"
|
166 |
+
return f"The scatter plot of {x_col} vs {y_col} shows a {strength} {direction} correlation (Pearson r = {correlation:.2f})."
|
167 |
+
elif plot_type == "Histogram":
|
168 |
+
skewness = df[x_col].skew()
|
169 |
+
skew_desc = "positively skewed" if skewness > 1 else "negatively skewed" if skewness < -1 else "approximately symmetric"
|
170 |
+
return f"The histogram of {x_col} is {skew_desc} (skewness = {skewness:.2f})."
|
171 |
+
return "Inference not available for this plot type."
|
172 |
+
|
173 |
+
# Parse Chatbot Commands
|
174 |
+
def parse_command(command):
|
175 |
+
command = command.lower().strip()
|
176 |
+
if "drop columns" in command or "drop column" in command:
|
177 |
+
columns = command.replace("drop columns", "").replace("drop column", "").strip()
|
178 |
+
return drop_columns, columns
|
179 |
+
elif "show a scatter plot" in command or "scatter plot of" in command:
|
180 |
+
params = command.replace("show a scatter plot of", "").replace("scatter plot of", "").strip()
|
181 |
+
return generate_scatter_plot, params
|
182 |
+
elif "show a histogram" in command or "histogram of" in command:
|
183 |
+
params = command.replace("show a histogram of", "").replace("histogram of", "").strip()
|
184 |
+
return generate_histogram, params
|
185 |
+
elif "analyze plot" in command:
|
186 |
+
return lambda x: analyze_plot(), None
|
187 |
+
return None, "Command not recognized. Try 'drop columns X, Y', 'scatter plot of X vs Y', or 'analyze plot'."
|
188 |
+
|
189 |
+
# Dataset Preview Function
|
190 |
+
def display_dataset_preview():
|
191 |
+
if 'cleaned_data' in st.session_state:
|
192 |
+
st.subheader("Current Dataset Preview")
|
193 |
+
st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
|
194 |
+
st.write("---")
|
195 |
+
|
196 |
+
# Sidebar Navigation with API Key Input
|
197 |
with st.sidebar:
|
198 |
st.title("🔮 Data-Vision Pro")
|
199 |
st.markdown("Your AI-powered data analysis suite with RAG.")
|
|
|
209 |
st.info("🧹 Clean and preprocess your data using various tools.")
|
210 |
elif app_mode == "EDA":
|
211 |
st.info("🔍 Explore your data visually and statistically.")
|
212 |
+
|
213 |
+
# API Key Input Field
|
214 |
+
api_key_input = st.text_input(
|
215 |
+
"Enter your API key (optional)",
|
216 |
+
type="password",
|
217 |
+
help="Enter your API key to override the default. Leave blank to use the app's default key."
|
218 |
+
)
|
219 |
|
220 |
st.markdown("---")
|
221 |
st.markdown("**Note**: Requires dependencies in `requirements.txt`.")
|
|
|
230 |
st.markdown("Created by Calvin Allen-Crawford")
|
231 |
st.markdown("v1.0 | © 2025")
|
232 |
|
233 |
+
# Determine which API key to use
|
234 |
+
if api_key_input:
|
235 |
+
api_key = api_key_input # Use the user-provided API key from the sidebar
|
236 |
+
else:
|
237 |
+
api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY")) # Fall back to secret or environment variable
|
238 |
+
|
239 |
+
if not api_key:
|
240 |
+
st.error("API key is required. Please provide it in the sidebar or ensure it’s set in the app’s secrets.")
|
241 |
+
st.stop()
|
242 |
+
|
243 |
+
# Initialize OpenAI client with the selected API key
|
244 |
+
client = OpenAI(api_key=api_key)
|
245 |
+
|
246 |
+
# Display dataset preview at the top of each page
|
247 |
+
display_dataset_preview()
|
248 |
+
|
249 |
# Main App Pages
|
250 |
if app_mode == "Data Upload":
|
251 |
st.title("📤 Data Upload & Profiling")
|
252 |
st.header("Upload Your Dataset")
|
253 |
st.write("Supported formats: CSV, XLSX")
|
|
|
254 |
if 'raw_data' not in st.session_state:
|
255 |
st.info("It looks like no dataset has been uploaded yet. Would you like to upload a CSV or XLSX file?")
|
|
|
256 |
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
|
257 |
if uploaded_file:
|
258 |
st.session_state.pop('raw_data', None)
|
|
|
267 |
st.error("Uploaded file is empty.")
|
268 |
st.stop()
|
269 |
st.session_state.raw_data = df
|
270 |
+
st.session_state.cleaned_data = df.copy()
|
271 |
st.session_state.dataset_text = convert_csv_to_json_and_text(df)
|
272 |
if 'data_versions' not in st.session_state:
|
273 |
st.session_state.data_versions = [df.copy()]
|
|
|
312 |
st.session_state.dataset_text = convert_csv_to_json_and_text(st.session_state.cleaned_data)
|
313 |
st.rerun()
|
314 |
|
315 |
+
with st.expander("🛠️ Data Cleaning Operations", expanded=True):
|
316 |
+
enhance_section_title("🔍 Missing Values Treatment")
|
317 |
+
missing_cols = df.columns[df.isna().any()].tolist()
|
318 |
+
if missing_cols:
|
319 |
+
cols = st.multiselect("Select columns with missing values", missing_cols)
|
320 |
+
method = st.selectbox("Choose imputation method", [
|
321 |
+
"Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
|
322 |
+
])
|
323 |
+
if method == "Fill with Custom Value":
|
324 |
+
custom_val = st.text_input("Enter custom value:")
|
325 |
+
if st.button("Apply Missing Value Treatment"):
|
326 |
+
new_df = df.copy()
|
327 |
+
if method == "Drop Missing Values":
|
328 |
+
new_df = new_df.dropna(subset=cols)
|
329 |
+
elif method == "Fill with Mean/Median":
|
330 |
+
for col in cols:
|
331 |
+
if pd.api.types.is_numeric_dtype(new_df[col]):
|
332 |
+
new_df[col] = new_df[col].fillna(new_df[col].median())
|
333 |
+
else:
|
334 |
+
new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
|
335 |
+
elif method == "Fill with Custom Value" and custom_val:
|
336 |
+
new_df[cols] = new_df[cols].fillna(custom_val)
|
337 |
+
elif method == "Forward Fill":
|
338 |
+
new_df[cols] = new_df[cols].ffill()
|
339 |
+
elif method == "Backward Fill":
|
340 |
+
new_df[cols] = new_df[cols].bfill()
|
341 |
+
update_cleaned_data(new_df)
|
342 |
+
else:
|
343 |
+
st.success("✨ No missing values detected!")
|
344 |
+
|
345 |
+
enhance_section_title("🔄 Data Type Conversion")
|
346 |
+
col_to_convert = st.selectbox("Select column to convert", df.columns)
|
347 |
+
new_type = st.selectbox("Select new data type", ["String", "Integer", "Float", "Boolean", "Datetime"])
|
348 |
+
if new_type == "Datetime":
|
349 |
+
date_format = st.text_input("Enter date format (e.g., %Y-%m-%d):", "%Y-%m-%d")
|
350 |
+
if st.button("Convert Data Type"):
|
351 |
+
new_df = df.copy()
|
352 |
+
if new_type == "String":
|
353 |
+
new_df[col_to_convert] = new_df[col_to_convert].astype(str)
|
354 |
+
elif new_type == "Integer":
|
355 |
+
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce').astype('Int64')
|
356 |
+
elif new_type == "Float":
|
357 |
+
new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce')
|
358 |
+
elif new_type == "Boolean":
|
359 |
+
new_df[col_to_convert] = new_df[col_to_convert].astype(bool)
|
360 |
+
elif new_type == "Datetime":
|
361 |
+
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
362 |
+
update_cleaned_data(new_df)
|
363 |
+
|
364 |
+
enhance_section_title("🗑️ Drop Columns")
|
365 |
+
columns_to_drop = st.multiselect("Select columns to remove", df.columns)
|
366 |
+
if columns_to_drop and st.button("Confirm Column Removal"):
|
367 |
+
new_df = df.copy()
|
368 |
+
new_df = new_df.drop(columns=columns_to_drop)
|
369 |
+
update_cleaned_data(new_df)
|
370 |
+
|
371 |
+
enhance_section_title("🔢 Encoding Options")
|
372 |
+
encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
|
373 |
+
data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
|
374 |
+
if data_to_encode and st.button("Apply Encoding"):
|
375 |
+
new_df = df.copy()
|
376 |
+
if encoding_method == "Label Encoding":
|
377 |
+
for col in data_to_encode:
|
378 |
+
le = LabelEncoder()
|
379 |
+
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
380 |
+
elif encoding_method == "One-Hot Encoding":
|
381 |
+
new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
|
382 |
+
update_cleaned_data(new_df)
|
383 |
+
|
384 |
+
enhance_section_title("📏 StandardScaler")
|
385 |
+
scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
|
386 |
+
if scale_cols and st.button("Apply StandardScaler"):
|
387 |
+
new_df = df.copy()
|
388 |
+
scaler = StandardScaler()
|
389 |
+
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
390 |
+
update_cleaned_data(new_df)
|
391 |
+
|
392 |
+
enhance_section_title("🕵️ Pattern-Based Cleaning")
|
393 |
+
selected_col = st.selectbox("Select text column for pattern cleaning", df.select_dtypes(include='object').columns)
|
394 |
+
pattern = st.text_input("Enter regex pattern:")
|
395 |
+
replacement = st.text_input("Enter replacement value:")
|
396 |
+
if st.button("Apply Pattern Replacement"):
|
397 |
+
new_df = df.copy()
|
398 |
+
new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
|
399 |
+
update_cleaned_data(new_df)
|
400 |
+
|
401 |
elif app_mode == "EDA":
|
402 |
st.title("🔍 Interactive Data Explorer")
|
403 |
if 'cleaned_data' not in st.session_state:
|
|
|
414 |
col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
|
415 |
col4.metric("Duplicates", df.duplicated().sum())
|
416 |
|
417 |
+
tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
|
418 |
+
with tab1:
|
419 |
+
st.write("First few rows of the dataset:")
|
420 |
+
st.dataframe(df.head(), use_container_width=True)
|
421 |
+
with tab2:
|
422 |
+
st.write("Column Data Types:")
|
423 |
+
type_counts = df.dtypes.value_counts().reset_index()
|
424 |
+
type_counts.columns = ['Type', 'Count']
|
425 |
+
st.dataframe(type_counts, use_container_width=True)
|
426 |
+
with tab3:
|
427 |
+
st.write("Missing Values Matrix:")
|
428 |
+
fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
|
429 |
+
fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
|
430 |
+
st.plotly_chart(fig_missing, use_container_width=True)
|
431 |
+
|
432 |
+
enhance_section_title("Interactive Visualization Builder")
|
433 |
+
with st.container():
|
434 |
+
col1, col2 = st.columns([1, 3])
|
435 |
+
with col1:
|
436 |
+
plot_type = st.selectbox("Choose visualization type", [
|
437 |
+
"Scatter Plot", "Histogram", "Box Plot", "Violin Plot", "Line Chart", "Bar Chart",
|
438 |
+
"Correlation Matrix", "Heatmap", "3D Scatter", "Parallel Categories", "Segmented Bar Chart",
|
439 |
+
"Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot", "Count Plot", "Lollipop Chart"
|
440 |
+
])
|
441 |
+
x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
|
442 |
+
y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Line Chart", "Heatmap", "Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot", "Lollipop Chart"] else None
|
443 |
+
z_axis = st.selectbox("Z-axis", df.columns) if plot_type == "3D Scatter" else None
|
444 |
+
color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
|
445 |
+
if plot_type == "Parallel Categories":
|
446 |
+
dimensions = st.multiselect("Dimensions", df.columns.tolist(), default=df.columns[:3].tolist())
|
447 |
+
elif plot_type == "Segmented Bar Chart":
|
448 |
+
segment_col = st.selectbox("Segment Column (Categorical)", df.select_dtypes(exclude=np.number).columns)
|
449 |
+
elif plot_type == "Bubble Plot":
|
450 |
+
size_col = st.selectbox("Size Column", df.columns)
|
451 |
+
|
452 |
+
with col2:
|
453 |
+
try:
|
454 |
+
fig = None
|
455 |
+
if plot_type == "Scatter Plot" and x_axis and y_axis:
|
456 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, trendline="lowess", title=f'Scatter Plot of {x_axis} vs {y_axis}')
|
457 |
+
elif plot_type == "Histogram" and x_axis:
|
458 |
+
fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, marginal="box", title=f'Histogram of {x_axis}')
|
459 |
+
elif plot_type == "Box Plot" and x_axis and y_axis:
|
460 |
+
fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
|
461 |
+
elif plot_type == "Violin Plot" and x_axis and y_axis:
|
462 |
+
fig = px.violin(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, box=True, title=f'Violin Plot of {x_axis} vs {y_axis}')
|
463 |
+
elif plot_type == "Line Chart" and x_axis and y_axis:
|
464 |
+
fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
|
465 |
+
elif plot_type == "Bar Chart" and x_axis:
|
466 |
+
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
|
467 |
+
elif plot_type == "Correlation Matrix":
|
468 |
+
numeric_df = df.select_dtypes(include=np.number)
|
469 |
+
if len(numeric_df.columns) > 1:
|
470 |
+
corr = numeric_df.corr()
|
471 |
+
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
|
472 |
+
elif plot_type == "Heatmap" and x_axis and y_axis:
|
473 |
+
fig = px.density_heatmap(df, x=x_axis, y=y_axis, facet_col=color_by if color_by != "None" else None, title=f'Heatmap of {x_axis} vs {y_axis}')
|
474 |
+
elif plot_type == "3D Scatter" and x_axis and y_axis and z_axis:
|
475 |
+
fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=color_by if color_by != "None" else None, title=f'3D Scatter Plot of {x_axis} vs {y_axis} vs {z_axis}')
|
476 |
+
elif plot_type == "Parallel Categories" and dimensions:
|
477 |
+
fig = px.parallel_categories(df, dimensions=dimensions, color=color_by if color_by != "None" else None, title='Parallel Categories Plot')
|
478 |
+
elif plot_type == "Segmented Bar Chart" and x_axis and segment_col:
|
479 |
+
segment_counts = df.groupby([x_axis, segment_col]).size().reset_index(name='counts')
|
480 |
+
fig = px.bar(segment_counts, x=x_axis, y='counts', color=segment_col, title=f'Segmented Bar Chart of {x_axis} by {segment_col}')
|
481 |
+
fig.update_layout(yaxis_title="Count")
|
482 |
+
elif plot_type == "Swarm Plot" and x_axis and y_axis:
|
483 |
+
fig = px.strip(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Swarm Plot of {x_axis} vs {y_axis}')
|
484 |
+
elif plot_type == "Ridge Plot" and x_axis and y_axis:
|
485 |
+
fig = px.histogram(df, x=x_axis, color=y_axis, marginal="rug", title=f'Ridge Plot of {x_axis} by {y_axis}')
|
486 |
+
elif plot_type == "Bubble Plot" and x_axis and y_axis and size_col:
|
487 |
+
fig = px.scatter(df, x=x_axis, y=y_axis, size=size_col, color=color_by if color_by != "None" else None, title=f'Bubble Plot of {x_axis} vs {y_axis}')
|
488 |
+
elif plot_type == "Density Plot" and x_axis and y_axis:
|
489 |
+
fig = px.density_heatmap(df, x=x_axis, y=y_axis, color_continuous_scale="Viridis", title=f'Density Plot of {x_axis} vs {y_axis}')
|
490 |
+
elif plot_type == "Count Plot" and x_axis:
|
491 |
+
fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Count Plot of {x_axis}')
|
492 |
+
fig.update_layout(yaxis_title="Count")
|
493 |
+
elif plot_type == "Lollipop Chart" and x_axis and y_axis:
|
494 |
+
fig = go.Figure()
|
495 |
+
fig.add_trace(go.Scatter(x=df[x_axis], y=df[y_axis], mode='markers', marker=dict(size=10)))
|
496 |
+
for i in range(len(df)):
|
497 |
+
fig.add_trace(go.Scatter(x=[df[x_axis].iloc[i], df[x_axis].iloc[i]], y=[0, df[y_axis].iloc[i]], mode='lines', line=dict(color='gray')))
|
498 |
+
fig.update_layout(showlegend=False, title=f'Lollipop Chart of {x_axis} vs {y_axis}')
|
499 |
+
|
500 |
+
if fig:
|
501 |
+
fig.update_layout(template="plotly_white")
|
502 |
+
st.plotly_chart(fig, use_container_width=True)
|
503 |
+
st.session_state.last_plot = {
|
504 |
+
"type": plot_type,
|
505 |
+
"x": x_axis,
|
506 |
+
"y": y_axis,
|
507 |
+
"z": z_axis,
|
508 |
+
"color": color_by if color_by != "None" else None,
|
509 |
+
"data": df[[x_axis, y_axis] + ([z_axis] if z_axis else [])].to_json() if x_axis and y_axis else df[[x_axis]].to_json()
|
510 |
+
}
|
511 |
+
else:
|
512 |
+
st.error("Please provide required inputs for the selected plot type.")
|
513 |
+
except Exception as e:
|
514 |
+
st.error(f"Couldn't create visualization: {str(e)}")
|
515 |
+
|
516 |
# Chatbot Section
|
517 |
st.markdown("---")
|
518 |
st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
|
519 |
+
st.info("Ask me about the app or your data! Try: 'drop columns X, Y', 'scatter plot of X vs Y', or 'analyze plot'")
|
520 |
if "chat_history" not in st.session_state:
|
521 |
st.session_state.chat_history = []
|
522 |
|
|
|
529 |
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
530 |
with st.chat_message("user"):
|
531 |
st.markdown(user_input)
|
532 |
+
with st.spinner("Processing..."):
|
|
|
533 |
dataset_text = st.session_state.get("dataset_text", "")
|
534 |
+
func, param = parse_command(user_input)
|
535 |
+
if func:
|
536 |
+
response = func(param) if param else func(None)
|
537 |
+
else:
|
538 |
+
response = get_chatbot_response(user_input, app_mode, dataset_text)
|
539 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
540 |
with st.chat_message("assistant"):
|
541 |
st.markdown(response)
|