Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from langchain_community.document_loaders import PyPDFLoader # Updated import | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma # Updated import | |
from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import | |
from langchain_community.llms import HuggingFaceHub # Updated import | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from pathlib import Path | |
import chromadb | |
from transformers import AutoTokenizer | |
import transformers | |
import torch | |
import tqdm | |
import accelerate | |
# Default LLM model | |
llm_model = "mistralai/Mistral-7B-Instruct-v0.2" | |
# Other settings | |
default_persist_directory = './chroma_HF/' | |
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \ | |
"google/gemma-7b-it","google/gemma-2b-it", \ | |
"HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \ | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \ | |
"google/flan-t5-xxl" | |
] | |
list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
# Load vector database | |
def load_db(): | |
embedding = HuggingFaceEmbeddings() | |
vectordb = Chroma( | |
persist_directory=default_persist_directory, | |
embedding_function=embedding) | |
return vectordb | |
# Initialize langchain LLM chain | |
def initialize_llmchain(vector_db, progress=gr.Progress()): | |
progress(0.5, desc="Initializing HF Hub...") | |
# Use of trust_remote_code as model_kwargs | |
# Warning: langchain issue | |
# URL: https://github.com/langchain-ai/langchain/issues/6080 | |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": | |
llm = HuggingFaceHub( | |
repo_id=llm_model, | |
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True} | |
) | |
# ... (other model configurations for different model options) | |
else: | |
llm = HuggingFaceHub( | |
repo_id=llm_model, | |
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3} | |
) | |
progress(0.75, desc="Defining buffer memory...") | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
retriever=vector_db.as_retriever() | |
progress(0.8, desc="Defining retrieval chain...") | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
progress(0.9, desc="Done!") | |
return qa_chain | |
# ... (other functions remain the same) | |