File size: 2,508 Bytes
10150e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import pandas as pd
import plotly.express as px
import json


# GPT-4o powered function to generate visualization
def ask_gpt4o_for_viz(query, df, llm):
    columns = ', '.join(df.columns)
    prompt = f"""
    Analyze the user query and suggest the best way to visualize the data.
    Query: "{query}"
    Available Columns: {columns}
    Respond in this JSON format:
    {{
      "chart_type": "bar/box/line/scatter",
      "x_axis": "column_name",
      "y_axis": "column_name",
      "group_by": "optional_column_name"
    }}
    """
    
    response = llm.generate(prompt)
    try:
        suggestion = json.loads(response)
        return suggestion
    except json.JSONDecodeError:
        st.error("⚠️ Failed to interpret AI response. Please refine your query.")
        return None

# Visualization generator
def generate_viz(suggestion, df):
    chart_type = suggestion.get("chart_type", "bar")
    x_axis = suggestion.get("x_axis")
    y_axis = suggestion.get("y_axis", "salary_in_usd")
    group_by = suggestion.get("group_by")
    
    if not x_axis or not y_axis:
        st.warning("⚠️ Could not identify required columns.")
        return None
    
    # Generate the specified chart
    if chart_type == "bar":
        fig = px.bar(df, x=x_axis, y=y_axis, color=group_by)
    elif chart_type == "box":
        fig = px.box(df, x=x_axis, y=y_axis, color=group_by)
    elif chart_type == "line":
        fig = px.line(df, x=x_axis, y=y_axis, color=group_by)
    elif chart_type == "scatter":
        fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by)
    else:
        st.warning("⚠️ Unsupported chart type suggested.")
        return None
    
    fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}")
    return fig

# Streamlit App
st.title("πŸ“Š GPT-4o Powered Data Visualization")
uploaded_file = st.file_uploader("Upload CSV File", type=["csv"])
query = st.text_input("Ask a question about the data:")

if uploaded_file:
    df = load_data(uploaded_file)
    st.write("### Dataset Preview", df.head())
    
    if query and st.button("Generate Visualization"):
        llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")  # Initialize GPT-4o
        suggestion = ask_gpt4o_for_viz(query, df, llm)
        if suggestion:
            fig = generate_viz(suggestion, df)
            if fig:
                st.plotly_chart(fig, use_container_width=True)
else:
    st.info("Upload a CSV file to get started.")