Ubai's picture
Update app.py
d20d8d3 verified
raw
history blame
2.82 kB
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader # Updated import
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma # Updated import
from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
from langchain_community.llms import HuggingFaceHub # Updated import
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate
# Default LLM model
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
# Other settings
default_persist_directory = './chroma_HF/'
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
"google/gemma-7b-it","google/gemma-2b-it", \
"HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
"google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load vector database
def load_db():
embedding = HuggingFaceEmbeddings()
vectordb = Chroma(
persist_directory=default_persist_directory,
embedding_function=embedding)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(vector_db, progress=gr.Progress()):
progress(0.5, desc="Initializing HF Hub...")
# Use of trust_remote_code as model_kwargs
# Warning: langchain issue
# URL: https://github.com/langchain-ai/langchain/issues/6080
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True}
)
# ... (other model configurations for different model options)
else:
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3}
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever=vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
# ... (other functions remain the same)