Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -156,24 +156,86 @@ def create_vector_store(df_text):
|
|
| 156 |
os.unlink(temp_path)
|
| 157 |
return vector_store
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
| 160 |
-
"""Get response from Groq with vector store context"""
|
| 161 |
system_prompt = (
|
| 162 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
| 163 |
f"The user is on the '{app_mode}' page:\n"
|
| 164 |
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
| 165 |
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
| 166 |
-
"- **EDA**: Visualize data (e.g., scatter plots, histograms).\n"
|
|
|
|
| 167 |
)
|
| 168 |
|
| 169 |
context = ""
|
| 170 |
if vector_store:
|
| 171 |
docs = vector_store.similarity_search(user_input, k=3)
|
| 172 |
if docs:
|
| 173 |
-
context = "\n\nDataset Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
| 174 |
-
system_prompt += f"Use this dataset context to augment your response:\n{context}"
|
| 175 |
else:
|
| 176 |
-
system_prompt += "No dataset is loaded. Assist based on app functionality."
|
| 177 |
|
| 178 |
try:
|
| 179 |
response = client.chat.completions.create(
|
|
@@ -230,20 +292,8 @@ def analyze_plot():
|
|
| 230 |
return "No plot available to analyze."
|
| 231 |
plot_info = st.session_state.last_plot
|
| 232 |
df = pd.read_json(plot_info["data"])
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
y_col = plot_info["y"] if "y" in plot_info else None
|
| 236 |
-
|
| 237 |
-
if plot_type == "Scatter Plot" and y_col:
|
| 238 |
-
correlation = df[x_col].corr(df[y_col])
|
| 239 |
-
strength = "strong" if abs(correlation) > 0.7 else "moderate" if abs(correlation) > 0.3 else "weak"
|
| 240 |
-
direction = "positive" if correlation > 0 else "negative"
|
| 241 |
-
return f"The scatter plot of {x_col} vs {y_col} shows a {strength} {direction} correlation (Pearson r = {correlation:.2f})."
|
| 242 |
-
elif plot_type == "Histogram":
|
| 243 |
-
skewness = df[x_col].skew()
|
| 244 |
-
skew_desc = "positively skewed" if skewness > 1 else "negatively skewed" if skewness < -1 else "approximately symmetric"
|
| 245 |
-
return f"The histogram of {x_col} is {skew_desc} (skewness = {skewness:.2f})."
|
| 246 |
-
return "Inference not available for this plot type."
|
| 247 |
|
| 248 |
def parse_command(command):
|
| 249 |
command = command.lower().strip()
|
|
@@ -529,6 +579,11 @@ def main():
|
|
| 529 |
"y": y_axis,
|
| 530 |
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
| 531 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
else:
|
| 533 |
st.error("Please provide required inputs for the selected plot type.")
|
| 534 |
except Exception as e:
|
|
|
|
| 156 |
os.unlink(temp_path)
|
| 157 |
return vector_store
|
| 158 |
|
| 159 |
+
def update_vector_store_with_plot(plot_text, existing_vector_store):
|
| 160 |
+
"""Update the FAISS vector store with plot-derived text"""
|
| 161 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as temp_file:
|
| 162 |
+
temp_file.write(plot_text)
|
| 163 |
+
temp_path = temp_file.name
|
| 164 |
+
|
| 165 |
+
loader = TextLoader(temp_path)
|
| 166 |
+
documents = loader.load()
|
| 167 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
| 168 |
+
texts = text_splitter.split_documents(documents)
|
| 169 |
+
|
| 170 |
+
if existing_vector_store:
|
| 171 |
+
existing_vector_store.add_documents(texts)
|
| 172 |
+
else:
|
| 173 |
+
existing_vector_store = FAISS.from_documents(texts, embeddings)
|
| 174 |
+
|
| 175 |
+
os.unlink(temp_path)
|
| 176 |
+
return existing_vector_store
|
| 177 |
+
|
| 178 |
+
def extract_plot_data(plot_info, df):
|
| 179 |
+
"""Extract numerical data from the last generated plot and convert to text"""
|
| 180 |
+
plot_type = plot_info["type"]
|
| 181 |
+
x_col = plot_info["x"]
|
| 182 |
+
y_col = plot_info["y"] if "y" in plot_info else None
|
| 183 |
+
data = pd.read_json(plot_info["data"])
|
| 184 |
+
|
| 185 |
+
plot_text = f"Plot Type: {plot_type}\n"
|
| 186 |
+
plot_text += f"X-Axis: {x_col}\n"
|
| 187 |
+
if y_col:
|
| 188 |
+
plot_text += f"Y-Axis: {y_col}\n"
|
| 189 |
+
|
| 190 |
+
if plot_type == "Scatter Plot" and y_col:
|
| 191 |
+
correlation = data[x_col].corr(data[y_col])
|
| 192 |
+
slope, intercept, r_value, p_value, std_err = stats.linregress(data[x_col].dropna(), data[y_col].dropna())
|
| 193 |
+
plot_text += f"Correlation: {correlation:.2f}\n"
|
| 194 |
+
plot_text += f"Linear Regression: Slope={slope:.2f}, Intercept={intercept:.2f}, R²={r_value**2:.2f}, p-value={p_value:.4f}\n"
|
| 195 |
+
plot_text += f"X Stats: Mean={data[x_col].mean():.2f}, Std={data[x_col].std():.2f}, Min={data[x_col].min():.2f}, Max={data[x_col].max():.2f}\n"
|
| 196 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Min={data[y_col].min():.2f}, Max={data[y_col].max():.2f}\n"
|
| 197 |
+
elif plot_type == "Histogram":
|
| 198 |
+
plot_text += f"Stats: Mean={data[x_col].mean():.2f}, Median={data[x_col].median():.2f}, Std={data[x_col].std():.2f}\n"
|
| 199 |
+
plot_text += f"Skewness: {data[x_col].skew():.2f}\n"
|
| 200 |
+
plot_text += f"Range: [{data[x_col].min():.2f}, {data[x_col].max():.2f}]\n"
|
| 201 |
+
elif plot_type == "Box Plot" and y_col:
|
| 202 |
+
q1, q3 = data[y_col].quantile(0.25), data[y_col].quantile(0.75)
|
| 203 |
+
iqr = q3 - q1
|
| 204 |
+
plot_text += f"Y Stats: Median={data[y_col].median():.2f}, Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}\n"
|
| 205 |
+
plot_text += f"Outliers: {len(data[y_col][(data[y_col] < q1 - 1.5 * iqr) | (data[y_col] > q3 + 1.5 * iqr)])} potential outliers\n"
|
| 206 |
+
elif plot_type == "Line Chart" and y_col:
|
| 207 |
+
plot_text += f"Y Stats: Mean={data[y_col].mean():.2f}, Std={data[y_col].std():.2f}, Trend={'increasing' if data[y_col].iloc[-1] > data[y_col].iloc[0] else 'decreasing'}\n"
|
| 208 |
+
elif plot_type == "Bar Chart":
|
| 209 |
+
plot_text += f"Counts: {data[x_col].value_counts().to_dict()}\n"
|
| 210 |
+
elif plot_type == "Correlation Matrix":
|
| 211 |
+
corr = data.corr()
|
| 212 |
+
plot_text += "Correlation Matrix:\n"
|
| 213 |
+
for col1 in corr.columns:
|
| 214 |
+
for col2 in corr.index:
|
| 215 |
+
if col1 < col2: # Avoid duplicates
|
| 216 |
+
plot_text += f"{col1} vs {col2}: {corr.loc[col2, col1]:.2f}\n"
|
| 217 |
+
|
| 218 |
+
return plot_text
|
| 219 |
+
|
| 220 |
def get_chatbot_response(user_input, app_mode, vector_store=None, model="llama3-70b-8192"):
|
| 221 |
+
"""Get response from Groq with vector store context including plot data"""
|
| 222 |
system_prompt = (
|
| 223 |
"You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
|
| 224 |
f"The user is on the '{app_mode}' page:\n"
|
| 225 |
"- **Data Upload**: Upload CSV/XLSX files, view stats, or generate reports.\n"
|
| 226 |
"- **Data Cleaning**: Clean data (e.g., handle missing values, encode variables).\n"
|
| 227 |
+
"- **EDA**: Visualize data (e.g., scatter plots, histograms) and analyze plots.\n"
|
| 228 |
+
"When analyzing plots, provide detailed insights based on numerical data extracted from them."
|
| 229 |
)
|
| 230 |
|
| 231 |
context = ""
|
| 232 |
if vector_store:
|
| 233 |
docs = vector_store.similarity_search(user_input, k=3)
|
| 234 |
if docs:
|
| 235 |
+
context = "\n\nDataset and Plot Context:\n" + "\n".join([f"- {doc.page_content}" for doc in docs])
|
| 236 |
+
system_prompt += f"Use this dataset and plot context to augment your response:\n{context}"
|
| 237 |
else:
|
| 238 |
+
system_prompt += "No dataset or plot data is loaded. Assist based on app functionality."
|
| 239 |
|
| 240 |
try:
|
| 241 |
response = client.chat.completions.create(
|
|
|
|
| 292 |
return "No plot available to analyze."
|
| 293 |
plot_info = st.session_state.last_plot
|
| 294 |
df = pd.read_json(plot_info["data"])
|
| 295 |
+
plot_text = extract_plot_data(plot_info, df)
|
| 296 |
+
return f"Analysis of the last plot:\n{plot_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
def parse_command(command):
|
| 299 |
command = command.lower().strip()
|
|
|
|
| 579 |
"y": y_axis,
|
| 580 |
"data": df[[x_axis, y_axis]].to_json() if y_axis else df[[x_axis]].to_json()
|
| 581 |
}
|
| 582 |
+
# Extract numerical data and update vector store
|
| 583 |
+
plot_text = extract_plot_data(st.session_state.last_plot, df)
|
| 584 |
+
st.session_state.vector_store = update_vector_store_with_plot(plot_text, st.session_state.vector_store)
|
| 585 |
+
with st.expander("Extracted Plot Data"):
|
| 586 |
+
st.text(plot_text)
|
| 587 |
else:
|
| 588 |
st.error("Please provide required inputs for the selected plot type.")
|
| 589 |
except Exception as e:
|