AgentLlama007B / agent_llama_ui.py
srossitto79's picture
added airllm experiment
48a378e
import streamlit as st
from streamlit_chat import message
import os
import io
from dotenv import load_dotenv
import requests
import glob
import json
import shutil
from RBotReloaded import SmartAgent, MODELS_DIR
import time
from PIL import Image
from langchain.schema import AIMessage, HumanMessage
load_dotenv()
default_model = ""
default_context = 8192
default_load_type = "Auto"
default_iterations = 3
default_temperature = 0.2
default_topp = 0.95
@st.cache_resource
def agent(model, temperature, top_p, context_length, load_8bit, load_4bit, max_iterations):
ag = SmartAgent(f"{MODELS_DIR}/{model}" if os.path.exists(f"{MODELS_DIR}/{model}") else model, temp=temperature, top_p=top_p, load_in_4bit=load_4bit, load_in_8bit=load_8bit, ctx_len=context_length, max_iterations=max_iterations) if model else None
st.session_state["temperature_executive"] = temperature
st.session_state["max_iterations_executive"] = max_iterations
st.session_state["model_executive"] = model
st.session_state["context_length_executive"] = context_length
st.session_state["load_options_executive"] = "Load 4-bit" if load_8bit else "Load 4-bit" if load_4bit else "Auto"
st.session_state["top_p_executive"] = top_p
return ag
def get_models():
supported_extensions = ["bin","pth","gguf"]
models_directory = f"{MODELS_DIR}" # Replace with the actual path
# Use os.listdir to get a list of filenames in the directory
os.makedirs(models_directory, exist_ok=True)
models = os.listdir(models_directory)
# Filter out any subdirectories, if any
models = [model for model in models if (model.lower().split(".")[-1] in supported_extensions) and os.path.isfile(os.path.join(models_directory, model))]
models.append("garage-bAInd/Platypus2-70B-instruct")
if len(models) == 0:
st.write("Downloading models")
from huggingface_hub import hf_hub_download
st.write("Downloading mistral-7b-instruct-v0.1.Q4_K_M.gguf")
hf_hub_download(repo_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", local_dir=models_directory)
st.write("Downloading dreamshaper_8.safetensors")
hf_hub_download(repo_id="digiplay/DreamShaper_8", filename="dreamshaper_8.safetensors", local_dir=models_directory)
st.experimental_rerun()
#models.append("http://localhost:5000") #to use it with text gen webui
return models
def current_agent():
model = st.session_state.get("model", default_model)
temperature = st.session_state.get("temperature", default_temperature)
max_iterations = st.session_state.get("max_iterations", default_iterations)
context_length = st.session_state.get("context_length", default_context)
load_options = st.session_state.get("load_options", default_load_type)
top_p = st.session_state.get("top_p", default_topp)
model = st.session_state.get("model_executive", model)
temperature = st.session_state.get("temperature_executive", temperature)
max_iterations = st.session_state.get("max_iterations_executive", max_iterations)
context_length = st.session_state.get("context_length_executive", context_length)
load_options = st.session_state.get("load_options_executive", load_options)
top_p = st.session_state.get("top_p_executive", top_p)
return agent(model, temperature, top_p, context_length, load_options=="Load 8-bit", load_options=="Load 4-bit", max_iterations)
def history():
return [] if current_agent() is None else current_agent().chat_history
#@st.cache_data
def generate_text(input):
start_time = time.time()
output = "Error: Model not Loaded!" if current_agent() is None else current_agent().agent_generate_response(input)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"\n----------------------")
print(f"Agent Reply: {output} - Input: {input}")
print(f"Elapsed Time: {elapsed_time} seconds")
print(f"Agent Reply: {output}")
print(f"\n----------------------")
return output + f" ({round(elapsed_time,2)}s)"
def get_generated_files():
# Specify the directory path where the generated images are stored
directory = "./generated_images"
# Get the list of files in the directory
files = glob.glob(f"{directory}/*.jpg") # Modify the file extension as per your generated image format
# Return the list of file paths
return files
# Function to list files in the "./knowledge_base/" folder
def list_files_in_knowledge_base_folder():
knowledge_base_folder = "./knowledge_base/"
files = os.listdir(knowledge_base_folder)
return [file for file in files if os.path.isfile(os.path.join(knowledge_base_folder, file))]
# Function to add a file to the "./knowledge_base/" folder
def add_file_to_knowledge_base(file):
knowledge_base_folder = "./knowledge_base/"
final_path = os.path.join(knowledge_base_folder, file.name)
with open(final_path, "wb") as f:
f.write(file.read())
if current_agent() is None:
st.error("Model Not Loaded!")
else:
current_agent().memory_chain.addDocumentToMemory(os.path.join(knowledge_base_folder, file.name))
# Function to add a file to the "./knowledge_base/" folder
def set_image_gen_guide(file):
bytes_data = io.BytesIO(file.read())
image = Image.open(bytes_data)
image = image.convert("RGB")
image.save("./image_gen_guide.jpg")
def unset_image_gen_guide():
if os.path.exists("./image_gen_guide.jpg"):
os.remove("./image_gen_guide.jpg")
def get_index_size():
index_file_path = "./knowledge_base/index.faiss" # Replace with the actual path to your index file
if os.path.exists(index_file_path):
index_size = os.path.getsize(index_file_path)
return index_size / 1024
else:
print(f"{index_file_path} does not exist or is not accessible.")
return 0
# @cl.langchain_factory(use_async=True)
# def factory():
# return current_agent().smartAgent
def render_simple_chat():
st.markdown("<h3 style='text-align: center;'>To fully utilize all functionalities of this demo, you'll require a minimum of a 16-core CPU and 32GB of RAM. Please note that the limited resources available in Huggingface free spaces may lead to slow responses and potential crashes due to out-of-memory issues during image generation.</h3>", unsafe_allow_html=True)
models = get_models()
models.append("")
model = st.session_state.get("model", default_model)
temperature = st.session_state.get("temperature", default_temperature)
max_iterations = st.session_state.get("max_iterations", default_iterations)
context_length = st.session_state.get("context_length", default_context)
load_options = st.session_state.get("load_options", default_load_type)
top_p = st.session_state.get("top_p", default_topp)
with st.sidebar:
st.image("./avatar.png")
st.sidebar.title("LLM Options")
max_iterations = st.sidebar.slider("Max Iterations", min_value=1, max_value=10, step=1, key="max_iterations")
model = st.selectbox(label="Model", options=models, key="model")
if (not model.startswith("http")):
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, step=0.1, key="temperature")
top_p = st.sidebar.slider("top_p", min_value=0.1, max_value=1.0, step=0.1, key="top_p")
context_length = st.sidebar.slider("Context Length", min_value=1024, max_value=131072, step=1024, key="context_length")
# Load Options
load_options = st.sidebar.radio("Load Options", ["Auto", "Load 4-bit", "Load 8-bit"], key="load_options")
if (st.sidebar.button("Apply Changes to Model")):
st.session_state["temperature_executive"] = temperature
st.session_state["max_iterations_executive"] = max_iterations
st.session_state["model_executive"] = model
st.session_state["context_length_executive"] = context_length
st.session_state["load_options_executive"] = load_options
st.session_state["top_p_executive"] = top_p
#st.experimental_rerun()
if st.sidebar.button("Reset Chat Context", disabled=not (current_agent() is not None and len(current_agent().chat_history) > 0)) and current_agent() is not None:
current_agent().reset_context()
st.sidebar.write("-----")
st.sidebar.title("Documents Context")
st.sidebar.subheader(f"Current Memory Size {round(get_index_size() / 1024,2)}MB")
uploaded_file = st.sidebar.file_uploader("Drag and Drop a File to ./knowledge_base/", type=["txt", "pdf", "docx"])
knowledge_files = glob.glob(f"./knowledge_base/*.*")
st.sidebar.subheader("Knowledge Files")
for file_path in knowledge_files:
if not "index." in file_path.lower():
file_path = file_path.replace("\\", "/")
file_name = file_path.split("/")[-1]
st.sidebar.markdown(f"[{file_name}](/{file_path})", unsafe_allow_html=True)
if st.sidebar.button("Reset Long Term Memory", disabled=not (current_agent() is not None and get_index_size() > 0)) and current_agent() is not None:
current_agent().reset_knowledge()
st.sidebar.write("-----")
st.sidebar.title("Images Generation")
if os.path.exists("./image_gen_guide.jpg"):
st.sidebar.image("./image_gen_guide.jpg")
if st.sidebar.button("Remove Image Generation Guidance"):
unset_image_gen_guide()
st.experimental_rerun()
else:
image_gen_guide = st.sidebar.file_uploader("Drag and Drop an optional image to use as Guidance", type=["jpg", "png"])
if image_gen_guide:
set_image_gen_guide(image_gen_guide)
st.sidebar.success(f"File '{image_gen_guide.name}' set as image generation guidance.")
if uploaded_file:
add_file_to_knowledge_base(uploaded_file)
st.sidebar.success(f"File '{uploaded_file.name}' added to Knowledge Base.")
with st.sidebar:
#GENERATED FILES
generated_files = get_generated_files()
st.sidebar.subheader("Generated Files")
for file_path in generated_files:
file_path = file_path.replace("\\", "/")
file_name = file_path.split("/")[-1]
st.write("---")
st.markdown(f"[{file_name}](/{file_path})", unsafe_allow_html=True)
st.image(file_path, use_column_width=True)
i = 0
for m in history():
i = i +1
gen = str(m.content)
#saved to files: ./generated_images/image_202310091819331.jpg
if str(gen).endswith(".jpg") and os.path.exists(gen.split(" ")[-1]):
st.image(gen.split(" ")[-1])
message(gen, is_user=m.type.lower() == "human", key=str(i))
user_input = st.chat_input("Prompt", key="input_text")
if user_input:
message(user_input, is_user=True, key=str(i+1))
res = generate_text(user_input)
message(res, is_user=False, key=str(i+2))
##### BEGIN MAIN #####
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'model' not in st.session_state:
st.session_state['model'] = default_model
st.session_state['temperature'] = default_temperature
st.session_state['max_iterations'] = default_iterations
st.session_state['context_length'] = default_context
st.session_state['load_options'] = default_load_type
st.session_state['top_p'] = default_topp
st.set_page_config(page_title="Agent Llama", page_icon="🤖", layout="wide")
st.title("Agent Llama")
render_simple_chat()