RegBotBeta / pages /llama_custom_demo.py
Zwea Htet
fixed llms import
8e4a873
raw
history blame
5.32 kB
import random
import time
import streamlit as st
import os
import pathlib
from typing import List
from models.llms import (
llm_bloomz_560m,
llm_gpt_3_5_turbo_0125,
)
from models.embeddings import hf_embed_model, openai_embed_model
from models.llamaCustom import LlamaCustom
# from models.llamaCustom import LlamaCustom
from utils.chatbox import show_previous_messages, show_chat_input
from llama_index.core import (
SimpleDirectoryReader,
Document,
VectorStoreIndex,
StorageContext,
Settings,
load_index_from_storage,
)
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.base.llms.types import ChatMessage
SAVE_DIR = "uploaded_files"
VECTOR_STORE_DIR = "vectorStores"
# global
Settings.embed_model = hf_embed_model
llama_llms = {
"bigscience/bloomz-560m": llm_bloomz_560m,
# "mistral/mixtral": llm_mixtral_8x7b,
# "meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat,
# "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo,
"openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125,
# "openai/gpt-4-0125-preview": llm_gpt_4_0125,
# "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
}
def init_session_state():
if "llama_messages" not in st.session_state:
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
# TODO: create a chat history for each different document
if "llama_chat_history" not in st.session_state:
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if "llama_custom" not in st.session_state:
st.session_state.llama_custom = None
# @st.cache_resource
def index_docs(
filename: str,
) -> VectorStoreIndex:
try:
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
if pathlib.Path.exists(index_path):
print("Loading index from storage ...")
storage_context = StorageContext.from_defaults(persist_dir=index_path)
index = load_index_from_storage(storage_context=storage_context)
# test the index
index.as_query_engine().query("What is the capital of France?")
else:
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
docs = reader.load_data(show_progress=True)
index = VectorStoreIndex.from_documents(
documents=docs,
show_progress=True,
)
index.storage_context.persist(
persist_dir=f"vectorStores/{filename.replace('.', '_')}"
)
except Exception as e:
print(f"Error: {e}")
index = None
return index
def load_llm(model_name: str):
return llama_llms[model_name]
init_session_state()
st.set_page_config(page_title="Llama", page_icon="🦙")
st.header("Llama Index with Custom LLM Demo")
tab1, tab2 = st.tabs(["Config", "Chat"])
with tab1:
with st.form(key="llama_form"):
selected_llm_name = st.selectbox(
label="Select a model:", options=llama_llms.keys()
)
if selected_llm_name.startswith("openai"):
# ask for the api key
if st.secrets.get("OPENAI_API_KEY") is None:
# st.stop()
st.info("OpenAI API Key not found in secrets. Please enter it below.")
st.secrets["OPENAI_API_KEY"] = st.text_input(
"OpenAI API Key",
type="password",
help="Get your API key from https://platform.openai.com/account/api-keys",
)
selected_file = st.selectbox(
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
)
if st.form_submit_button(label="Submit"):
with st.status("Loading ...", expanded=True) as status:
st.write("Loading Model ...")
llama_llm = load_llm(selected_llm_name)
Settings.llm = llama_llm
st.write("Processing Data ...")
index = index_docs(selected_file)
if index is None:
st.error("Failed to index the documents.")
st.stop()
st.write("Finishing Up ...")
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
st.session_state.llama_custom = llama_custom
status.update(label="Ready to query!", state="complete", expanded=False)
with tab2:
messages_container = st.container(height=300)
show_previous_messages(framework="llama", messages_container=messages_container)
show_chat_input(
disabled=False,
framework="llama",
model=st.session_state.llama_custom,
messages_container=messages_container,
)
def clear_history():
messages_container.empty()
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if st.button("Clear Chat History"):
clear_history()
st.rerun()