mgbam commited on
Commit
06300b8
·
verified ·
1 Parent(s): 28cd590

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -144
app.py CHANGED
@@ -1,60 +1,93 @@
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
- from langchain.tools import tool
5
- from langchain.agents import initialize_agent, AgentType
6
- from langchain.chat_models import ChatOpenAI
7
- from typing import Union, List, Dict, Optional
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
10
  import os
11
  import base64
12
  import io
 
 
 
 
 
 
 
 
 
13
 
14
- # Set up LangChain with OpenAI (or any other LLM)
15
- os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Replace with your OpenAI API key
16
- llm = ChatOpenAI(model="gpt-4", temperature=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @tool
19
- def analyze_basic_stats(data: pd.DataFrame) -> str:
20
- """Calculate basic statistical measures for numerical columns in the dataset.
21
 
22
  Args:
23
- data (pd.DataFrame): The dataset to analyze. It should contain at least one numerical column.
24
-
25
  Returns:
26
- str: A string containing formatted basic statistics for each numerical column,
27
- including mean, median, standard deviation, skewness, and missing value counts.
28
  """
29
- stats = {}
30
- numeric_cols = data.select_dtypes(include=[np.number]).columns
31
-
32
- for col in numeric_cols:
33
- stats[col] = {
34
- 'mean': float(data[col].mean()),
35
- 'median': float(data[col].median()),
36
- 'std': float(data[col].std()),
37
- 'skew': float(data[col].skew()),
38
- 'missing': int(data[col].isnull().sum())
39
  }
40
-
41
- return str(stats)
42
 
43
  @tool
44
- def generate_correlation_matrix(data: pd.DataFrame) -> str:
45
- """Generate a visual correlation matrix for numerical columns in the dataset.
46
 
47
  Args:
48
- data (pd.DataFrame): The dataset to analyze. It should contain at least two numerical columns.
49
-
 
50
  Returns:
51
- str: A base64 encoded string representing the correlation matrix plot image.
52
  """
53
- numeric_data = data.select_dtypes(include=[np.number])
54
-
55
- plt.figure(figsize=(10, 8))
56
- sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
57
- plt.title('Correlation Matrix')
 
58
 
59
  buf = io.BytesIO()
60
  plt.savefig(buf, format='png')
@@ -62,136 +95,122 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
62
  return base64.b64encode(buf.getvalue()).decode()
63
 
64
  @tool
65
- def analyze_categorical_columns(data: pd.DataFrame) -> str:
66
- """Analyze categorical columns in the dataset for distribution and frequencies.
67
 
68
  Args:
69
- data (pd.DataFrame): The dataset to analyze. It should contain at least one categorical column.
70
-
 
 
71
  Returns:
72
- str: A string containing formatted analysis results for each categorical column,
73
- including unique value counts, top categories, and missing value counts.
74
  """
75
- categorical_cols = data.select_dtypes(include=['object', 'category']).columns
76
- analysis = {}
77
-
78
- for col in categorical_cols:
79
- analysis[col] = {
80
- 'unique_values': int(data[col].nunique()),
81
- 'top_categories': data[col].value_counts().head(5).to_dict(),
82
- 'missing': int(data[col].isnull().sum())
83
- }
84
 
85
- return str(analysis)
 
 
 
86
 
87
  @tool
88
- def suggest_features(data: pd.DataFrame) -> str:
89
- """Suggest potential feature engineering steps based on data characteristics.
90
 
91
  Args:
92
- data (pd.DataFrame): The dataset to analyze. It can contain both numerical and categorical columns.
93
-
 
 
94
  Returns:
95
- str: A string containing suggestions for feature engineering based on
96
- the characteristics of the input data.
97
  """
98
- suggestions = []
99
- numeric_cols = data.select_dtypes(include=[np.number]).columns
100
- categorical_cols = data.select_dtypes(include=['object', 'category']).columns
101
 
102
- if len(numeric_cols) >= 2:
103
- suggestions.append("Consider creating interaction terms between numerical features")
 
104
 
105
- if len(categorical_cols) > 0:
106
- suggestions.append("Consider one-hot encoding for categorical variables")
107
-
108
- for col in numeric_cols:
109
- if data[col].skew() > 1 or data[col].skew() < -1:
110
- suggestions.append(f"Consider log transformation for {col} due to skewness")
111
 
112
- return '\n'.join(suggestions)
 
 
 
 
113
 
114
  def main():
115
- st.title("Data Analysis Assistant")
116
- st.write("Upload your dataset and get automated analysis with natural language interaction.")
117
 
118
  # Initialize session state
119
  if 'data' not in st.session_state:
120
- st.session_state['data'] = None
121
- if 'agent' not in st.session_state:
122
- st.session_state['agent'] = None
123
-
124
- # Drag-and-drop file upload
125
- uploaded_file = st.file_uploader("Drag and drop a CSV file here", type="csv")
126
-
127
- try:
128
- if uploaded_file is not None:
129
- with st.spinner('Loading and processing your data...'):
130
- # Load the dataset
131
- data = pd.read_csv(uploaded_file)
132
- st.session_state['data'] = data
133
-
134
- # Initialize the LangChain agent with the tools
135
- tools = [analyze_basic_stats, generate_correlation_matrix,
136
- analyze_categorical_columns, suggest_features]
137
- st.session_state['agent'] = initialize_agent(
138
- tools=tools,
139
- llm=llm,
140
- agent=AgentType.OPENAI_FUNCTIONS,
141
- verbose=True
142
- )
 
 
 
 
 
 
143
 
144
- st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
145
- st.subheader("Data Preview")
146
- st.dataframe(data.head())
 
 
 
147
 
148
- if st.session_state['data'] is not None:
149
- analysis_type = st.selectbox(
150
- "Choose analysis type",
151
- ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
152
- "Feature Engineering", "Custom Question"]
153
- )
154
-
155
- if analysis_type == "Basic Statistics":
156
- with st.spinner('Analyzing basic statistics...'):
157
- result = st.session_state['agent'].run(
158
- f"Analyze the dataset and provide basic statistics: {st.session_state['data']}"
159
- )
160
- st.write(result)
161
-
162
- elif analysis_type == "Correlation Analysis":
163
- with st.spinner('Generating correlation matrix...'):
164
- result = st.session_state['agent'].run(
165
- f"Generate a correlation matrix for the dataset: {st.session_state['data']}"
166
- )
167
- if isinstance(result, str) and result.startswith('data:image') or ',' in result:
168
- st.image(f"data:image/png;base64,{result.split(',')[-1]}")
169
- else:
170
- st.write(result)
171
-
172
- elif analysis_type == "Categorical Analysis":
173
- with st.spinner('Analyzing categorical columns...'):
174
- result = st.session_state['agent'].run(
175
- f"Analyze categorical columns in the dataset: {st.session_state['data']}"
176
- )
177
- st.write(result)
178
-
179
- elif analysis_type == "Feature Engineering":
180
- with st.spinner('Generating feature suggestions...'):
181
- result = st.session_state['agent'].run(
182
- f"Suggest feature engineering steps for the dataset: {st.session_state['data']}"
183
- )
184
- st.write(result)
185
-
186
- elif analysis_type == "Custom Question":
187
- question = st.text_input("What would you like to know about your data?")
188
- if question:
189
- with st.spinner('Analyzing...'):
190
- result = st.session_state['agent'].run(question)
191
- st.write(result)
192
-
193
- except Exception as e:
194
- st.error(f"An error occurred: {str(e)}")
195
 
