Spaces:
Running
Running
File size: 2,508 Bytes
10150e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
import pandas as pd
import plotly.express as px
import json
# GPT-4o powered function to generate visualization
def ask_gpt4o_for_viz(query, df, llm):
columns = ', '.join(df.columns)
prompt = f"""
Analyze the user query and suggest the best way to visualize the data.
Query: "{query}"
Available Columns: {columns}
Respond in this JSON format:
{{
"chart_type": "bar/box/line/scatter",
"x_axis": "column_name",
"y_axis": "column_name",
"group_by": "optional_column_name"
}}
"""
response = llm.generate(prompt)
try:
suggestion = json.loads(response)
return suggestion
except json.JSONDecodeError:
st.error("β οΈ Failed to interpret AI response. Please refine your query.")
return None
# Visualization generator
def generate_viz(suggestion, df):
chart_type = suggestion.get("chart_type", "bar")
x_axis = suggestion.get("x_axis")
y_axis = suggestion.get("y_axis", "salary_in_usd")
group_by = suggestion.get("group_by")
if not x_axis or not y_axis:
st.warning("β οΈ Could not identify required columns.")
return None
# Generate the specified chart
if chart_type == "bar":
fig = px.bar(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "box":
fig = px.box(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "line":
fig = px.line(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "scatter":
fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by)
else:
st.warning("β οΈ Unsupported chart type suggested.")
return None
fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}")
return fig
# Streamlit App
st.title("π GPT-4o Powered Data Visualization")
uploaded_file = st.file_uploader("Upload CSV File", type=["csv"])
query = st.text_input("Ask a question about the data:")
if uploaded_file:
df = load_data(uploaded_file)
st.write("### Dataset Preview", df.head())
if query and st.button("Generate Visualization"):
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") # Initialize GPT-4o
suggestion = ask_gpt4o_for_viz(query, df, llm)
if suggestion:
fig = generate_viz(suggestion, df)
if fig:
st.plotly_chart(fig, use_container_width=True)
else:
st.info("Upload a CSV file to get started.") |