DrishtiSharma commited on
Commit
10150e1
Β·
verified Β·
1 Parent(s): 73e6a7e

Update mylab/gpt4o_based_sol.py

Browse files
Files changed (1) hide show
  1. mylab/gpt4o_based_sol.py +75 -0
mylab/gpt4o_based_sol.py CHANGED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import json
5
+
6
+
7
+ # GPT-4o powered function to generate visualization
8
+ def ask_gpt4o_for_viz(query, df, llm):
9
+ columns = ', '.join(df.columns)
10
+ prompt = f"""
11
+ Analyze the user query and suggest the best way to visualize the data.
12
+ Query: "{query}"
13
+ Available Columns: {columns}
14
+ Respond in this JSON format:
15
+ {{
16
+ "chart_type": "bar/box/line/scatter",
17
+ "x_axis": "column_name",
18
+ "y_axis": "column_name",
19
+ "group_by": "optional_column_name"
20
+ }}
21
+ """
22
+
23
+ response = llm.generate(prompt)
24
+ try:
25
+ suggestion = json.loads(response)
26
+ return suggestion
27
+ except json.JSONDecodeError:
28
+ st.error("⚠️ Failed to interpret AI response. Please refine your query.")
29
+ return None
30
+
31
+ # Visualization generator
32
+ def generate_viz(suggestion, df):
33
+ chart_type = suggestion.get("chart_type", "bar")
34
+ x_axis = suggestion.get("x_axis")
35
+ y_axis = suggestion.get("y_axis", "salary_in_usd")
36
+ group_by = suggestion.get("group_by")
37
+
38
+ if not x_axis or not y_axis:
39
+ st.warning("⚠️ Could not identify required columns.")
40
+ return None
41
+
42
+ # Generate the specified chart
43
+ if chart_type == "bar":
44
+ fig = px.bar(df, x=x_axis, y=y_axis, color=group_by)
45
+ elif chart_type == "box":
46
+ fig = px.box(df, x=x_axis, y=y_axis, color=group_by)
47
+ elif chart_type == "line":
48
+ fig = px.line(df, x=x_axis, y=y_axis, color=group_by)
49
+ elif chart_type == "scatter":
50
+ fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by)
51
+ else:
52
+ st.warning("⚠️ Unsupported chart type suggested.")
53
+ return None
54
+
55
+ fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}")
56
+ return fig
57
+
58
+ # Streamlit App
59
+ st.title("πŸ“Š GPT-4o Powered Data Visualization")
60
+ uploaded_file = st.file_uploader("Upload CSV File", type=["csv"])
61
+ query = st.text_input("Ask a question about the data:")
62
+
63
+ if uploaded_file:
64
+ df = load_data(uploaded_file)
65
+ st.write("### Dataset Preview", df.head())
66
+
67
+ if query and st.button("Generate Visualization"):
68
+ llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") # Initialize GPT-4o
69
+ suggestion = ask_gpt4o_for_viz(query, df, llm)
70
+ if suggestion:
71
+ fig = generate_viz(suggestion, df)
72
+ if fig:
73
+ st.plotly_chart(fig, use_container_width=True)
74
+ else:
75
+ st.info("Upload a CSV file to get started.")