Update app.py
Browse files
app.py
CHANGED
@@ -1,60 +1,93 @@
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
-
from langchain.tools import tool
|
5 |
-
from langchain.agents import initialize_agent, AgentType
|
6 |
-
from langchain.chat_models import ChatOpenAI
|
7 |
-
from typing import Union, List, Dict, Optional
|
8 |
import matplotlib.pyplot as plt
|
9 |
import seaborn as sns
|
10 |
import os
|
11 |
import base64
|
12 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@tool
|
19 |
-
def
|
20 |
-
"""
|
21 |
|
22 |
Args:
|
23 |
-
data (pd.DataFrame):
|
24 |
-
|
25 |
Returns:
|
26 |
-
|
27 |
-
including mean, median, standard deviation, skewness, and missing value counts.
|
28 |
"""
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
'median': float(data[col].median()),
|
36 |
-
'std': float(data[col].std()),
|
37 |
-
'skew': float(data[col].skew()),
|
38 |
-
'missing': int(data[col].isnull().sum())
|
39 |
}
|
40 |
-
|
41 |
-
return
|
42 |
|
43 |
@tool
|
44 |
-
def
|
45 |
-
"""Generate
|
46 |
|
47 |
Args:
|
48 |
-
data (pd.DataFrame):
|
49 |
-
|
|
|
50 |
Returns:
|
51 |
-
str:
|
52 |
"""
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
buf = io.BytesIO()
|
60 |
plt.savefig(buf, format='png')
|
@@ -62,136 +95,122 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
|
|
62 |
return base64.b64encode(buf.getvalue()).decode()
|
63 |
|
64 |
@tool
|
65 |
-
def
|
66 |
-
"""Analyze
|
67 |
|
68 |
Args:
|
69 |
-
data (pd.DataFrame):
|
70 |
-
|
|
|
|
|
71 |
Returns:
|
72 |
-
str:
|
73 |
-
including unique value counts, top categories, and missing value counts.
|
74 |
"""
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
'top_categories': data[col].value_counts().head(5).to_dict(),
|
82 |
-
'missing': int(data[col].isnull().sum())
|
83 |
-
}
|
84 |
|
85 |
-
|
|
|
|
|
|
|
86 |
|
87 |
@tool
|
88 |
-
def
|
89 |
-
"""
|
90 |
|
91 |
Args:
|
92 |
-
data (pd.DataFrame):
|
93 |
-
|
|
|
|
|
94 |
Returns:
|
95 |
-
|
96 |
-
the characteristics of the input data.
|
97 |
"""
|
98 |
-
|
99 |
-
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
100 |
-
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
|
101 |
|
102 |
-
|
103 |
-
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
if data[col].skew() > 1 or data[col].skew() < -1:
|
110 |
-
suggestions.append(f"Consider log transformation for {col} due to skewness")
|
111 |
|
112 |
-
return
|
|
|
|
|
|
|
|
|
113 |
|
114 |
def main():
|
115 |
-
st.title("
|
116 |
-
st.
|
117 |
|
118 |
# Initialize session state
|
119 |
if 'data' not in st.session_state:
|
120 |
-
st.session_state
|
121 |
-
if '
|
122 |
-
st.session_state
|
123 |
-
|
124 |
-
#
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
if uploaded_file
|
129 |
-
with st.spinner(
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
)
|
167 |
-
|
168 |
-
|
169 |
-
else:
|
170 |
-
st.write(result)
|
171 |
-
|
172 |
-
elif analysis_type == "Categorical Analysis":
|
173 |
-
with st.spinner('Analyzing categorical columns...'):
|
174 |
-
result = st.session_state['agent'].run(
|
175 |
-
f"Analyze categorical columns in the dataset: {st.session_state['data']}"
|
176 |
-
)
|
177 |
-
st.write(result)
|
178 |
-
|
179 |
-
elif analysis_type == "Feature Engineering":
|
180 |
-
with st.spinner('Generating feature suggestions...'):
|
181 |
-
result = st.session_state['agent'].run(
|
182 |
-
f"Suggest feature engineering steps for the dataset: {st.session_state['data']}"
|
183 |
-
)
|
184 |
-
st.write(result)
|
185 |
-
|
186 |
-
elif analysis_type == "Custom Question":
|
187 |
-
question = st.text_input("What would you like to know about your data?")
|
188 |
-
if question:
|
189 |
-
with st.spinner('Analyzing...'):
|
190 |
-
result = st.session_state['agent'].run(question)
|
191 |
-
st.write(result)
|
192 |
-
|
193 |
-
except Exception as e:
|
194 |
-
st.error(f"An error occurred: {str(e)}")
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
6 |
import os
|
7 |
import base64
|
8 |
import io
|
9 |
+
from groq import Groq
|
10 |
+
from langchain.tools import tool
|
11 |
+
from langchain.agents import AgentType, initialize_agent
|
12 |
+
from langchain.chains import LLMChain
|
13 |
+
from langchain.prompts import PromptTemplate
|
14 |
+
from typing import Optional, Dict, List
|
15 |
+
|
16 |
+
# Initialize Groq Client
|
17 |
+
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
18 |
|
19 |
+
class GroqAnalyst:
|
20 |
+
"""Advanced AI Researcher & Data Analyst using Groq"""
|
21 |
+
def __init__(self, model_name="mixtral-8x7b-32768"):
|
22 |
+
self.model_name = model_name
|
23 |
+
self.system_prompt = """
|
24 |
+
You are an expert AI research assistant and data scientist.
|
25 |
+
Provide detailed, technical analysis with professional visualizations.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def analyze(self, prompt: str, data: pd.DataFrame) -> str:
|
29 |
+
"""Execute complex data analysis using Groq"""
|
30 |
+
try:
|
31 |
+
dataset_info = f"""
|
32 |
+
Dataset Shape: {data.shape}
|
33 |
+
Columns: {', '.join(data.columns)}
|
34 |
+
Data Types: {data.dtypes.to_dict()}
|
35 |
+
Sample Data: {data.head(3).to_dict()}
|
36 |
+
"""
|
37 |
+
|
38 |
+
completion = client.chat.completions.create(
|
39 |
+
messages=[
|
40 |
+
{"role": "system", "content": self.system_prompt},
|
41 |
+
{"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"}
|
42 |
+
],
|
43 |
+
model=self.model_name,
|
44 |
+
temperature=0.3,
|
45 |
+
max_tokens=4096,
|
46 |
+
stream=False
|
47 |
+
)
|
48 |
+
|
49 |
+
return completion.choices[0].message.content
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
return f"Analysis Error: {str(e)}"
|
53 |
|
54 |
@tool
|
55 |
+
def advanced_eda(data: pd.DataFrame) -> Dict:
|
56 |
+
"""Perform comprehensive exploratory data analysis.
|
57 |
|
58 |
Args:
|
59 |
+
data (pd.DataFrame): Input dataset for analysis
|
60 |
+
|
61 |
Returns:
|
62 |
+
Dict: Contains statistical summary, missing values, and data quality report
|
|
|
63 |
"""
|
64 |
+
analysis = {
|
65 |
+
"statistical_summary": data.describe().to_dict(),
|
66 |
+
"missing_values": data.isnull().sum().to_dict(),
|
67 |
+
"data_quality": {
|
68 |
+
"duplicates": data.duplicated().sum(),
|
69 |
+
"zero_values": (data == 0).sum().to_dict()
|
|
|
|
|
|
|
|
|
70 |
}
|
71 |
+
}
|
72 |
+
return analysis
|
73 |
|
74 |
@tool
|
75 |
+
def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str:
|
76 |
+
"""Generate distribution plots for specified numerical columns.
|
77 |
|
78 |
Args:
|
79 |
+
data (pd.DataFrame): Input dataset
|
80 |
+
columns (List[str]): List of numerical columns to visualize
|
81 |
+
|
82 |
Returns:
|
83 |
+
str: Base64 encoded image of the visualization
|
84 |
"""
|
85 |
+
plt.figure(figsize=(12, 6))
|
86 |
+
for i, col in enumerate(columns, 1):
|
87 |
+
plt.subplot(1, len(columns), i)
|
88 |
+
sns.histplot(data[col], kde=True)
|
89 |
+
plt.title(f'Distribution of {col}')
|
90 |
+
plt.tight_layout()
|
91 |
|
92 |
buf = io.BytesIO()
|
93 |
plt.savefig(buf, format='png')
|
|
|
95 |
return base64.b64encode(buf.getvalue()).decode()
|
96 |
|
97 |
@tool
|
98 |
+
def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str:
|
99 |
+
"""Analyze time series data and generate trend visualization.
|
100 |
|
101 |
Args:
|
102 |
+
data (pd.DataFrame): Dataset containing time series
|
103 |
+
time_col (str): Name of timestamp column
|
104 |
+
value_col (str): Name of value column to analyze
|
105 |
+
|
106 |
Returns:
|
107 |
+
str: Base64 encoded image of time series plot
|
|
|
108 |
"""
|
109 |
+
plt.figure(figsize=(12, 6))
|
110 |
+
data[time_col] = pd.to_datetime(data[time_col])
|
111 |
+
data.set_index(time_col)[value_col].plot()
|
112 |
+
plt.title(f'Temporal Trend of {value_col}')
|
113 |
+
plt.xlabel('Date')
|
114 |
+
plt.ylabel('Value')
|
|
|
|
|
|
|
115 |
|
116 |
+
buf = io.BytesIO()
|
117 |
+
plt.savefig(buf, format='png')
|
118 |
+
plt.close()
|
119 |
+
return base64.b64encode(buf.getvalue()).decode()
|
120 |
|
121 |
@tool
|
122 |
+
def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict:
|
123 |
+
"""Perform statistical hypothesis testing between groups.
|
124 |
|
125 |
Args:
|
126 |
+
data (pd.DataFrame): Input dataset
|
127 |
+
group_col (str): Categorical column defining groups
|
128 |
+
value_col (str): Numerical column to compare
|
129 |
+
|
130 |
Returns:
|
131 |
+
Dict: Contains test results, p-value, and conclusion
|
|
|
132 |
"""
|
133 |
+
from scipy.stats import ttest_ind
|
|
|
|
|
134 |
|
135 |
+
groups = data[group_col].unique()
|
136 |
+
if len(groups) != 2:
|
137 |
+
return {"error": "Hypothesis testing requires exactly two groups"}
|
138 |
|
139 |
+
group1 = data[data[group_col] == groups[0]][value_col]
|
140 |
+
group2 = data[data[group_col] == groups[1]][value_col]
|
141 |
+
|
142 |
+
t_stat, p_value = ttest_ind(group1, group2)
|
|
|
|
|
143 |
|
144 |
+
return {
|
145 |
+
"t_statistic": t_stat,
|
146 |
+
"p_value": p_value,
|
147 |
+
"conclusion": "Significant difference" if p_value < 0.05 else "No significant difference"
|
148 |
+
}
|
149 |
|
150 |
def main():
|
151 |
+
st.title("🔬 AI Research Assistant with Groq")
|
152 |
+
st.markdown("Advanced data analysis powered by Groq's accelerated computing")
|
153 |
|
154 |
# Initialize session state
|
155 |
if 'data' not in st.session_state:
|
156 |
+
st.session_state.data = None
|
157 |
+
if 'analyst' not in st.session_state:
|
158 |
+
st.session_state.analyst = GroqAnalyst()
|
159 |
+
|
160 |
+
# File upload section
|
161 |
+
with st.sidebar:
|
162 |
+
st.header("Data Upload")
|
163 |
+
uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv")
|
164 |
+
if uploaded_file:
|
165 |
+
with st.spinner("Analyzing dataset..."):
|
166 |
+
st.session_state.data = pd.read_csv(uploaded_file)
|
167 |
+
st.success(f"Loaded {len(st.session_state.data)} records")
|
168 |
+
|
169 |
+
# Main analysis interface
|
170 |
+
if st.session_state.data is not None:
|
171 |
+
st.subheader("Dataset Overview")
|
172 |
+
st.dataframe(st.session_state.data.head(), use_container_width=True)
|
173 |
+
|
174 |
+
analysis_type = st.selectbox("Select Analysis Type", [
|
175 |
+
"Exploratory Data Analysis",
|
176 |
+
"Temporal Analysis",
|
177 |
+
"Statistical Testing",
|
178 |
+
"Custom Research Query"
|
179 |
+
])
|
180 |
+
|
181 |
+
if analysis_type == "Exploratory Data Analysis":
|
182 |
+
with st.expander("Advanced EDA"):
|
183 |
+
eda_result = advanced_eda(st.session_state.data)
|
184 |
+
st.json(eda_result)
|
185 |
|
186 |
+
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
|
187 |
+
if num_cols:
|
188 |
+
selected_cols = st.multiselect("Select columns for distribution analysis", num_cols)
|
189 |
+
if selected_cols:
|
190 |
+
img_data = visualize_distributions(st.session_state.data, selected_cols)
|
191 |
+
st.image(f"data:image/png;base64,{img_data}")
|
192 |
|
193 |
+
elif analysis_type == "Temporal Analysis":
|
194 |
+
time_col = st.selectbox("Select time column", st.session_state.data.columns)
|
195 |
+
value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns)
|
196 |
+
if time_col and value_col:
|
197 |
+
img_data = temporal_analysis(st.session_state.data, time_col, value_col)
|
198 |
+
st.image(f"data:image/png;base64,{img_data}")
|
199 |
+
|
200 |
+
elif analysis_type == "Statistical Testing":
|
201 |
+
group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns)
|
202 |
+
value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns)
|
203 |
+
if group_col and value_col:
|
204 |
+
test_result = hypothesis_testing(st.session_state.data, group_col, value_col)
|
205 |
+
st.json(test_result)
|
206 |
+
|
207 |
+
elif analysis_type == "Custom Research Query":
|
208 |
+
research_query = st.text_area("Enter your research question:")
|
209 |
+
if research_query:
|
210 |
+
with st.spinner("Conducting advanced analysis..."):
|
211 |
+
result = st.session_state.analyst.analyze(research_query, st.session_state.data)
|
212 |
+
st.markdown("### Research Findings")
|
213 |
+
st.markdown(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
if __name__ == "__main__":
|
216 |
main()
|