|
import os |
|
import gradio as gr |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.retrievers import MultiQueryRetriever |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain_community.llms import llamacpp, huggingface_pipeline |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.chains.question_answering import load_qa_chain |
|
from huggingface_hub import hf_hub_download, login |
|
login(os.environ['hf_token']) |
|
|
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a |
|
standalone question without changing the content in given question. |
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions. |
|
Read the given context before answering questions and think step by step. If you can not answer a user question based on the provided context, inform the user. |
|
Do not use any other information for answering the user. Provide a detailed answer to the question.""" |
|
|
|
def load_quantized_model(model_id=None): |
|
if model_id == "Zephyr-7b-Beta": |
|
MODEL_ID, MODEL_BASENAME = os.environ['model_id_1'], os.environ['model_basename_1'] |
|
else: |
|
MODEL_ID, MODEL_BASENAME = os.environ['model_id_2'], os.environ['model_basename_2'] |
|
|
|
try: |
|
model_path = hf_hub_download( |
|
repo_id=MODEL_ID, |
|
filename=MODEL_BASENAME, |
|
resume_download=True, |
|
cache_dir = "models" |
|
) |
|
kwargs = { |
|
'model_path': model_path, |
|
'n_ctx': 20000, |
|
'max_tokens': 15000, |
|
'n_batch': 1024, |
|
|
|
} |
|
return llamacpp.LlamaCpp(**kwargs) |
|
except TypeError: |
|
print("Supported model architecture: Llama, Mistral") |
|
return None |
|
|
|
def upload_files(files): |
|
file_paths = [file.name for file in files] |
|
return file_paths |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
<h2> <center> PrivateGPT </center> </h2> |
|
""") |
|
|
|
with gr.Row(): |
|
persist_directory = "book1_raw_no_processing" |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name = "BAAI/bge-large-en-v1.5", |
|
model_kwargs={"device": "cpu"}, |
|
encode_kwargs = {'normalize_embeddings':True}, |
|
cache_folder="models", |
|
) |
|
db2 = Chroma(persist_directory = persist_directory,embedding_function = embeddings) |
|
|
|
|
|
llm = load_quantized_model() |
|
|
|
condense_question_prompt_template = PromptTemplate.from_template(_template) |
|
prompt_template = system_prompt + """ |
|
{context} |
|
Question: {question} |
|
Helpful Answer:""" |
|
qa_prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
memory = ConversationBufferWindowMemory(memory_key='chat_history', k=1, return_messages=True) |
|
retriever_from_llm = MultiQueryRetriever.from_llm( |
|
retriever=db2.as_retriever(search_kwargs={'k':10}), |
|
llm = llm, |
|
) |
|
qa2 = ConversationalRetrievalChain( |
|
retriever=retriever_from_llm, |
|
question_generator= LLMChain(llm=llm, prompt=condense_question_prompt_template, memory=memory, verbose=True), |
|
combine_docs_chain=load_qa_chain(llm=llm, chain_type="stuff", prompt=qa_prompt, verbose=True), |
|
memory=memory, |
|
verbose=True, |
|
|
|
) |
|
def add_text(history, text): |
|
history = history + [(text, None)] |
|
return history, "" |
|
|
|
def bot(history): |
|
res = qa2.invoke( |
|
{ |
|
'question': history[-1][0], |
|
'chat_history': history[:-1] |
|
} |
|
) |
|
history[-1][1] = res['answer'] |
|
|
|
return history |
|
with gr.Column(scale=9): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot([], elem_id="chatbot",label="Chat", height=500, show_label=True, avatar_images=["user.jpeg","Bot.jpg"]) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter text and press enter", |
|
container=False, |
|
) |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
model_id = gr.Radio(["Zephyr-7b-Beta", "Llama-2-7b-chat"], value="Zephyr-7b-Beta",label="LLM Model") |
|
with gr.Row(): |
|
mode = gr.Radio(['OITF Manuals', 'Operations Data'], value='Operations Data',label="QA mode") |
|
|
|
with gr.Column(scale=8): |
|
None |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button( |
|
'Clear', |
|
variant="stop" |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button( |
|
'Submit', |
|
variant='primary' |
|
) |
|
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then( |
|
bot, chatbot, chatbot |
|
) |
|
submit_btn.click(add_text, [chatbot, txt], [chatbot, txt]).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(max_threads=8, debug=True) |
|
|