DrishtiSharma commited on
Commit
9dc25a4
Β·
verified Β·
1 Parent(s): 7ff7723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -106
app.py CHANGED
@@ -191,11 +191,12 @@ COLUMN_SYNONYMS = {
191
  }
192
 
193
 
194
- # Fuzzy match to map query terms to dataset columns
195
  def fuzzy_match_columns(query, n=2):
196
  query = query.lower()
197
  all_synonyms = {synonym: col for col, synonyms in COLUMN_SYNONYMS.items() for synonym in synonyms}
198
- words = query.replace("and", "").replace("vs", "").split() # Remove "and"/"vs" for better matching
 
199
 
200
  matched_columns = []
201
  for word in words:
@@ -203,70 +204,70 @@ def fuzzy_match_columns(query, n=2):
203
  for match in matches:
204
  matched_columns.append(all_synonyms[match])
205
 
206
- # Remove duplicates while preserving order
207
- matched_columns = list(dict.fromkeys(matched_columns))
208
- return matched_columns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # Visualization generator with dynamic groupby handling
211
  def generate_visual_from_query(query, df):
212
  try:
213
- # Step 1: Fuzzy match columns mentioned in the query
214
  matched_columns = fuzzy_match_columns(query)
215
 
216
- # Step 2: Detect groupby intent (handling "and", "vs", "by")
217
- if "and" in query or "vs" in query or "by" in query or len(matched_columns) > 1:
218
- if len(matched_columns) >= 2:
219
- x_axis = matched_columns[0]
220
- group_by = matched_columns[1]
221
- else:
222
- x_axis, group_by = matched_columns[0], None
223
  else:
224
- x_axis = matched_columns[0] if matched_columns else None
225
- group_by = None
226
 
