sql-rag / dummy_funcs.py
DrishtiSharma's picture
Update dummy_funcs.py
964d389 verified
raw
history blame
8 kB
def ask_gpt4o_for_visualization(query, df, llm):
columns = ', '.join(df.columns)
prompt = f"""
Analyze the query and suggest one or more relevant visualizations.
Query: "{query}"
Available Columns: {columns}
Respond in this JSON format (as a list if multiple suggestions):
[
{{
"chart_type": "bar/box/line/scatter",
"x_axis": "column_name",
"y_axis": "column_name",
"group_by": "optional_column_name"
}}
]
"""
response = llm.generate(prompt)
try:
return json.loads(response)
except json.JSONDecodeError:
st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
return None
def add_stats_to_figure(fig, df, y_axis, chart_type):
"""
Add relevant statistical annotations to the visualization
based on the chart type.
"""
# Check if the y-axis column is numeric
if not pd.api.types.is_numeric_dtype(df[y_axis]):
st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}")
return fig
# Compute statistics for numeric data
min_val = df[y_axis].min()
max_val = df[y_axis].max()
avg_val = df[y_axis].mean()
median_val = df[y_axis].median()
std_dev_val = df[y_axis].std()
# Format the stats for display
stats_text = (
f"πŸ“Š **Statistics**\n\n"
f"- **Min:** ${min_val:,.2f}\n"
f"- **Max:** ${max_val:,.2f}\n"
f"- **Average:** ${avg_val:,.2f}\n"
f"- **Median:** ${median_val:,.2f}\n"
f"- **Std Dev:** ${std_dev_val:,.2f}"
)
# Apply stats only to relevant chart types
if chart_type in ["bar", "line"]:
# Add annotation box for bar and line charts
fig.add_annotation(
text=stats_text,
xref="paper", yref="paper",
x=1.02, y=1,
showarrow=False,
align="left",
font=dict(size=12, color="black"),
bordercolor="gray",
borderwidth=1,
bgcolor="rgba(255, 255, 255, 0.85)"
)
# Add horizontal reference lines
fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right")
fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right")
fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right")
fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right")
elif chart_type == "scatter":
# Add stats annotation only, no lines for scatter plots
fig.add_annotation(
text=stats_text,
xref="paper", yref="paper",
x=1.02, y=1,
showarrow=False,
align="left",
font=dict(size=12, color="black"),
bordercolor="gray",
borderwidth=1,
bgcolor="rgba(255, 255, 255, 0.85)"
)
elif chart_type == "box":
# Box plots inherently show distribution; no extra stats needed
pass
elif chart_type == "pie":
# Pie charts represent proportions, not suitable for stats
st.info("πŸ“Š Pie charts represent proportions. Additional stats are not applicable.")
elif chart_type == "heatmap":
# Heatmaps already reflect data intensity
st.info("πŸ“Š Heatmaps inherently reflect distribution. No additional stats added.")
else:
st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.")
return fig
# Dynamically generate Plotly visualizations based on GPT-4o suggestions
def generate_visualization(suggestion, df):
"""
Generate a Plotly visualization based on GPT-4o's suggestion.
If the Y-axis is missing, infer it intelligently.
"""
chart_type = suggestion.get("chart_type", "bar").lower()
x_axis = suggestion.get("x_axis")
y_axis = suggestion.get("y_axis")
group_by = suggestion.get("group_by")
# Step 1: Infer Y-axis if not provided
if not y_axis:
numeric_columns = df.select_dtypes(include='number').columns.tolist()
# Avoid using the same column for both axes
if x_axis in numeric_columns:
numeric_columns.remove(x_axis)
# Smart guess: prioritize salary or relevant metrics if available
priority_columns = ["salary_in_usd", "income", "earnings", "revenue"]
for col in priority_columns:
if col in numeric_columns:
y_axis = col
break
# Fallback to the first numeric column if no priority columns exist
if not y_axis and numeric_columns:
y_axis = numeric_columns[0]
# Step 2: Validate axes
if not x_axis or not y_axis:
st.warning("⚠️ Unable to determine appropriate columns for visualization.")
return None
# Step 3: Dynamically select the Plotly function
plotly_function = getattr(px, chart_type, None)
if not plotly_function:
st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.")
return None
# Step 4: Prepare dynamic plot arguments
plot_args = {"data_frame": df, "x": x_axis, "y": y_axis}
if group_by and group_by in df.columns:
plot_args["color"] = group_by
try:
# Step 5: Generate the visualization
fig = plotly_function(**plot_args)
fig.update_layout(
title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}",
xaxis_title=x_axis.replace('_', ' ').title(),
yaxis_title=y_axis.replace('_', ' ').title(),
)
# Step 6: Apply statistics intelligently
fig = add_statistics_to_visualization(fig, df, y_axis, chart_type)
return fig
except Exception as e:
st.error(f"⚠️ Failed to generate visualization: {e}")
return None
def generate_multiple_visualizations(suggestions, df):
"""
Generates one or more visualizations based on GPT-4o's suggestions.
Handles both single and multiple suggestions.
"""
visualizations = []
for suggestion in suggestions:
fig = generate_visualization(suggestion, df)
if fig:
# Apply chart-specific statistics
fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"])
visualizations.append(fig)
if not visualizations and suggestions:
st.warning("⚠️ No valid visualization found. Displaying the most relevant one.")
best_suggestion = suggestions[0]
fig = generate_visualization(best_suggestion, df)
fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"])
visualizations.append(fig)
return visualizations
def handle_visualization_suggestions(suggestions, df):
"""
Determines whether to generate a single or multiple visualizations.
"""
visualizations = []
# If multiple suggestions, generate multiple plots
if isinstance(suggestions, list) and len(suggestions) > 1:
visualizations = generate_multiple_visualizations(suggestions, df)
# If only one suggestion, generate a single plot
elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1):
suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions
fig = generate_visualization(suggestion, df)
if fig:
visualizations.append(fig)
# Handle cases when no visualization could be generated
if not visualizations:
st.warning("⚠️ Unable to generate any visualization based on the suggestion.")
# Display all generated visualizations
for fig in visualizations:
st.plotly_chart(fig, use_container_width=True)