sql-rag / mylab /gpt4o_based_sol.py
DrishtiSharma's picture
Update mylab/gpt4o_based_sol.py
10150e1 verified
raw
history blame
2.51 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import json
# GPT-4o powered function to generate visualization
def ask_gpt4o_for_viz(query, df, llm):
columns = ', '.join(df.columns)
prompt = f"""
Analyze the user query and suggest the best way to visualize the data.
Query: "{query}"
Available Columns: {columns}
Respond in this JSON format:
{{
"chart_type": "bar/box/line/scatter",
"x_axis": "column_name",
"y_axis": "column_name",
"group_by": "optional_column_name"
}}
"""
response = llm.generate(prompt)
try:
suggestion = json.loads(response)
return suggestion
except json.JSONDecodeError:
st.error("⚠️ Failed to interpret AI response. Please refine your query.")
return None
# Visualization generator
def generate_viz(suggestion, df):
chart_type = suggestion.get("chart_type", "bar")
x_axis = suggestion.get("x_axis")
y_axis = suggestion.get("y_axis", "salary_in_usd")
group_by = suggestion.get("group_by")
if not x_axis or not y_axis:
st.warning("⚠️ Could not identify required columns.")
return None
# Generate the specified chart
if chart_type == "bar":
fig = px.bar(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "box":
fig = px.box(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "line":
fig = px.line(df, x=x_axis, y=y_axis, color=group_by)
elif chart_type == "scatter":
fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by)
else:
st.warning("⚠️ Unsupported chart type suggested.")
return None
fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}")
return fig
# Streamlit App
st.title("πŸ“Š GPT-4o Powered Data Visualization")
uploaded_file = st.file_uploader("Upload CSV File", type=["csv"])
query = st.text_input("Ask a question about the data:")
if uploaded_file:
df = load_data(uploaded_file)
st.write("### Dataset Preview", df.head())
if query and st.button("Generate Visualization"):
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") # Initialize GPT-4o
suggestion = ask_gpt4o_for_viz(query, df, llm)
if suggestion:
fig = generate_viz(suggestion, df)
if fig:
st.plotly_chart(fig, use_container_width=True)
else:
st.info("Upload a CSV file to get started.")