196
  if __name__ == "__main__":
197
  main()
 
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
 
 
 
 
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
  import os
7
  import base64
8
  import io
9
+ from groq import Groq
10
+ from langchain.tools import tool
11
+ from langchain.agents import AgentType, initialize_agent
12
+ from langchain.chains import LLMChain
13
+ from langchain.prompts import PromptTemplate
14
+ from typing import Optional, Dict, List
15
+
16
+ # Initialize Groq Client
17
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
18
 
19
+ class GroqAnalyst:
20
+ """Advanced AI Researcher & Data Analyst using Groq"""
21
+ def __init__(self, model_name="mixtral-8x7b-32768"):
22
+ self.model_name = model_name
23
+ self.system_prompt = """
24
+ You are an expert AI research assistant and data scientist.
25
+ Provide detailed, technical analysis with professional visualizations.
26
+ """
27
+
28
+ def analyze(self, prompt: str, data: pd.DataFrame) -> str:
29
+ """Execute complex data analysis using Groq"""
30
+ try:
31
+ dataset_info = f"""
32
+ Dataset Shape: {data.shape}
33
+ Columns: {', '.join(data.columns)}
34
+ Data Types: {data.dtypes.to_dict()}
35
+ Sample Data: {data.head(3).to_dict()}
36
+ """
37
+
38
+ completion = client.chat.completions.create(
39
+ messages=[
40
+ {"role": "system", "content": self.system_prompt},
41
+ {"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"}
42
+ ],
43
+ model=self.model_name,
44
+ temperature=0.3,
45
+ max_tokens=4096,
46
+ stream=False
47
+ )
48
+
49
+ return completion.choices[0].message.content
50
+
51
+ except Exception as e:
52
+ return f"Analysis Error: {str(e)}"
53
 
