Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,11 @@ import pandas as pd
|
|
3 |
import numpy as np
|
4 |
import seaborn as sns
|
5 |
import matplotlib.pyplot as plt
|
|
|
|
|
6 |
from groq import Groq
|
7 |
|
8 |
-
# Groq API Key setup
|
9 |
GROQ_API_KEY = "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"
|
10 |
client = Groq(api_key=GROQ_API_KEY)
|
11 |
|
@@ -17,6 +19,7 @@ def chat_with_groq(prompt):
|
|
17 |
model="llama3-8b-8192",
|
18 |
stream=False
|
19 |
)
|
|
|
20 |
return chat_completion.choices[0].message.content
|
21 |
except Exception as e:
|
22 |
return f"Error fetching response: {e}"
|
@@ -24,10 +27,12 @@ def chat_with_groq(prompt):
|
|
24 |
def generate_code_with_groq(prompt):
|
25 |
try:
|
26 |
chat_completion = client.chat.completions.create(
|
27 |
-
messages=[{"role": "user", "content": prompt}, {"role": "assistant", "content": "
|
|
|
28 |
model="gemma-7b-it",
|
29 |
stream=False,
|
30 |
-
stop="
|
|
|
31 |
)
|
32 |
return chat_completion.choices[0].message.content
|
33 |
except Exception as e:
|
@@ -47,7 +52,7 @@ def parse_file(uploaded_file):
|
|
47 |
# Preprocess DataFrame to Fix Type Issues
|
48 |
def preprocess_dataframe(df):
|
49 |
try:
|
50 |
-
# Convert problematic columns to string to avoid serialization issues
|
51 |
for col in df.columns:
|
52 |
if df[col].dtype.name == 'object' or df[col].dtype.name == 'category':
|
53 |
df[col] = df[col].astype(str)
|
@@ -82,6 +87,9 @@ def analyze_data(data, visualization_type, class_size=10):
|
|
82 |
fig, ax = plt.subplots(figsize=(8, 6))
|
83 |
sns.heatmap(numeric_data.corr(), annot=True, ax=ax, cmap="coolwarm", fmt=".2f")
|
84 |
st.pyplot(fig)
|
|
|
|
|
|
|
85 |
elif visualization_type == "Line Graph" and not numeric_data.empty:
|
86 |
st.subheader("Line Graph")
|
87 |
x_col = st.selectbox("Select the X-axis column for the Line Graph (Non-Numeric):", numeric_data.columns)
|
@@ -92,6 +100,10 @@ def analyze_data(data, visualization_type, class_size=10):
|
|
92 |
ax.set_xlabel(x_col)
|
93 |
ax.set_ylabel(y_col)
|
94 |
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
95 |
elif visualization_type == "Area Chart" and not numeric_data.empty:
|
96 |
st.subheader("Area Chart")
|
97 |
column = st.selectbox("Select a column for the Area Chart:", numeric_data.columns)
|
@@ -100,6 +112,7 @@ def analyze_data(data, visualization_type, class_size=10):
|
|
100 |
ax.set_xlabel(column)
|
101 |
ax.set_ylabel("Area")
|
102 |
st.pyplot(fig)
|
|
|
103 |
else:
|
104 |
st.warning("No valid visualization option selected or data available.")
|
105 |
|
@@ -107,39 +120,21 @@ def analyze_data(data, visualization_type, class_size=10):
|
|
107 |
prompt = generate_groq_prompt(data, visualization_type, class_size)
|
108 |
return prompt
|
109 |
|
110 |
-
#
|
111 |
def generate_groq_prompt(data, visualization_type, class_size):
|
112 |
-
#
|
113 |
-
|
114 |
-
|
115 |
-
# Compute column widths for alignment
|
116 |
-
column_widths = [max(len(str(val)) for val in [col] + data_snippet[col].tolist()) for col in data_snippet.columns]
|
117 |
-
|
118 |
-
# Create the table header with spacing
|
119 |
-
table_header = " ".join(f"{col:<{column_widths[i]}}" for i, col in enumerate(data_snippet.columns))
|
120 |
-
|
121 |
-
# Create the rows with spacing
|
122 |
-
table_rows = "\n".join(
|
123 |
-
" " + " ".join(f"{str(val):<{column_widths[i]}}" for i, val in enumerate(row)) # Add indentation
|
124 |
-
for row in data_snippet.values
|
125 |
-
)
|
126 |
-
|
127 |
-
# Combine header and rows into a single table
|
128 |
-
formatted_table = f" {table_header}\n{table_rows}"
|
129 |
-
|
130 |
-
# Create the textual prompt
|
131 |
prompt = f"""
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
{formatted_table}
|
136 |
|
137 |
-
|
|
|
138 |
"""
|
139 |
|
140 |
return prompt
|
141 |
|
142 |
-
|
143 |
# Streamlit App
|
144 |
st.title("Data Analysis AI")
|
145 |
st.markdown("Upload a file (CSV or Excel) to analyze it.")
|
@@ -152,7 +147,7 @@ if uploaded_file is not None:
|
|
152 |
if data is not None:
|
153 |
data = preprocess_dataframe(data) # Fix serialization issues
|
154 |
st.subheader("Uploaded Data")
|
155 |
-
st.
|
156 |
|
157 |
# Visualization Selection
|
158 |
visualization_type = st.selectbox(
|
@@ -163,14 +158,18 @@ if uploaded_file is not None:
|
|
163 |
# User input for class size customization
|
164 |
class_size = st.slider("Select the class size for certain plots (e.g., Histogram)", 5, 50, 10)
|
165 |
|
166 |
-
# Perform Analysis and
|
167 |
prompt = analyze_data(data, visualization_type, class_size)
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
st.
|
173 |
-
st.
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Groq Code Generation Section
|
176 |
st.subheader("Generate Python Code with Groq")
|
|
|
3 |
import numpy as np
|
4 |
import seaborn as sns
|
5 |
import matplotlib.pyplot as plt
|
6 |
+
import tempfile
|
7 |
+
import subprocess
|
8 |
from groq import Groq
|
9 |
|
10 |
+
# Groq API Key setup .
|
11 |
GROQ_API_KEY = "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"
|
12 |
client = Groq(api_key=GROQ_API_KEY)
|
13 |
|
|
|
19 |
model="llama3-8b-8192",
|
20 |
stream=False
|
21 |
)
|
22 |
+
print(prompt)
|
23 |
return chat_completion.choices[0].message.content
|
24 |
except Exception as e:
|
25 |
return f"Error fetching response: {e}"
|
|
|
27 |
def generate_code_with_groq(prompt):
|
28 |
try:
|
29 |
chat_completion = client.chat.completions.create(
|
30 |
+
messages=[{"role": "user", "content": prompt}, {"role": "assistant", "content": "
|
31 |
+
python"}],
|
32 |
model="gemma-7b-it",
|
33 |
stream=False,
|
34 |
+
stop="
|
35 |
+
"
|
36 |
)
|
37 |
return chat_completion.choices[0].message.content
|
38 |
except Exception as e:
|
|
|
52 |
# Preprocess DataFrame to Fix Type Issues
|
53 |
def preprocess_dataframe(df):
|
54 |
try:
|
55 |
+
# Convert problematic columns to string to avoid Arrow serialization issues
|
56 |
for col in df.columns:
|
57 |
if df[col].dtype.name == 'object' or df[col].dtype.name == 'category':
|
58 |
df[col] = df[col].astype(str)
|
|
|
87 |
fig, ax = plt.subplots(figsize=(8, 6))
|
88 |
sns.heatmap(numeric_data.corr(), annot=True, ax=ax, cmap="coolwarm", fmt=".2f")
|
89 |
st.pyplot(fig)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
elif visualization_type == "Line Graph" and not numeric_data.empty:
|
94 |
st.subheader("Line Graph")
|
95 |
x_col = st.selectbox("Select the X-axis column for the Line Graph (Non-Numeric):", numeric_data.columns)
|
|
|
100 |
ax.set_xlabel(x_col)
|
101 |
ax.set_ylabel(y_col)
|
102 |
st.pyplot(fig)
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
elif visualization_type == "Area Chart" and not numeric_data.empty:
|
108 |
st.subheader("Area Chart")
|
109 |
column = st.selectbox("Select a column for the Area Chart:", numeric_data.columns)
|
|
|
112 |
ax.set_xlabel(column)
|
113 |
ax.set_ylabel("Area")
|
114 |
st.pyplot(fig)
|
115 |
+
|
116 |
else:
|
117 |
st.warning("No valid visualization option selected or data available.")
|
118 |
|
|
|
120 |
prompt = generate_groq_prompt(data, visualization_type, class_size)
|
121 |
return prompt
|
122 |
|
123 |
+
# Function to generate a prompt based on the data analysis
|
124 |
def generate_groq_prompt(data, visualization_type, class_size):
|
125 |
+
# Convert DataFrame to a string without the index
|
126 |
+
data_without_index = data.to_string(index=False)
|
127 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
prompt = f"""
|
129 |
+
Here is the summary statistics for the dataset:
|
130 |
+
{data_without_index}
|
|
|
|
|
131 |
|
132 |
+
The user has selected the '{visualization_type}' visualization type with a class size of {class_size}.
|
133 |
+
Please generate Python code that does this and for any data, please don't use any file input. Write the data in the code.
|
134 |
"""
|
135 |
|
136 |
return prompt
|
137 |
|
|
|
138 |
# Streamlit App
|
139 |
st.title("Data Analysis AI")
|
140 |
st.markdown("Upload a file (CSV or Excel) to analyze it.")
|
|
|
147 |
if data is not None:
|
148 |
data = preprocess_dataframe(data) # Fix serialization issues
|
149 |
st.subheader("Uploaded Data")
|
150 |
+
st.write(data.head())
|
151 |
|
152 |
# Visualization Selection
|
153 |
visualization_type = st.selectbox(
|
|
|
158 |
# User input for class size customization
|
159 |
class_size = st.slider("Select the class size for certain plots (e.g., Histogram)", 5, 50, 10)
|
160 |
|
161 |
+
# Perform Analysis and Visualization
|
162 |
prompt = analyze_data(data, visualization_type, class_size)
|
163 |
+
st.text(f"Prompt sent to Groq:\n{prompt}")
|
164 |
+
|
165 |
+
# Chat with Groq Section
|
166 |
+
st.subheader("Chat with Groq")
|
167 |
+
chat_input = st.text_area("Ask Groq questions about the data:")
|
168 |
+
if st.button("Chat"):
|
169 |
+
if chat_input:
|
170 |
+
chat_response = chat_with_groq(f"Here is the data:\n{data}\n\n{chat_input}")
|
171 |
+
st.write("Groq's Response:")
|
172 |
+
st.write(chat_response)
|
173 |
|
174 |
# Groq Code Generation Section
|
175 |
st.subheader("Generate Python Code with Groq")
|