File size: 2,815 Bytes
3ec9224
5be8df6
339ce69
d20d8d3
80f5e00
d20d8d3
 
 
339ce69
af00b58
339ce69
1ef8d7c
 
339ce69
af00b58
 
 
339ce69
af00b58
 
 
339ce69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af00b58
339ce69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import os

from langchain_community.document_loaders import PyPDFLoader  # Updated import
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma  # Updated import
from langchain_community.embeddings import HuggingFaceEmbeddings  # Updated import
from langchain_community.llms import HuggingFaceHub  # Updated import
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from pathlib import Path
import chromadb

from transformers import AutoTokenizer
import transformers
import torch
import tqdm 
import accelerate


# Default LLM model
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"

# Other settings
default_persist_directory = './chroma_HF/'
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
    "google/gemma-7b-it","google/gemma-2b-it", \
    "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
    "google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# Load vector database
def load_db():
    embedding = HuggingFaceEmbeddings()
    vectordb = Chroma(
        persist_directory=default_persist_directory, 
        embedding_function=embedding)
    return vectordb


# Initialize langchain LLM chain
def initialize_llmchain(vector_db, progress=gr.Progress()):
    progress(0.5, desc="Initializing HF Hub...")
    # Use of trust_remote_code as model_kwargs
    # Warning: langchain issue
    # URL: https://github.com/langchain-ai/langchain/issues/6080
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        llm = HuggingFaceHub(
            repo_id=llm_model, 
            model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True}
        )
    # ... (other model configurations for different model options)
    else:
        llm = HuggingFaceHub(
            repo_id=llm_model, 
            model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3}
        )
    
    progress(0.75, desc="Defining buffer memory...")
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    retriever=vector_db.as_retriever()
    progress(0.8, desc="Defining retrieval chain...")
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    progress(0.9, desc="Done!")
    return qa_chain


# ... (other functions remain the same)