54
  @tool
55
+ def advanced_eda(data: pd.DataFrame) -> Dict:
56
+ """Perform comprehensive exploratory data analysis.
57
 
58
  Args:
59
+ data (pd.DataFrame): Input dataset for analysis
60
+
61
  Returns:
62
+ Dict: Contains statistical summary, missing values, and data quality report
 
63
  """
64
+ analysis = {
65
+ "statistical_summary": data.describe().to_dict(),
66
+ "missing_values": data.isnull().sum().to_dict(),
67
+ "data_quality": {
68
+ "duplicates": data.duplicated().sum(),
69
+ "zero_values": (data == 0).sum().to_dict()
 
 
 
 
70
  }
71
+ }
72
+ return analysis
73
 
74
  @tool
75
+ def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str:
76
+ """Generate distribution plots for specified numerical columns.
77
 
78
  Args:
79
+ data (pd.DataFrame): Input dataset
80
+ columns (List[str]): List of numerical columns to visualize
81
+
82
  Returns:
83
+ str: Base64 encoded image of the visualization
84
  """
85
+ plt.figure(figsize=(12, 6))
86
+ for i, col in enumerate(columns, 1):
87
+ plt.subplot(1, len(columns), i)
88
+ sns.histplot(data[col], kde=True)
89
+ plt.title(f'Distribution of {col}')
90
+ plt.tight_layout()
91
 
92
  buf = io.BytesIO()
93
  plt.savefig(buf, format='png')
 
95
  return base64.b64encode(buf.getvalue()).decode()
96
 
97
  @tool
98
+ def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str:
99
+ """Analyze time series data and generate trend visualization.
100
 
101
  Args:
102
+ data (pd.DataFrame): Dataset containing time series
103
+ time_col (str): Name of timestamp column
104
+ value_col (str): Name of value column to analyze
105
+
106
  Returns:
107
+ str: Base64 encoded image of time series plot
 
108
  """
109
+ plt.figure(figsize=(12, 6))
110
+ data[time_col] = pd.to_datetime(data[time_col])
111
+ data.set_index(time_col)[value_col].plot()
112
+ plt.title(f'Temporal Trend of {value_col}')
113
+ plt.xlabel('Date')
114
+ plt.ylabel('Value')
 
 
 
115
 
116
+ buf = io.BytesIO()
117
+ plt.savefig(buf, format='png')
118
+ plt.close()
119
+ return base64.b64encode(buf.getvalue()).decode()
120
 
