Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import json | |
# GPT-4o powered function to generate visualization | |
def ask_gpt4o_for_viz(query, df, llm): | |
columns = ', '.join(df.columns) | |
prompt = f""" | |
Analyze the user query and suggest the best way to visualize the data. | |
Query: "{query}" | |
Available Columns: {columns} | |
Respond in this JSON format: | |
{{ | |
"chart_type": "bar/box/line/scatter", | |
"x_axis": "column_name", | |
"y_axis": "column_name", | |
"group_by": "optional_column_name" | |
}} | |
""" | |
response = llm.generate(prompt) | |
try: | |
suggestion = json.loads(response) | |
return suggestion | |
except json.JSONDecodeError: | |
st.error("β οΈ Failed to interpret AI response. Please refine your query.") | |
return None | |
# Visualization generator | |
def generate_viz(suggestion, df): | |
chart_type = suggestion.get("chart_type", "bar") | |
x_axis = suggestion.get("x_axis") | |
y_axis = suggestion.get("y_axis", "salary_in_usd") | |
group_by = suggestion.get("group_by") | |
if not x_axis or not y_axis: | |
st.warning("β οΈ Could not identify required columns.") | |
return None | |
# Generate the specified chart | |
if chart_type == "bar": | |
fig = px.bar(df, x=x_axis, y=y_axis, color=group_by) | |
elif chart_type == "box": | |
fig = px.box(df, x=x_axis, y=y_axis, color=group_by) | |
elif chart_type == "line": | |
fig = px.line(df, x=x_axis, y=y_axis, color=group_by) | |
elif chart_type == "scatter": | |
fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by) | |
else: | |
st.warning("β οΈ Unsupported chart type suggested.") | |
return None | |
fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}") | |
return fig | |
# Streamlit App | |
st.title("π GPT-4o Powered Data Visualization") | |
uploaded_file = st.file_uploader("Upload CSV File", type=["csv"]) | |
query = st.text_input("Ask a question about the data:") | |
if uploaded_file: | |
df = load_data(uploaded_file) | |
st.write("### Dataset Preview", df.head()) | |
if query and st.button("Generate Visualization"): | |
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") # Initialize GPT-4o | |
suggestion = ask_gpt4o_for_viz(query, df, llm) | |
if suggestion: | |
fig = generate_viz(suggestion, df) | |
if fig: | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.info("Upload a CSV file to get started.") |