227
- # Step 3: Visualization logic
228
- if "distribution" in query and x_axis:
229
  fig = px.box(df, x=x_axis, y="salary_in_usd", color=group_by,
230
  title=f"Salary Distribution by {x_axis.replace('_', ' ').title()}"
231
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
232
- return fig
233
 
 
234
  elif "average" in query or "mean" in query:
235
  grouped_df = df.groupby([x_axis] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
236
  fig = px.bar(grouped_df, x=x_axis, y="salary_in_usd", color=group_by,
237
- barmode="group",
238
  title=f"Average Salary by {x_axis.replace('_', ' ').title()}"
239
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
240
- return fig
241
 
242
- elif "trend" in query and "work_year" in df.columns and x_axis:
 
243
  grouped_df = df.groupby(["work_year", x_axis])["salary_in_usd"].mean().reset_index()
244
  fig = px.line(grouped_df, x="work_year", y="salary_in_usd", color=x_axis,
245
- title=f"Salary Trend over Years by {x_axis.replace('_', ' ').title()}")
246
- return fig
247
 
 
248
  elif "remote" in query:
249
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
250
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
251
- barmode="group", title="Remote Work Impact on Salary")
252
- return fig
253
-
254
- elif "company size" in query:
255
- grouped_df = df.groupby(["company_size"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
256
- fig = px.bar(grouped_df, x="company_size", y="salary_in_usd", color=group_by,
257
- title=f"Salary by Company Size"
258
- + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
259
- return fig
260
-
261
- elif "country" in query or "location" in query:
262
- grouped_df = df.groupby(["employee_residence"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
263
- fig = px.bar(grouped_df, x="employee_residence", y="salary_in_usd", color=group_by,
264
- title=f"Salary by Employee Residence"
265
- + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
266
- return fig
267
 
 
268
  else:
269
- st.warning("❓ No suitable visualization detected. Please refine your query.")
270
  return None
271
 
272
  except Exception as e:
@@ -274,71 +275,6 @@ def generate_visual_from_query(query, df):
274
  return None
275
 
276
 
277
-
278
-
279
-
280
-
281
-
282
-
283
-
284
-
285
-
286
- """def map_query_to_column(query):
287
- query = query.lower()
288
- all_synonyms = {synonym: col for col, synonyms in COLUMN_SYNONYMS.items() for synonym in synonyms}
289
- matches = get_close_matches(query, all_synonyms.keys(), n=1, cutoff=0.6)
290
-
291
- if matches:
292
- return all_synonyms[matches[0]]
293
- else:
294
- for col, synonyms in COLUMN_SYNONYMS.items():
295
- if any(term in query for term in synonyms):
296
- return col
297
- return None"""
298
-
299
-
300
- """# Visualization generator with synonym handling
301
- def generate_visual_from_query(query, df):
302
- try:
303
- query = query.lower()
304
-
305
- # Map user terms to actual dataset columns
306
- col1 = map_query_to_column(query)
307
- col2 = None # For dual-column charts
308
-
309
- # Handle common queries
310
- if "distribution" in query and col1:
311
- fig = px.box(df, x=col1, y="salary_in_usd", title=f"Salary Distribution by {col1.replace('_', ' ').title()}")
312
- return fig
313
-
314
- elif "average salary" in query and col1:
315
- grouped_df = df.groupby(col1)["salary_in_usd"].mean().reset_index()
316
- fig = px.bar(grouped_df, x=col1, y="salary_in_usd", title=f"Average Salary by {col1.replace('_', ' ').title()}")
317
- return fig
318
-
319
- elif "remote" in query:
320
- grouped_df = df.groupby("remote_ratio")["salary_in_usd"].mean().reset_index()
321
- fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", title="Remote Work Impact on Salary")
322
- return fig
323
-
324
- elif "company size" in query or "organization size" in query:
325
- grouped_df = df.groupby("company_size")["salary_in_usd"].mean().reset_index()
326
- fig = px.bar(grouped_df, x="company_size", y="salary_in_usd", title="Salary by Company Size")
327
- return fig
328
-
329
- elif "country" in query or "location" in query:
330
- grouped_df = df.groupby("employee_residence")["salary_in_usd"].mean().reset_index()
331
- fig = px.bar(grouped_df, x="employee_residence", y="salary_in_usd", title="Salary by Employee Residence")
332
- return fig
333
-
334
- else:
335
- st.warning("❓ I couldn't understand the query for visualization. Try asking about salary distribution, experience level, remote work, etc.")
336
- return None
337
-
338
- except Exception as e:
339
- st.error(f"Error generating visualization: {e}")
340
- return None"""
341
-
342
  # SQL-RAG Analysis
343
  if st.session_state.df is not None:
344
  temp_dir = tempfile.TemporaryDirectory()
 
191
  }
192
 
193
 
194
+ # Fuzzy matcher for mapping query terms to dataset columns
195
  def fuzzy_match_columns(query, n=2):
196
  query = query.lower()
197
  all_synonyms = {synonym: col for col, synonyms in COLUMN_SYNONYMS.items() for synonym in synonyms}
198
+
199
+ words = query.replace("and", "").replace("vs", "").replace("by", "").split()
200
 
201
  matched_columns = []
202
  for word in words:
 
204
  for match in matches:
205
  matched_columns.append(all_synonyms[match])
206
 
207
+ return list(dict.fromkeys(matched_columns))
208
+
209
+ # Statistical annotations for plots
210
+ def add_stats_to_figure(fig, df, y_axis):
211
+ min_salary = df[y_axis].min()
212
+ max_salary = df[y_axis].max()
213
+ avg_salary = df[y_axis].mean()
214
+
215
+ fig.add_annotation(
216
+ text=f"Min: ${min_salary:,.2f} | Max: ${max_salary:,.2f} | Avg: ${avg_salary:,.2f}",
217
+ xref="paper", yref="paper",
218
+ x=0.5, y=1.1,
219
+ showarrow=False,
220
+ font=dict(size=12, color="black"),
221
+ bgcolor="rgba(255, 255, 255, 0.7)"
222
+ )
223
+ return fig
224
 
225
+ # Visualization generator
226
  def generate_visual_from_query(query, df):
227
  try:
 
228
  matched_columns = fuzzy_match_columns(query)
229
 
230
+ # Detect and handle multiple grouping columns
231
+ if len(matched_columns) >= 2:
232
+ x_axis, group_by = matched_columns[0], matched_columns[1]
233
+ elif len(matched_columns) == 1:
234
+ x_axis, group_by = matched_columns[0], None
 
 
235
  else:
236
+ st.warning("❓ No matching columns found. Try rephrasing your query.")
237
+ return None
238
 
239
+ # Handle distribution queries
240
+ if "distribution" in query:
241
  fig = px.box(df, x=x_axis, y="salary_in_usd", color=group_by,
242
  title=f"Salary Distribution by {x_axis.replace('_', ' ').title()}"
243
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
244
+ return add_stats_to_figure(fig, df, "salary_in_usd")
245
 
246
+ # Handle average salary queries
247
  elif "average" in query or "mean" in query:
248
  grouped_df = df.groupby([x_axis] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
249
  fig = px.bar(grouped_df, x=x_axis, y="salary_in_usd", color=group_by,
 
250
  title=f"Average Salary by {x_axis.replace('_', ' ').title()}"
251
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
252
+ return add_stats_to_figure(fig, df, "salary_in_usd")
253
 
254
+ # Handle salary trends over time
255
+ elif "trend" in query and "work_year" in df.columns:
256
  grouped_df = df.groupby(["work_year", x_axis])["salary_in_usd"].mean().reset_index()
257
  fig = px.line(grouped_df, x="work_year", y="salary_in_usd", color=x_axis,
258
+ title=f"Salary Trend Over Years by {x_axis.replace('_', ' ').title()}")
259
+ return add_stats_to_figure(fig, df, "salary_in_usd")
260
 
261
+ # Handle remote work queries
262
  elif "remote" in query:
263
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
264
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
265
+ title="Remote Work Impact on Salary")
266
+ return add_stats_to_figure(fig, df, "salary_in_usd")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # Default behavior if query doesn't match anything specific
269
  else:
270
+ st.warning("❓ No suitable visualization generated. Try refining your query.")
271
  return None
272
 
273
  except Exception as e:
 
275
  return None
276
 
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  # SQL-RAG Analysis
279
  if st.session_state.df is not None:
280
  temp_dir = tempfile.TemporaryDirectory()