import streamlit as st import pandas as pd import plotly.express as px import json # GPT-4o powered function to generate visualization def ask_gpt4o_for_viz(query, df, llm): columns = ', '.join(df.columns) prompt = f""" Analyze the user query and suggest the best way to visualize the data. Query: "{query}" Available Columns: {columns} Respond in this JSON format: {{ "chart_type": "bar/box/line/scatter", "x_axis": "column_name", "y_axis": "column_name", "group_by": "optional_column_name" }} """ response = llm.generate(prompt) try: suggestion = json.loads(response) return suggestion except json.JSONDecodeError: st.error("⚠️ Failed to interpret AI response. Please refine your query.") return None # Visualization generator def generate_viz(suggestion, df): chart_type = suggestion.get("chart_type", "bar") x_axis = suggestion.get("x_axis") y_axis = suggestion.get("y_axis", "salary_in_usd") group_by = suggestion.get("group_by") if not x_axis or not y_axis: st.warning("⚠️ Could not identify required columns.") return None # Generate the specified chart if chart_type == "bar": fig = px.bar(df, x=x_axis, y=y_axis, color=group_by) elif chart_type == "box": fig = px.box(df, x=x_axis, y=y_axis, color=group_by) elif chart_type == "line": fig = px.line(df, x=x_axis, y=y_axis, color=group_by) elif chart_type == "scatter": fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by) else: st.warning("⚠️ Unsupported chart type suggested.") return None fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}") return fig # Streamlit App st.title("📊 GPT-4o Powered Data Visualization") uploaded_file = st.file_uploader("Upload CSV File", type=["csv"]) query = st.text_input("Ask a question about the data:") if uploaded_file: df = load_data(uploaded_file) st.write("### Dataset Preview", df.head()) if query and st.button("Generate Visualization"): llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") # Initialize GPT-4o suggestion = ask_gpt4o_for_viz(query, df, llm) if suggestion: fig = generate_viz(suggestion, df) if fig: st.plotly_chart(fig, use_container_width=True) else: st.info("Upload a CSV file to get started.")