singhtech commited on
Commit
13fbfb8
·
verified ·
1 Parent(s): 15f8686

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -282
app.py CHANGED
@@ -1,283 +1,283 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import os
4
- from crewai import Crew
5
- from langchain_groq import ChatGroq
6
- import streamlit_ace as st_ace
7
- import traceback
8
- import contextlib
9
- import io
10
- from crewai_tools import FileReadTool
11
- import matplotlib.pyplot as plt
12
- import glob
13
- from dotenv import load_dotenv
14
- from autotabml_agents import initialize_agents
15
- from autotabml_tasks import create_tasks
16
-
17
-
18
- TEMP_DIR = "temp_dir"
19
- OUTPUT_DIR = "Output_dir"
20
- # Ensure the temporary directory exists
21
- if not os.path.exists(TEMP_DIR):
22
- os.makedirs(TEMP_DIR)
23
-
24
- # Ensure the Output directory exits
25
- if not os.path.exists(OUTPUT_DIR):
26
- os.makedirs(OUTPUT_DIR)
27
-
28
- # Function to save uploaded file
29
- def save_uploaded_file(uploaded_file):
30
- file_path = os.path.join(TEMP_DIR, uploaded_file.name)
31
- with open(file_path, 'wb') as f:
32
- f.write(uploaded_file.getbuffer())
33
- return file_path
34
-
35
- # load the .env file
36
- load_dotenv()
37
- # Set up Groq API key
38
- groq_api_key = os.environ.get("GROQ_API_KEY") # os.environ["GROQ_API_KEY"] =
39
-
40
-
41
- def main():
42
- # Set custom CSS for UI
43
- set_custom_css()
44
-
45
- # Initialize session state for edited code
46
- if 'edited_code' not in st.session_state:
47
- st.session_state['edited_code'] = ""
48
-
49
- # Initialize session state for whether the initial code is generated
50
- if 'code_generated' not in st.session_state:
51
- st.session_state['code_generated'] = False
52
-
53
- # Header with futuristic design
54
- st.markdown("""
55
- <div class="header">
56
- <h1>AutoTabML</h1>
57
- <p>Automated Machine Learning Code Generation for Tabluar Data</p>
58
- </div>
59
- """, unsafe_allow_html=True)
60
-
61
- # Sidebar for customization options
62
- st.sidebar.title('LLM Model')
63
- model = st.sidebar.selectbox(
64
- 'Model',
65
- ["llama3-70b-8192"]
66
- )
67
-
68
- # Initialize LLM
69
- llm = initialize_llm(model)
70
-
71
-
72
-
73
- # User inputs
74
- user_question = st.text_area("Describe your ML problem:", key="user_question")
75
- uploaded_file = st.file_uploader("Upload a sample .csv of your data", key="uploaded_file")
76
- try:
77
- file_name = uploaded_file.name
78
- except:
79
- file_name = "dataset.csv"
80
-
81
- # Initialize agents
82
- agents = initialize_agents(llm,file_name)
83
- # Process uploaded file
84
- if uploaded_file:
85
- try:
86
- file_path = save_uploaded_file(uploaded_file)
87
- df = pd.read_csv(uploaded_file)
88
- st.write("Data successfully uploaded:")
89
- st.dataframe(df.head())
90
- data_upload = True
91
- except Exception as e:
92
- st.error(f"Error reading the file: {e}")
93
- data_upload = False
94
- else:
95
- df = None
96
- data_upload = False
97
-
98
- # Process button
99
- if st.button('Process'):
100
- tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents)
101
- with st.spinner('Processing...'):
102
- crew = Crew(
103
- agents=list(agents.values()),
104
- tasks=tasks,
105
- verbose=2
106
- )
107
-
108
- result = crew.kickoff()
109
-
110
- if result: # Only call st_ace if code has a valid value
111
- code = result.strip("```")
112
- try:
113
- filt_idx = code.index("```")
114
- code = code[:filt_idx]
115
- except:
116
- pass
117
- st.session_state['edited_code'] = code
118
- st.session_state['code_generated'] = True
119
-
120
- st.session_state['edited_code'] = st_ace.st_ace(
121
- value=st.session_state['edited_code'],
122
- language='python',
123
- theme='monokai',
124
- keybinding='vscode',
125
- min_lines=20,
126
- max_lines=50
127
- )
128
-
129
- if st.session_state['code_generated']:
130
- # Show options for modification, debugging, and running the code
131
- suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion")
132
- if st.button('Modify'):
133
- if st.session_state['edited_code'] and suggestion:
134
- tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents)
135
- with st.spinner('Modifying code...'):
136
- crew = Crew(
137
- agents=list(agents.values()),
138
- tasks=tasks,
139
- verbose=2
140
- )
141
-
142
- result = crew.kickoff()
143
-
144
- if result: # Only call st_ace if code has a valid value
145
- code = result.strip("```")
146
- try:
147
- filter_idx = code.index("```")
148
- code = code[:filter_idx]
149
- except:
150
- pass
151
- st.session_state['edited_code'] = code
152
-
153
- st.write("Modified code:")
154
- st.session_state['edited_code']= st_ace.st_ace(
155
- value=st.session_state['edited_code'],
156
- language='python',
157
- theme='monokai',
158
- keybinding='vscode',
159
- min_lines=20,
160
- max_lines=50
161
- )
162
-
163
- debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger")
164
- if st.button('Debug'):
165
- if st.session_state['edited_code'] and debugger:
166
- tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents)
167
- with st.spinner('Debugging code...'):
168
- crew = Crew(
169
- agents=list(agents.values()),
170
- tasks=tasks,
171
- verbose=2
172
- )
173
-
174
- result = crew.kickoff()
175
-
176
- if result: # Only call st_ace if code has a valid value
177
- code = result.strip("```")
178
- try:
179
- filter_idx = code.index("```")
180
- code = code[:filter_idx]
181
- except:
182
- pass
183
- st.session_state['edited_code'] = code
184
-
185
- st.write("Debugged code:")
186
- st.session_state['edited_code'] = st_ace.st_ace(
187
- value=st.session_state['edited_code'],
188
- language='python',
189
- theme='monokai',
190
- keybinding='vscode',
191
- min_lines=20,
192
- max_lines=50
193
- )
194
-
195
- if st.button('Run'):
196
- output = io.StringIO()
197
- with contextlib.redirect_stdout(output):
198
- try:
199
- globals().update({'dataset': df})
200
- final_code = st.session_state["edited_code"]
201
-
202
- with st.expander("Final Code"):
203
- st.code(final_code, language='python')
204
-
205
- exec(final_code, globals())
206
- result = output.getvalue()
207
- success = True
208
- except Exception as e:
209
- result = str(e)
210
- success = False
211
-
212
- st.subheader('Output:')
213
- st.text(result)
214
-
215
- figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()]
216
- if figs:
217
- st.subheader('Generated Plots:')
218
- for fig in figs:
219
- st.pyplot(fig)
220
-
221
- if success:
222
- st.success("Code executed successfully!")
223
- else:
224
- st.error("Code execution failed! Waiting for debugging input...")
225
-
226
- # Move the generated files section to the sidebar
227
- with st.sidebar:
228
- st.header('Output_dir :')
229
- files = glob.glob(os.path.join(OUTPUT_DIR,"/", '*'))
230
- for file in files:
231
- if os.path.isfile(file):
232
- with open(file, 'rb') as f:
233
- st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file))
234
-
235
-
236
-
237
- # Function to set custom CSS for futuristic UI
238
- def set_custom_css():
239
- st.markdown("""
240
- <style>
241
- body {
242
- background: #0e0e0e;
243
- color: #e0e0e0;
244
- font-family: 'Roboto', sans-serif;
245
- }
246
- .header {
247
- background: linear-gradient(135deg, #6e3aff, #b839ff);
248
- padding: 10px;
249
- border-radius: 10px;
250
- }
251
- .header h1, .header p {
252
- color: white;
253
- text-align: center;
254
- }
255
- .stButton button {
256
- background-color: #b839ff;
257
- color: white;
258
- border-radius: 10px;
259
- font-size: 16px;
260
- padding: 10px 20px;
261
- }
262
- .stButton button:hover {
263
- background-color: #6e3aff;
264
- color: #e0e0e0;
265
- }
266
- .spinner {
267
- display: flex;
268
- justify-content: center;
269
- align-items: center;
270
- }
271
- </style>
272
- """, unsafe_allow_html=True)
273
-
274
- # Function to initialize LLM
275
- def initialize_llm(model):
276
- return ChatGroq(
277
- temperature=0,
278
- groq_api_key=groq_api_key,
279
- model_name=model
280
- )
281
-
282
- if __name__ == "__main__":
283
  main()
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from crewai import Crew
5
+ from langchain_groq import ChatGroq
6
+ import streamlit_ace as st_ace
7
+ import traceback
8
+ import contextlib
9
+ import io
10
+ from crewai_tools import FileReadTool
11
+ import matplotlib.pyplot as plt
12
+ import glob
13
+ from dotenv import load_dotenv
14
+ from autotabml_agents import initialize_agents
15
+ from autotabml_tasks import create_tasks
16
+
17
+
18
+ TEMP_DIR = "temp_dir"
19
+ OUTPUT_DIR = "Output_dir"
20
+ # Ensure the temporary directory exists
21
+ if not os.path.exists(TEMP_DIR):
22
+ os.makedirs(TEMP_DIR)
23
+
24
+ # Ensure the Output directory exits
25
+ if not os.path.exists(OUTPUT_DIR):
26
+ os.makedirs(OUTPUT_DIR)
27
+
28
+ # Function to save uploaded file
29
+ def save_uploaded_file(uploaded_file):
30
+ file_path = os.path.join(TEMP_DIR, uploaded_file.name)
31
+ with open(file_path, 'wb') as f:
32
+ f.write(uploaded_file.getbuffer())
33
+ return file_path
34
+
35
+ # load the .env file
36
+ load_dotenv()
37
+ # Set up Groq API key
38
+ groq_api_key = os.environ.get("GROQ_API_KEY") # os.environ["GROQ_API_KEY"] =
39
+
40
+
41
+ def main():
42
+ # Set custom CSS for UI
43
+ set_custom_css()
44
+
45
+ # Initialize session state for edited code
46
+ if 'edited_code' not in st.session_state:
47
+ st.session_state['edited_code'] = ""
48
+
49
+ # Initialize session state for whether the initial code is generated
50
+ if 'code_generated' not in st.session_state:
51
+ st.session_state['code_generated'] = False
52
+
53
+ # Header with futuristic design
54
+ st.markdown("""
55
+ <div class="header">
56
+ <h1>AutoTabML</h1>
57
+ <p>Automated Machine Learning Code Generation for Tabluar Data</p>
58
+ </div>
59
+ """, unsafe_allow_html=True)
60
+
61
+ # Sidebar for customization options
62
+ st.sidebar.title('LLM Model')
63
+ model = st.sidebar.selectbox(
64
+ 'Model',
65
+ ["llama3-70b-8192"]
66
+ )
67
+
68
+ # Initialize LLM
69
+ llm = initialize_llm(model)
70
+
71
+
72
+
73
+ # User inputs
74
+ user_question = st.text_area("Describe your ML problem:", key="user_question")
75
+ uploaded_file = st.file_uploader("Upload a sample .csv of your data", key="uploaded_file")
76
+ try:
77
+ file_name = uploaded_file.name
78
+ except:
79
+ file_name = "dataset.csv"
80
+
81
+ # Initialize agents
82
+ agents = initialize_agents(llm,file_name,TEMP_DIR)
83
+ # Process uploaded file
84
+ if uploaded_file:
85
+ try:
86
+ file_path = save_uploaded_file(uploaded_file)
87
+ df = pd.read_csv(uploaded_file)
88
+ st.write("Data successfully uploaded:")
89
+ st.dataframe(df.head())
90
+ data_upload = True
91
+ except Exception as e:
92
+ st.error(f"Error reading the file: {e}")
93
+ data_upload = False
94
+ else:
95
+ df = None
96
+ data_upload = False
97
+
98
+ # Process button
99
+ if st.button('Process'):
100
+ tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents)
101
+ with st.spinner('Processing...'):
102
+ crew = Crew(
103
+ agents=list(agents.values()),
104
+ tasks=tasks,
105
+ verbose=2
106
+ )
107
+
108
+ result = crew.kickoff()
109
+
110
+ if result: # Only call st_ace if code has a valid value
111
+ code = result.strip("```")
112
+ try:
113
+ filt_idx = code.index("```")
114
+ code = code[:filt_idx]
115
+ except:
116
+ pass
117
+ st.session_state['edited_code'] = code
118
+ st.session_state['code_generated'] = True
119
+
120
+ st.session_state['edited_code'] = st_ace.st_ace(
121
+ value=st.session_state['edited_code'],
122
+ language='python',
123
+ theme='monokai',
124
+ keybinding='vscode',
125
+ min_lines=20,
126
+ max_lines=50
127
+ )
128
+
129
+ if st.session_state['code_generated']:
130
+ # Show options for modification, debugging, and running the code
131
+ suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion")
132
+ if st.button('Modify'):
133
+ if st.session_state['edited_code'] and suggestion:
134
+ tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents)
135
+ with st.spinner('Modifying code...'):
136
+ crew = Crew(
137
+ agents=list(agents.values()),
138
+ tasks=tasks,
139
+ verbose=2
140
+ )
141
+
142
+ result = crew.kickoff()
143
+
144
+ if result: # Only call st_ace if code has a valid value
145
+ code = result.strip("```")
146
+ try:
147
+ filter_idx = code.index("```")
148
+ code = code[:filter_idx]
149
+ except:
150
+ pass
151
+ st.session_state['edited_code'] = code
152
+
153
+ st.write("Modified code:")
154
+ st.session_state['edited_code']= st_ace.st_ace(
155
+ value=st.session_state['edited_code'],
156
+ language='python',
157
+ theme='monokai',
158
+ keybinding='vscode',
159
+ min_lines=20,
160
+ max_lines=50
161
+ )
162
+
163
+ debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger")
164
+ if st.button('Debug'):
165
+ if st.session_state['edited_code'] and debugger:
166
+ tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents)
167
+ with st.spinner('Debugging code...'):
168
+ crew = Crew(
169
+ agents=list(agents.values()),
170
+ tasks=tasks,
171
+ verbose=2
172
+ )
173
+
174
+ result = crew.kickoff()
175
+
176
+ if result: # Only call st_ace if code has a valid value
177
+ code = result.strip("```")
178
+ try:
179
+ filter_idx = code.index("```")
180
+ code = code[:filter_idx]
181
+ except:
182
+ pass
183
+ st.session_state['edited_code'] = code
184
+
185
+ st.write("Debugged code:")
186
+ st.session_state['edited_code'] = st_ace.st_ace(
187
+ value=st.session_state['edited_code'],
188
+ language='python',
189
+ theme='monokai',
190
+ keybinding='vscode',
191
+ min_lines=20,
192
+ max_lines=50
193
+ )
194
+
195
+ if st.button('Run'):
196
+ output = io.StringIO()
197
+ with contextlib.redirect_stdout(output):
198
+ try:
199
+ globals().update({'dataset': df})
200
+ final_code = st.session_state["edited_code"]
201
+
202
+ with st.expander("Final Code"):
203
+ st.code(final_code, language='python')
204
+
205
+ exec(final_code, globals())
206
+ result = output.getvalue()
207
+ success = True
208
+ except Exception as e:
209
+ result = str(e)
210
+ success = False
211
+
212
+ st.subheader('Output:')
213
+ st.text(result)
214
+
215
+ figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()]
216
+ if figs:
217
+ st.subheader('Generated Plots:')
218
+ for fig in figs:
219
+ st.pyplot(fig)
220
+
221
+ if success:
222
+ st.success("Code executed successfully!")
223
+ else:
224
+ st.error("Code execution failed! Waiting for debugging input...")
225
+
226
+ # Move the generated files section to the sidebar
227
+ with st.sidebar:
228
+ st.header('Output_dir :')
229
+ files = glob.glob(os.path.join(OUTPUT_DIR,"/", '*'))
230
+ for file in files:
231
+ if os.path.isfile(file):
232
+ with open(file, 'rb') as f:
233
+ st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file))
234
+
235
+
236
+
237
+ # Function to set custom CSS for futuristic UI
238
+ def set_custom_css():
239
+ st.markdown("""
240
+ <style>
241
+ body {
242
+ background: #0e0e0e;
243
+ color: #e0e0e0;
244
+ font-family: 'Roboto', sans-serif;
245
+ }
246
+ .header {
247
+ background: linear-gradient(135deg, #6e3aff, #b839ff);
248
+ padding: 10px;
249
+ border-radius: 10px;
250
+ }
251
+ .header h1, .header p {
252
+ color: white;
253
+ text-align: center;
254
+ }
255
+ .stButton button {
256
+ background-color: #b839ff;
257
+ color: white;
258
+ border-radius: 10px;
259
+ font-size: 16px;
260
+ padding: 10px 20px;
261
+ }
262
+ .stButton button:hover {
263
+ background-color: #6e3aff;
264
+ color: #e0e0e0;
265
+ }
266
+ .spinner {
267
+ display: flex;
268
+ justify-content: center;
269
+ align-items: center;
270
+ }
271
+ </style>
272
+ """, unsafe_allow_html=True)
273
+
274
+ # Function to initialize LLM
275
+ def initialize_llm(model):
276
+ return ChatGroq(
277
+ temperature=0,
278
+ groq_api_key=groq_api_key,
279
+ model_name=model
280
+ )
281
+
282
+ if __name__ == "__main__":
283
  main()