Spaces:
Sleeping
Sleeping
File size: 5,355 Bytes
4bb745d 19f4fce 4bb745d 19f4fce 4bb745d 19f4fce 4bb745d 19f4fce 4bb745d 19f4fce 4bb745d 19f4fce 4bb745d 13c5bb4 4bb745d 19f4fce 4bb745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import random
import time
import streamlit as st
import os
import pathlib
from typing import List
from models.llms import (
llm_llama_2_7b_chat,
llm_mixtral_8x7b,
llm_bloomz_560m,
llm_gpt_3_5_turbo,
llm_gpt_3_5_turbo_0125,
llm_gpt_4_0125,
llm_llama_13b_v2_replicate
)
from models.embeddings import hf_embed_model, openai_embed_model
from models.llamaCustom import LlamaCustom
# from models.llamaCustom import LlamaCustom
from utils.chatbox import show_previous_messages, show_chat_input
from llama_index.core import (
SimpleDirectoryReader,
Document,
VectorStoreIndex,
StorageContext,
Settings,
load_index_from_storage,
)
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.base.llms.types import ChatMessage
SAVE_DIR = "uploaded_files"
VECTOR_STORE_DIR = "vectorStores"
# global
Settings.embed_model = hf_embed_model
llama_llms = {
"bigscience/bloomz-560m": llm_bloomz_560m,
"mistral/mixtral": llm_mixtral_8x7b,
"meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat,
# "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo,
"openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125,
# "openai/gpt-4-0125-preview": llm_gpt_4_0125,
# "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
}
def init_session_state():
if "llama_messages" not in st.session_state:
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
# TODO: create a chat history for each different document
if "llama_chat_history" not in st.session_state:
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if "llama_custom" not in st.session_state:
st.session_state.llama_custom = None
# @st.cache_resource
def index_docs(
filename: str,
) -> VectorStoreIndex:
try:
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
if pathlib.Path.exists(index_path):
print("Loading index from storage ...")
storage_context = StorageContext.from_defaults(persist_dir=index_path)
index = load_index_from_storage(storage_context=storage_context)
# test the index
index.as_query_engine().query("What is the capital of France?")
else:
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
docs = reader.load_data(show_progress=True)
index = VectorStoreIndex.from_documents(
documents=docs,
show_progress=True,
)
index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}")
except Exception as e:
print(f"Error: {e}")
index = None
return index
def load_llm(model_name: str):
return llama_llms[model_name]
init_session_state()
st.set_page_config(page_title="Llama", page_icon="🦙")
st.header("Llama Index with Custom LLM Demo")
tab1, tab2 = st.tabs(["Config", "Chat"])
with tab1:
with st.form(key="llama_form"):
selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys())
if selected_llm_name.startswith("openai"):
# ask for the api key
if st.secrets.get("OPENAI_API_KEY") is None:
# st.stop()
st.info("OpenAI API Key not found in secrets. Please enter it below.")
st.secrets["OPENAI_API_KEY"] = st.text_input(
"OpenAI API Key",
type="password",
help="Get your API key from https://platform.openai.com/account/api-keys",
)
selected_file = st.selectbox(
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
)
if st.form_submit_button(label="Submit"):
with st.status("Loading ...", expanded=True) as status:
st.write("Loading Model ...")
llama_llm = load_llm(selected_llm_name)
Settings.llm = llama_llm
st.write("Processing Data ...")
index = index_docs(selected_file)
if index is None:
st.error("Failed to index the documents.")
st.stop()
st.write("Finishing Up ...")
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
st.session_state.llama_custom = llama_custom
status.update(label="Ready to query!", state="complete", expanded=False)
with tab2:
messages_container = st.container(height=300)
show_previous_messages(framework="llama", messages_container=messages_container)
show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container)
def clear_history():
messages_container.empty()
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if st.button("Clear Chat History"):
clear_history()
st.rerun()
|