EditsPaarth commited on
Commit
b9123fb
·
verified ·
1 Parent(s): 8215549
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import seaborn as sns
5
+ import matplotlib.pyplot as plt
6
+ import tempfile
7
+ import subprocess
8
+ from groq import Groq
9
+
10
+ # Groq API Key setup
11
+ GROQ_API_KEY = "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"
12
+ client = Groq(api_key=GROQ_API_KEY)
13
+
14
+ # Groq Chat Function
15
+ def chat_with_groq(prompt):
16
+ try:
17
+ chat_completion = client.chat.completions.create(
18
+ messages=[{"role": "user", "content": prompt}],
19
+ model="gemma-7b-it",
20
+ stream=False
21
+ )
22
+ print(prompt)
23
+ return chat_completion.choices[0].message.content
24
+ except Exception as e:
25
+ return f"Error fetching response: {e}"
26
+
27
+ def generate_code_with_groq(prompt):
28
+ try:
29
+ chat_completion = client.chat.completions.create(
30
+ messages=[{"role": "user", "content": prompt}, {"role": "assistant", "content": "```python"}],
31
+ model="gemma-7b-it",
32
+ stream=False,
33
+ stop="```"
34
+ )
35
+ return chat_completion.choices[0].message.content
36
+ except Exception as e:
37
+ return f"Error fetching response: {e}"
38
+
39
+ # File Parsing Functions
40
+ def parse_file(uploaded_file):
41
+ filename = uploaded_file.name
42
+ if filename.endswith('.csv'):
43
+ return pd.read_csv(uploaded_file)
44
+ elif filename.endswith('.xlsx'):
45
+ return pd.read_excel(uploaded_file)
46
+ else:
47
+ st.error("Unsupported file type! Only CSV and Excel are supported.")
48
+ return None
49
+
50
+ # Preprocess DataFrame to Fix Type Issues
51
+ def preprocess_dataframe(df):
52
+ try:
53
+ # Convert problematic columns to string to avoid Arrow serialization issues
54
+ for col in df.columns:
55
+ if df[col].dtype.name == 'object' or df[col].dtype.name == 'category':
56
+ df[col] = df[col].astype(str)
57
+ return df
58
+ except Exception as e:
59
+ st.error(f"Error preprocessing data: {e}")
60
+ return None
61
+
62
+ # Analysis Function
63
+ def analyze_data(data, visualization_type, class_size=10):
64
+ st.subheader("Basic Analysis")
65
+ st.write("Shape of Data:", data.shape)
66
+ st.write("Data Types:")
67
+ st.write(data.dtypes)
68
+
69
+ # Combine numerical and non-numerical summaries
70
+ st.write("Summary Statistics:")
71
+ combined_stats = pd.concat(
72
+ [
73
+ data.describe(include=[np.number]),
74
+ data.describe(include=['object', 'category'])
75
+ ],
76
+ axis=1
77
+ )
78
+ st.write(combined_stats)
79
+
80
+ numeric_data = data.select_dtypes(include=[np.number])
81
+
82
+ # Visualization logic
83
+ if visualization_type == "Heatmap" and not numeric_data.empty:
84
+ st.subheader("Correlation Heatmap")
85
+ fig, ax = plt.subplots(figsize=(8, 6))
86
+ sns.heatmap(numeric_data.corr(), annot=True, ax=ax, cmap="coolwarm", fmt=".2f")
87
+ st.pyplot(fig)
88
+
89
+ elif visualization_type == "Bar Chart" and not numeric_data.empty:
90
+ st.subheader("Bar Chart")
91
+ x_col = st.selectbox("Select the X-axis column for the Bar Chart:", data.columns)
92
+ y_col = st.selectbox("Select the Y-axis column for the Bar Chart:", data.columns)
93
+
94
+ fig, ax = plt.subplots(figsize=(8, 6))
95
+ data.groupby(x_col)[y_col].sum().plot(kind='bar', ax=ax)
96
+ ax.set_xlabel(x_col)
97
+ ax.set_ylabel(y_col)
98
+ st.pyplot(fig)
99
+
100
+ elif visualization_type == "Line Graph" and not numeric_data.empty:
101
+ st.subheader("Line Graph")
102
+ x_col = st.selectbox("Select the X-axis column for the Line Graph:", numeric_data.columns)
103
+ y_col = st.selectbox("Select the Y-axis column for the Line Graph:", numeric_data.columns)
104
+
105
+ fig, ax = plt.subplots(figsize=(8, 6))
106
+ ax.plot(data[x_col], data[y_col])
107
+ ax.set_xlabel(x_col)
108
+ ax.set_ylabel(y_col)
109
+ st.pyplot(fig)
110
+
111
+ elif visualization_type == "Scatter Plot" and not numeric_data.empty:
112
+ st.subheader("Scatter Plot")
113
+ x_col = st.selectbox("Select the X-axis column for the Scatter Plot:", numeric_data.columns)
114
+ y_col = st.selectbox("Select the Y-axis column for the Scatter Plot:", numeric_data.columns)
115
+
116
+ fig, ax = plt.subplots(figsize=(8, 6))
117
+ ax.scatter(data[x_col], data[y_col])
118
+ ax.set_xlabel(x_col)
119
+ ax.set_ylabel(y_col)
120
+ st.pyplot(fig)
121
+
122
+ elif visualization_type == "Histogram" and not numeric_data.empty:
123
+ st.subheader("Histogram")
124
+ column = st.selectbox("Select a column for the Histogram:", numeric_data.columns)
125
+ fig, ax = plt.subplots(figsize=(8, 6))
126
+ data[column].plot(kind='hist', bins=class_size, ax=ax)
127
+ ax.set_xlabel(column)
128
+ ax.set_ylabel("Frequency")
129
+ st.pyplot(fig)
130
+
131
+ elif visualization_type == "Area Chart" and not numeric_data.empty:
132
+ st.subheader("Area Chart")
133
+ column = st.selectbox("Select a column for the Area Chart:", numeric_data.columns)
134
+ fig, ax = plt.subplots(figsize=(8, 6))
135
+ data[column].plot(kind='area', ax=ax)
136
+ ax.set_xlabel(column)
137
+ ax.set_ylabel("Area")
138
+ st.pyplot(fig)
139
+
140
+ else:
141
+ st.warning("No valid visualization option selected or data available.")
142
+
143
+ # Automatically generate a prompt for Groq based on the analysis
144
+ prompt = generate_groq_prompt(data, visualization_type, class_size)
145
+ return prompt
146
+
147
+ # Function to generate a prompt based on the data analysis
148
+ def generate_groq_prompt(data, visualization_type, class_size):
149
+ # Convert DataFrame to a string without the index
150
+ data_without_index = data.to_string(index=False)
151
+
152
+ prompt = f"""
153
+ Here is the summary statistics for the dataset:
154
+ {data_without_index}
155
+
156
+ The user has selected the '{visualization_type}' visualization type with a class size of {class_size}.
157
+ Please generate Python code that does this and for any data, please don't use any file input. Write the data in the code.
158
+ """
159
+
160
+ return prompt
161
+
162
+ # Streamlit App
163
+ st.title("Data Analysis AI")
164
+ st.markdown("Upload a file (CSV or Excel) to analyze it.")
165
+
166
+ uploaded_file = st.file_uploader("Choose a file", type=['csv', 'xlsx'])
167
+
168
+ if uploaded_file is not None:
169
+ try:
170
+ data = parse_file(uploaded_file)
171
+ if data is not None:
172
+ data = preprocess_dataframe(data) # Fix serialization issues
173
+ st.subheader("Uploaded Data")
174
+ st.write(data.head())
175
+
176
+ # Visualization Selection
177
+ visualization_type = st.selectbox(
178
+ "Select a visualization type:",
179
+ ["Heatmap", "Bar Chart", "Line Graph", "Scatter Plot", "Histogram", "Area Chart"]
180
+ )
181
+
182
+ # User input for class size customization
183
+ class_size = st.slider("Select the class size for certain plots (e.g., Histogram)", 5, 50, 10)
184
+
185
+ # Perform Analysis and Visualization
186
+ prompt = analyze_data(data, visualization_type, class_size)
187
+ st.text(f"Prompt sent to Groq:\n{prompt}")
188
+
189
+ # Chat with Groq Section
190
+ st.subheader("Chat with Groq")
191
+ chat_input = st.text_area("Ask Groq questions about the data:")
192
+ if st.button("Chat"):
193
+ if chat_input:
194
+ chat_response = chat_with_groq(f"Here is the data:\n{data}\n\n{chat_input}")
195
+ st.write("Groq's Response:")
196
+ st.write(chat_response)
197
+
198
+ # Groq Code Generation Section
199
+ st.subheader("Generate Python Code with Groq")
200
+ prompt_input = st.text_area("Describe the analysis or visualization you want to generate code for:")
201
+ if st.button("Generate Code"):
202
+ if prompt_input:
203
+ prompt += f"\n\nUser request: {prompt_input}"
204
+ response = generate_code_with_groq(prompt)
205
+
206
+ # Display the Groq response
207
+ st.subheader("Generated Code")
208
+ st.code(response, language="python")
209
+ except Exception as e:
210
+ st.error(f"An error occurred: {e}")