121
  @tool
122
+ def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict:
123
+ """Perform statistical hypothesis testing between groups.
124
 
125
  Args:
126
+ data (pd.DataFrame): Input dataset
127
+ group_col (str): Categorical column defining groups
128
+ value_col (str): Numerical column to compare
129
+
130
  Returns:
131
+ Dict: Contains test results, p-value, and conclusion
 
132
  """
133
+ from scipy.stats import ttest_ind
 
 
134
 
135
+ groups = data[group_col].unique()
136
+ if len(groups) != 2:
137
+ return {"error": "Hypothesis testing requires exactly two groups"}
138
 
139
+ group1 = data[data[group_col] == groups[0]][value_col]
140
+ group2 = data[data[group_col] == groups[1]][value_col]
141
+
142
+ t_stat, p_value = ttest_ind(group1, group2)
 
 
143
 
144
+ return {
145
+ "t_statistic": t_stat,
146
+ "p_value": p_value,
147
+ "conclusion": "Significant difference" if p_value < 0.05 else "No significant difference"
148
+ }
149
 
150
  def main():
151
+ st.title("🔬 AI Research Assistant with Groq")
152
+ st.markdown("Advanced data analysis powered by Groq's accelerated computing")
153
 
154
  # Initialize session state
155
  if 'data' not in st.session_state:
156
+ st.session_state.data = None
157
+ if 'analyst' not in st.session_state:
158
+ st.session_state.analyst = GroqAnalyst()
159
+
160
+ # File upload section
161
+ with st.sidebar:
162
+ st.header("Data Upload")
163
+ uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv")
164
+ if uploaded_file:
165
+ with st.spinner("Analyzing dataset..."):
166
+ st.session_state.data = pd.read_csv(uploaded_file)
167
+ st.success(f"Loaded {len(st.session_state.data)} records")
168
+
169
+ # Main analysis interface
170
+ if st.session_state.data is not None:
171
+ st.subheader("Dataset Overview")
172
+ st.dataframe(st.session_state.data.head(), use_container_width=True)
173
+
174
+ analysis_type = st.selectbox("Select Analysis Type", [
175
+ "Exploratory Data Analysis",
176
+ "Temporal Analysis",
177
+ "Statistical Testing",
178
+ "Custom Research Query"
179
+ ])
180
+
181
+ if analysis_type == "Exploratory Data Analysis":
182
+ with st.expander("Advanced EDA"):
183
+ eda_result = advanced_eda(st.session_state.data)
184
+ st.json(eda_result)
185
 
186
+ num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
187
+ if num_cols:
188
+ selected_cols = st.multiselect("Select columns for distribution analysis", num_cols)
189
+ if selected_cols:
190
+ img_data = visualize_distributions(st.session_state.data, selected_cols)
191
+ st.image(f"data:image/png;base64,{img_data}")
192
 
193
+ elif analysis_type == "Temporal Analysis":
194
+ time_col = st.selectbox("Select time column", st.session_state.data.columns)
195
+ value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns)
196
+ if time_col and value_col:
197
+ img_data = temporal_analysis(st.session_state.data, time_col, value_col)
198
+ st.image(f"data:image/png;base64,{img_data}")
199
+
200
+ elif analysis_type == "Statistical Testing":
201
+ group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns)
202
+ value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns)
203
+ if group_col and value_col:
204
+ test_result = hypothesis_testing(st.session_state.data, group_col, value_col)
205
+ st.json(test_result)
206
+
207
+ elif analysis_type == "Custom Research Query":
208
+ research_query = st.text_area("Enter your research question:")
209
+ if research_query:
210
+ with st.spinner("Conducting advanced analysis..."):
211
+ result = st.session_state.analyst.analyze(research_query, st.session_state.data)
212
+ st.markdown("### Research Findings")
213
+ st.markdown(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  if __name__ == "__main__":
216
  main()