DrishtiSharma commited on
Commit
d1f7f7b
Β·
verified Β·
1 Parent(s): 9dc25a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -14
app.py CHANGED
@@ -192,21 +192,45 @@ COLUMN_SYNONYMS = {
192
 
193
 
194
  # Fuzzy matcher for mapping query terms to dataset columns
195
- def fuzzy_match_columns(query, n=2):
196
  query = query.lower()
197
  all_synonyms = {synonym: col for col, synonyms in COLUMN_SYNONYMS.items() for synonym in synonyms}
198
-
199
  words = query.replace("and", "").replace("vs", "").replace("by", "").split()
200
-
201
  matched_columns = []
202
  for word in words:
203
- matches = get_close_matches(word, all_synonyms.keys(), n=n, cutoff=0.6)
204
- for match in matches:
205
- matched_columns.append(all_synonyms[match])
206
 
207
  return list(dict.fromkeys(matched_columns))
208
 
209
- # Statistical annotations for plots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  def add_stats_to_figure(fig, df, y_axis):
211
  min_salary = df[y_axis].min()
212
  max_salary = df[y_axis].max()
@@ -222,18 +246,25 @@ def add_stats_to_figure(fig, df, y_axis):
222
  )
223
  return fig
224
 
225
- # Visualization generator
226
- def generate_visual_from_query(query, df):
227
  try:
228
  matched_columns = fuzzy_match_columns(query)
229
 
230
- # Detect and handle multiple grouping columns
 
 
 
 
 
 
 
231
  if len(matched_columns) >= 2:
232
  x_axis, group_by = matched_columns[0], matched_columns[1]
233
  elif len(matched_columns) == 1:
234
  x_axis, group_by = matched_columns[0], None
235
  else:
236
- st.warning("❓ No matching columns found. Try rephrasing your query.")
237
  return None
238
 
239
  # Handle distribution queries
@@ -258,16 +289,15 @@ def generate_visual_from_query(query, df):
258
  title=f"Salary Trend Over Years by {x_axis.replace('_', ' ').title()}")
259
  return add_stats_to_figure(fig, df, "salary_in_usd")
260
 
261
- # Handle remote work queries
262
  elif "remote" in query:
263
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
264
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
265
  title="Remote Work Impact on Salary")
266
  return add_stats_to_figure(fig, df, "salary_in_usd")
267
 
268
- # Default behavior if query doesn't match anything specific
269
  else:
270
- st.warning("❓ No suitable visualization generated. Try refining your query.")
271
  return None
272
 
273
  except Exception as e:
@@ -275,6 +305,7 @@ def generate_visual_from_query(query, df):
275
  return None
276
 
277
 
 
278
  # SQL-RAG Analysis
279
  if st.session_state.df is not None:
280
  temp_dir = tempfile.TemporaryDirectory()
 
192
 
193
 
194
  # Fuzzy matcher for mapping query terms to dataset columns
195
+ def fuzzy_match_columns(query):
196
  query = query.lower()
197
  all_synonyms = {synonym: col for col, synonyms in COLUMN_SYNONYMS.items() for synonym in synonyms}
 
198
  words = query.replace("and", "").replace("vs", "").replace("by", "").split()
199
+
200
  matched_columns = []
201
  for word in words:
202
+ matches = get_close_matches(word, all_synonyms.keys(), n=1, cutoff=0.6)
203
+ matched_columns.extend([all_synonyms[match] for match in matches])
 
204
 
205
  return list(dict.fromkeys(matched_columns))
206
 
207
+ # Ask LLM to suggest relevant columns if fuzzy matching fails
208
+ def ask_llm_for_columns(query, llm, df):
209
+ columns = ', '.join(df.columns)
210
+ prompt = f"""
211
+ Analyze this user query and suggest the most relevant dataset columns for visualization.
212
+
213
+ Query: "{query}"
214
+
215
+ Available Columns: {columns}
216
+
217
+ Respond in this JSON format:
218
+ {{
219
+ "x_axis": "column_name",
220
+ "y_axis": "column_name",
221
+ "group_by": "optional_column_name"
222
+ }}
223
+ """
224
+
225
+ response = llm.generate(prompt)
226
+ try:
227
+ suggestion = json.loads(response)
228
+ return suggestion
229
+ except json.JSONDecodeError:
230
+ st.error("⚠️ Failed to interpret AI response. Please refine your query.")
231
+ return None
232
+
233
+ # Add min, max, and average salary annotations to the chart
234
  def add_stats_to_figure(fig, df, y_axis):
235
  min_salary = df[y_axis].min()
236
  max_salary = df[y_axis].max()
 
246
  )
247
  return fig
248
 
249
+ # Unified visualization function with LLM fallback
250
+ def generate_visual_from_query(query, df, llm=None):
251
  try:
252
  matched_columns = fuzzy_match_columns(query)
253
 
254
+ # Fallback to LLM if fuzzy matching fails
255
+ if not matched_columns and llm:
256
+ st.info("πŸ€– No match found. Asking AI for suggestions...")
257
+ suggestion = ask_llm_for_columns(query, llm, df)
258
+ if suggestion:
259
+ matched_columns = [suggestion.get("x_axis"), suggestion.get("group_by")]
260
+
261
+ # Handle cases when we have columns to plot
262
  if len(matched_columns) >= 2:
263
  x_axis, group_by = matched_columns[0], matched_columns[1]
264
  elif len(matched_columns) == 1:
265
  x_axis, group_by = matched_columns[0], None
266
  else:
267
+ st.warning("❓ No matching columns found. Please refine your query.")
268
  return None
269
 
270
  # Handle distribution queries
 
289
  title=f"Salary Trend Over Years by {x_axis.replace('_', ' ').title()}")
290
  return add_stats_to_figure(fig, df, "salary_in_usd")
291
 
292
+ # Handle remote work impact
293
  elif "remote" in query:
294
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
295
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
296
  title="Remote Work Impact on Salary")
297
  return add_stats_to_figure(fig, df, "salary_in_usd")
298
 
 
299
  else:
300
+ st.warning("⚠️ No suitable visualization generated. Please refine your query.")
301
  return None
302
 
303
  except Exception as e:
 
305
  return None
306
 
307
 
308
+
309
  # SQL-RAG Analysis
310
  if st.session_state.df is not None:
311
  temp_dir = tempfile.TemporaryDirectory()