mohcineelharras's picture
wokrs
6131df7
raw
history blame
13.3 kB
# --------------------------------libraries-----------------------------------
import streamlit as st
#import torch
import os
import logging
import sys
from llama_index.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.llms import LlamaCPP
from llama_index.embeddings import InstructorEmbedding
from llama_index import ServiceContext, VectorStoreIndex, SimpleDirectoryReader
from tqdm.notebook import tqdm
from dotenv import load_dotenv
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
# --------------------------------env variables-----------------------------------
# Load environment variables
load_dotenv(dotenv_path=".env")
no_proxy = os.getenv("no_proxy")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
# Text QA Prompt
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"You are Dolphin, a helpful AI assistant. "
"Answer questions based solely on the context provided. "
"Do not use information outside of the context. "
"Respond in the same language as the question. Be concise."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below:\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Based on this context, answer the question: {query_str}\n"
),
),
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
# Refine Prompt
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"You are Dolphin, focused on refining answers with additional context. "
"Use new context to refine the answer. "
"If the new context isn't useful, restate the original answer. "
"Be precise and match the language of the query."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"New context for refinement:\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Refine the original answer with this context for the question: {query_str}. "
"Original Answer: {existing_answer}"
),
),
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
template = (
"system\n"
"\"You are Dolphin, a helpful AI assistant. Your responses should be based solely on the content of documents you have access to, "
"including the specific context provided below. Do not provide information that is not contained in the documents or the context. "
"If a question is asked about content not in the documents or context, respond with 'I do not have that information.' "
"Always respond in the same language as the question was asked. Be concise.\n"
"Respond to the best of your ability. Try to respond in markdown.\"\n"
"context\n"
"{context}\n"
"user\n"
"{prompt}\n"
"assistant\n"
)
# --------------------------------cache LLM-----------------------------------
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug])
#One doc embedding
def load_emb_uploaded_document(filename):
# You may want to add a check to prevent execution during initialization.
if 'init' in st.session_state:
embed_model_inst = InstructorEmbedding("models/hkunlp_instructor-base")
service_context = ServiceContext.from_defaults(embed_model=embed_model_inst, llm=llm, chunk_size=500)
documents = SimpleDirectoryReader(input_files=[filename]).load_data()
index = VectorStoreIndex.from_documents(
documents, service_context=service_context, show_progress=True)
return index.as_query_engine(text_qa_template=text_qa_template, refine_template=refine_template)
return None
# --------------------------------cache Embedding model-----------------------------------
@st.cache_resource
def load_emb_model():
if not os.path.exists("data"):
st.error("Data directory does not exist. Please upload the data.")
os.makedirs("data")
return None #
embed_model_inst = InstructorEmbedding("models/hkunlp_instructor-base"
#model_name="hkunlp/instructor-base"
)
service_context = ServiceContext.from_defaults(embed_model=embed_model_inst,chunk_size=500,
llm=llm)
documents = SimpleDirectoryReader("data").load_data()
print(f"Number of documents: {len(documents)}")
index = VectorStoreIndex.from_documents(
documents, service_context=service_context, show_progress=True)
return index.as_query_engine(text_qa_template=text_qa_template, refine_template=refine_template)
# --------------------------------cache Embedding model-----------------------------------
# LLM
@st.cache_resource
def load_llm_model():
if not os.path.exists("models"):
st.error("models directory does not exist. Please download and copy paste a model in folder models.")
os.makedirs("models")
return None #
llm = LlamaCPP(
#model_url="https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q5_K_M.gguf",
model_path="models/dolphin-2.1-mistral-7b.Q4_K_S.gguf",
temperature=0.0,
max_new_tokens=100,
context_window=4096,
generate_kwargs={},
model_kwargs={"n_gpu_layers": 20},
verbose=True,
)
return llm
# ------------------------------------session state----------------------------------------
if 'memory' not in st.session_state:
st.session_state.memory = ""
# LLM Model Loading
if 'llm_model' not in st.session_state:
st.session_state.llm_model = load_llm_model()
# Use the models from session state
llm = st.session_state.llm_model
# Embedding Model Loading
if 'emb_model' not in st.session_state:
st.session_state.emb_model = load_emb_model()
# Use the models from session state
query_engine = st.session_state.emb_model
# ------------------------------------layout----------------------------------------
with st.sidebar:
api_server_info = st.text_input("Local LLM API server", OPENAI_API_BASE ,key="openai_api_base")
st.title("πŸ€– Llama Index πŸ“š")
if st.button('Clear Memory'):
del st.session_state["memory"]
st.session_state.memory = ""
st.write("Local LLM API server in this demo is useles, we are loading local model using llama_index integration of llama cpp")
st.write("πŸš€ This app allows you to chat with local LLM using api server or loaded in cache")
st.subheader("πŸ’» System Requirements: ")
st.markdown("- CPU: the faster the better ")
st.markdown("- RAM: 16 GB or higher")
st.markdown("- GPU: optional but very useful for Cuda acceleration")
st.subheader("Developer Information:")
st.write("This app is developed and maintained by **@mohcineelharras**")
# Define your app's tabs
tab1, tab2, tab3 = st.tabs(["LLM only", "LLM RAG QA with database", "One single document Q&A"])
# -----------------------------------LLM only---------------------------------------------
with tab1:
st.title("πŸ’¬ LLM only")
prompt = st.text_input(
"Ask your question here",
placeholder="How do miners contribute to the security of the blockchain ?",
)
if prompt:
contextual_prompt = st.session_state.memory + "\n" + prompt
response = llm.complete(prompt,max_tokens=100, temperature=0, top_p=0.95, top_k=10)
text_response = response
st.write("### Answer")
st.markdown(text_response)
st.session_state.memory = f"Prompt: {contextual_prompt}\nResponse:\n {text_response}"
with open("short_memory.txt", 'w') as file:
file.write(st.session_state.memory)
# -----------------------------------LLM Q&A-------------------------------------------------
with tab2:
st.title("πŸ’¬ LLM RAG QA with database")
st.write("To consult files that are available in the database, go to https://huggingface.co/spaces/mohcineelharras/llama-index-docs-spaces/tree/main/data")
prompt = st.text_input(
"Ask your question here",
placeholder="Who is Mohcine ?",
)
if prompt:
contextual_prompt = st.session_state.memory + "\n" + prompt
response = query_engine.query(contextual_prompt)
text_response = response.response
st.write("### Answer")
st.markdown(text_response)
st.session_state.memory = f"Prompt: {contextual_prompt}\nResponse:\n {text_response}"
with st.expander("Document Similarity Search"):
for i, node in enumerate(response.source_nodes):
dict_source_i = node.node.metadata
dict_source_i.update({"Text":node.node.text})
st.write("Source nΒ°"+str(i+1), dict_source_i)
break
st.session_state.memory = f"Prompt: {contextual_prompt}\nResponse:\n {text_response}"
with open("short_memory.txt", 'w') as file:
file.write(st.session_state.memory)
# -----------------------------------Upload File Q&A-----------------------------------------
with tab3:
st.title("πŸ“ One single document Q&A with Llama Index using local open llms")
if st.button('Reinitialize Query Engine', key='reinit_engine'):
query_engine = st.session_state.emb_model
st.write("Query engine reinitialized.")
uploaded_file = st.file_uploader("Upload an File", type=("txt", "csv", "md","pdf"))
question = st.text_input(
"Ask something about the files",
placeholder="Can you give me a short summary?",
disabled=not uploaded_file,
)
if 'init' not in st.session_state:
st.session_state.init = True
if uploaded_file:
if not os.path.exists("draft_docs"):
st.error("draft_docs directory does not exist. Please download and copy paste a model in folder models.")
os.makedirs("draft_docs")
with open("draft_docs/"+uploaded_file.name, "wb") as f:
text = uploaded_file.read()
f.write(text)
text = uploaded_file.read()
# if load_emb_uploaded_document:
# load_emb_uploaded_document.clear()
#load_emb_uploaded_document.clear()
query_engine = load_emb_uploaded_document("draft_docs/"+uploaded_file.name)
st.write("File ",uploaded_file.name, "was loaded successfully")
if uploaded_file and question and api_server_info:
contextual_prompt = st.session_state.memory + "\n" + question
response = query_engine.query(contextual_prompt)
text_response = response.response
st.write("### Answer")
st.markdown(text_response)
st.session_state.memory = f"Prompt: {contextual_prompt}\nResponse:\n {text_response}"
with open("short_memory.txt", 'w') as file:
file.write(st.session_state.memory)
with st.expander("Document Similarity Search"):
#st.write(len(response.source_nodes))
for i, node in enumerate(response.source_nodes):
dict_source_i = node.node.metadata
dict_source_i.update({"Text":node.node.text})
st.write("Source nΒ°"+str(i+1), dict_source_i)
#st.write("Source nΒ°"+str(i))
#st.write("Meta Data :", node.node.metadata)
#st.write("Text :", node.node.text)
#st.write()
#print("Is File uploaded : ",uploaded_file==True, "Is question asked : ", question==True, "Is question asked : ", api_server_info==True)
st.subheader('⚠️ Warning: To avoid lags')
st.markdown("Please consider **delete input prompt** and **clear memory** with the button on sidebar, each time you switch to another tab")
st.markdown("If you've got a GPU locally, the execution could be a **lot faster** (approximately 5 seconds on my local machine).")
st.markdown("""
<div style="text-align: center; margin-top: 20px;">
<a href="https://github.com/mohcineelharras/llama-index-docs" target="_blank" style="margin: 10px; display: inline-block;">
<img src="https://img.shields.io/badge/Repository-333?logo=github&style=for-the-badge" alt="Repository" style="vertical-align: middle;">
</a>
<a href="https://www.linkedin.com/in/mohcine-el-harras" target="_blank" style="margin: 10px; display: inline-block;">
<img src="https://img.shields.io/badge/-LinkedIn-0077B5?style=for-the-badge&logo=linkedin" alt="LinkedIn" style="vertical-align: middle;">
</a>
<a href="https://mohcineelharras.github.io" target="_blank" style="margin: 10px; display: inline-block;">
<img src="https://img.shields.io/badge/Visit-Portfolio-9cf?style=for-the-badge" alt="GitHub" style="vertical-align: middle;">
</a>
</div>
<div style="text-align: center; margin-top: 20px; color: #666; font-size: 0.85em;">
Β© 2023 Mohcine EL HARRAS
</div>
""", unsafe_allow_html=True)
# -----------------------------------end-----------------------------------------