|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
from crewai import Agent, Task, Crew |
|
from langchain_groq import ChatGroq |
|
import streamlit_ace as st_ace |
|
import traceback |
|
import contextlib |
|
import io |
|
from crewai_tools import FileReadTool |
|
import matplotlib.pyplot as plt |
|
import glob |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
def main(): |
|
|
|
set_custom_css() |
|
|
|
|
|
if 'edited_code' not in st.session_state: |
|
st.session_state['edited_code'] = "" |
|
|
|
|
|
if 'code_generated' not in st.session_state: |
|
st.session_state['code_generated'] = False |
|
|
|
|
|
st.markdown(""" |
|
<div class="header"> |
|
<h1>CrewAI Machine Learning Assistant</h1> |
|
<p>Your AI-powered partner for machine learning projects.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title('Customization') |
|
model = st.sidebar.selectbox( |
|
'Choose a model', |
|
['llama3-8b-8192', "llama3-70b-8192"] |
|
) |
|
|
|
|
|
llm = initialize_llm(model) |
|
|
|
|
|
|
|
|
|
user_question = st.text_area("Describe your ML problem:", key="user_question") |
|
uploaded_file = st.file_uploader("Upload a sample .csv of your data (optional)", key="uploaded_file") |
|
try: |
|
file_name = uploaded_file.name |
|
except: |
|
file_name = "dataset.csv" |
|
|
|
|
|
agents = initialize_agents(llm,file_name) |
|
|
|
if uploaded_file: |
|
try: |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Data successfully uploaded:") |
|
st.dataframe(df.head()) |
|
data_upload = True |
|
except Exception as e: |
|
st.error(f"Error reading the file: {e}") |
|
data_upload = False |
|
else: |
|
df = None |
|
data_upload = False |
|
|
|
|
|
if st.button('Process'): |
|
tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents) |
|
with st.spinner('Processing...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filt_idx = code.index("```") |
|
code = code[:filt_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
st.session_state['code_generated'] = True |
|
|
|
st.session_state['edited_code'] = st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
if st.session_state['code_generated']: |
|
|
|
suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion") |
|
if st.button('Modify'): |
|
if st.session_state['edited_code'] and suggestion: |
|
tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents) |
|
with st.spinner('Modifying code...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filter_idx = code.index("```") |
|
code = code[:filter_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
|
|
st.write("Modified code:") |
|
st.session_state['edited_code']= st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger") |
|
if st.button('Debug'): |
|
if st.session_state['edited_code'] and debugger: |
|
tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents) |
|
with st.spinner('Debugging code...'): |
|
crew = Crew( |
|
agents=list(agents.values()), |
|
tasks=tasks, |
|
verbose=2 |
|
) |
|
|
|
result = crew.kickoff() |
|
|
|
if result: |
|
code = result.strip("```") |
|
try: |
|
filter_idx = code.index("```") |
|
code = code[:filter_idx] |
|
except: |
|
pass |
|
st.session_state['edited_code'] = code |
|
|
|
st.write("Debugged code:") |
|
st.session_state['edited_code'] = st_ace.st_ace( |
|
value=st.session_state['edited_code'], |
|
language='python', |
|
theme='monokai', |
|
keybinding='vscode', |
|
min_lines=20, |
|
max_lines=50 |
|
) |
|
|
|
if st.button('Run'): |
|
output = io.StringIO() |
|
with contextlib.redirect_stdout(output): |
|
try: |
|
globals().update({'dataset': df}) |
|
final_code = st.session_state["edited_code"] |
|
|
|
with st.expander("Final Code"): |
|
st.code(final_code, language='python') |
|
|
|
exec(final_code, globals()) |
|
result = output.getvalue() |
|
success = True |
|
except Exception as e: |
|
result = str(e) |
|
success = False |
|
|
|
st.subheader('Output:') |
|
st.text(result) |
|
|
|
figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()] |
|
if figs: |
|
st.subheader('Generated Plots:') |
|
for fig in figs: |
|
st.pyplot(fig) |
|
|
|
if success: |
|
st.success("Code executed successfully!") |
|
else: |
|
st.error("Code execution failed! Waiting for debugging input...") |
|
|
|
|
|
with st.sidebar: |
|
st.header('Output Files:') |
|
files = glob.glob(os.path.join("Output/", '*')) |
|
for file in files: |
|
if os.path.isfile(file): |
|
with open(file, 'rb') as f: |
|
st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file)) |
|
|
|
|
|
|
|
|
|
def set_custom_css(): |
|
st.markdown(""" |
|
<style> |
|
body { |
|
background: #0e0e0e; |
|
color: #e0e0e0; |
|
font-family: 'Roboto', sans-serif; |
|
} |
|
.header { |
|
background: linear-gradient(135deg, #6e3aff, #b839ff); |
|
padding: 10px; |
|
border-radius: 10px; |
|
} |
|
.header h1, .header p { |
|
color: white; |
|
text-align: center; |
|
} |
|
.stButton button { |
|
background-color: #b839ff; |
|
color: white; |
|
border-radius: 10px; |
|
font-size: 16px; |
|
padding: 10px 20px; |
|
} |
|
.stButton button:hover { |
|
background-color: #6e3aff; |
|
color: #e0e0e0; |
|
} |
|
.spinner { |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def initialize_llm(model): |
|
return ChatGroq( |
|
temperature=0, |
|
groq_api_key=groq_api_key, |
|
model_name=model |
|
) |
|
|
|
|
|
def initialize_agents(llm,file_name): |
|
file_read_tool = FileReadTool() |
|
return { |
|
"Data_Reader_Agent": Agent( |
|
role='Data_Reader_Agent', |
|
goal="Read the uploaded dataset and provide it to other agents.", |
|
backstory="Responsible for reading the uploaded dataset.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
tools=[file_read_tool] |
|
), |
|
"Problem_Definition_Agent": Agent( |
|
role='Problem_Definition_Agent', |
|
goal="Clarify the machine learning problem the user wants to solve.", |
|
backstory="Expert in defining machine learning problems.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"EDA_Agent": Agent( |
|
role='EDA_Agent', |
|
goal="Perform all possible Exploratory Data Analysis (EDA) on the data provided by the user.", |
|
backstory="Specializes in conducting comprehensive EDA to understand the data characteristics, distributions, and relationships.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"Feature_Engineering_Agent": Agent( |
|
role='Feature_Engineering_Agent', |
|
goal="Perform feature engineering on the data based on the EDA results provided by the EDA agent.", |
|
backstory="Expert in deriving new features, transforming existing features, and preprocessing data to prepare it for modeling.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"Model_Recommendation_Agent": Agent( |
|
role='Model_Recommendation_Agent', |
|
goal="Suggest the most suitable machine learning models.", |
|
backstory="Expert in recommending machine learning algorithms.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"Starter_Code_Generator_Agent": Agent( |
|
role='Starter_Code_Generator_Agent', |
|
goal=f"Generate starter Python code for the project. Always give dataset name as {file_name}", |
|
backstory="Code wizard for generating starter code templates.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"Code_Modification_Agent": Agent( |
|
role='Code_Modification_Agent', |
|
goal="Modify the generated Python code based on user suggestions.", |
|
backstory="Expert in adapting code according to user feedback.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"Code_Debugger_Agent": Agent( |
|
role='Code_Debugger_Agent', |
|
goal="Debug the generated Python code.", |
|
backstory="Seasoned code debugger.", |
|
verbose=True, |
|
allow_delegation=False, |
|
llm=llm, |
|
), |
|
"Compiler_Agent":Agent( |
|
role = "Code_compiler", |
|
goal = "Extract only the python code.", |
|
backstory = "You are the compiler which extract only the python code.", |
|
verbose = True, |
|
allow_delegation = False, |
|
llm = llm |
|
) |
|
} |
|
|
|
|
|
def create_tasks(func_call,user_question,file_name, data_upload, df, suggestion, edited_code, debugger, agents): |
|
info = df.info() |
|
tasks = [] |
|
if(func_call == "Process"): |
|
tasks.append(Task( |
|
description=f"Clarify the ML problem: {user_question}", |
|
agent=agents["Problem_Definition_Agent"], |
|
expected_output="A clear and concise definition of the ML problem." |
|
) |
|
) |
|
|
|
if data_upload: |
|
tasks.extend([ |
|
Task( |
|
description=f"Evaluate the data provided by the file name . This is the data: {df}", |
|
agent=agents["EDA_Agent"], |
|
expected_output="An assessment of the EDA and preprocessing like dataset info, missing value, duplication, outliers etc. on the data provided" |
|
), |
|
Task( |
|
description=f"Feature Engineering on data {df} based on EDA output: {info}", |
|
agent=agents["Feature_Engineering_Agent"], |
|
expected_output="An assessment of the Featuring Engineering and preprocessing like handling missing values, handling duplication, handling outliers, feature encoding, feature scaling etc. on the data provided" |
|
) |
|
]) |
|
|
|
tasks.extend([ |
|
Task( |
|
description="Suggest suitable ML models.", |
|
agent=agents["Model_Recommendation_Agent"], |
|
expected_output="A list of suitable ML models." |
|
), |
|
Task( |
|
description=f"Generate starter Python code based on feature engineering, where column names are {df.columns.tolist()}. Generate only the code without any extra text", |
|
agent=agents["Starter_Code_Generator_Agent"], |
|
expected_output="Starter Python code." |
|
), |
|
]) |
|
if(func_call == "Modify"): |
|
if suggestion: |
|
tasks.append( |
|
Task( |
|
description=f"Modify the already generated code {edited_code} according to the suggestion: {suggestion} \n\n Do not generate entire new code.", |
|
agent=agents["Code_Modification_Agent"], |
|
expected_output="Modified code." |
|
) |
|
) |
|
if(func_call == "Debug"): |
|
if debugger: |
|
tasks.append( |
|
Task( |
|
description=f"Debug and fix any errors for data with column names {df.columns.tolist()} with data as {df} in the generated code: {edited_code} \n\n According to the debugging: {debugger}. \n\n Do not generate entire new code. Just remove the error in the code by modifying only necessary parts of the code.", |
|
agent=agents["Code_Debugger_Agent"], |
|
expected_output="Debugged and successfully executed code." |
|
) |
|
) |
|
tasks.append( |
|
Task( |
|
description = "Your job is to only extract python code from string", |
|
agent = agents["Compiler_Agent"], |
|
expected_output = "Running python code." |
|
) |
|
) |
|
|
|
return tasks |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|