|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
import os
|
|
import glob
|
|
import textwrap
|
|
import time
|
|
|
|
import langchain
|
|
|
|
|
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
|
from langchain import PromptTemplate, LLMChain
|
|
|
|
|
|
from langchain.vectorstores import FAISS
|
|
|
|
|
|
from langchain.llms import HuggingFacePipeline
|
|
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
|
|
|
|
from langchain.chains import RetrievalQA
|
|
|
|
import torch
|
|
import transformers
|
|
from transformers import (
|
|
AutoTokenizer, AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
pipeline
|
|
)
|
|
import gradio as gr
|
|
import locale
|
|
import time
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
|
class CFG:
|
|
|
|
model_name = 'llama2-13b-chat'
|
|
temperature = 0
|
|
top_p = 0.95
|
|
repetition_penalty = 1.15
|
|
|
|
|
|
split_chunk_size = 800
|
|
split_overlap = 0
|
|
|
|
|
|
embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
|
|
|
|
|
|
k = 6
|
|
|
|
|
|
PDFs_path = './'
|
|
Embeddings_path = './faiss-hp-sentence-transformers'
|
|
Output_folder = './rag-vectordb'
|
|
|
|
def get_model(model = CFG.model_name):
|
|
|
|
print('\nDownloading model: ', model, '\n\n')
|
|
|
|
if model == 'wizardlm':
|
|
model_repo = 'TheBloke/wizardLM-7B-HF'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit = True,
|
|
bnb_4bit_quant_type = "nf4",
|
|
bnb_4bit_compute_dtype = torch.float16,
|
|
bnb_4bit_use_double_quant = True,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_repo,
|
|
quantization_config = bnb_config,
|
|
device_map = 'auto',
|
|
low_cpu_mem_usage = True
|
|
)
|
|
|
|
max_len = 1024
|
|
|
|
elif model == 'llama2-7b-chat':
|
|
model_repo = 'daryl149/llama-2-7b-chat-hf'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit = True,
|
|
bnb_4bit_quant_type = "nf4",
|
|
bnb_4bit_compute_dtype = torch.float16,
|
|
bnb_4bit_use_double_quant = True,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_repo,
|
|
quantization_config = bnb_config,
|
|
device_map = 'auto',
|
|
low_cpu_mem_usage = True,
|
|
trust_remote_code = True
|
|
)
|
|
|
|
max_len = 2048
|
|
|
|
elif model == 'llama2-13b-chat':
|
|
model_repo = 'daryl149/llama-2-13b-chat-hf'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit = True,
|
|
bnb_4bit_quant_type = "nf4",
|
|
bnb_4bit_compute_dtype = torch.float16,
|
|
bnb_4bit_use_double_quant = True,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_repo,
|
|
quantization_config = bnb_config,
|
|
|
|
low_cpu_mem_usage = True,
|
|
trust_remote_code = True
|
|
)
|
|
|
|
max_len = 2048
|
|
truncation=True,
|
|
padding="max_len"
|
|
|
|
elif model == 'mistral-7B':
|
|
model_repo = 'mistralai/Mistral-7B-v0.1'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit = True,
|
|
bnb_4bit_quant_type = "nf4",
|
|
bnb_4bit_compute_dtype = torch.float16,
|
|
bnb_4bit_use_double_quant = True,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_repo,
|
|
quantization_config = bnb_config,
|
|
device_map = 'auto',
|
|
low_cpu_mem_usage = True,
|
|
)
|
|
|
|
max_len = 1024
|
|
|
|
else:
|
|
print("Not implemented model (tokenizer and backbone)")
|
|
|
|
return tokenizer, model, max_len
|
|
|
|
def wrap_text_preserve_newlines(text, width=700):
|
|
|
|
lines = text.split('\n')
|
|
|
|
|
|
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
|
|
|
|
|
|
wrapped_text = '\n'.join(wrapped_lines)
|
|
|
|
return wrapped_text
|
|
|
|
|
|
def process_llm_response(llm_response):
|
|
ans = wrap_text_preserve_newlines(llm_response['result'])
|
|
|
|
sources_used = ' \n'.join(
|
|
[
|
|
source.metadata['source'].split('/')[-1][:-4]
|
|
+ ' - page: '
|
|
+ str(source.metadata['page'])
|
|
for source in llm_response['source_documents']
|
|
]
|
|
)
|
|
|
|
ans = ans + '\n\nSources: \n' + sources_used
|
|
return ans
|
|
|
|
def llm_ans(query):
|
|
start = time.time()
|
|
|
|
llm_response = qa_chain.invoke(query)
|
|
ans = process_llm_response(llm_response)
|
|
|
|
end = time.time()
|
|
|
|
time_elapsed = int(round(end - start, 0))
|
|
time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
|
|
return ans + time_elapsed_str
|
|
|
|
def predict(message, history):
|
|
output = str(llm_ans(message)).replace("\n", "<br/>")
|
|
return output
|
|
|
|
|
|
|
|
|
|
tokenizer, model, max_len = get_model(model = CFG.model_name)
|
|
|
|
pipe = pipeline(
|
|
task = "text-generation",
|
|
model = model,
|
|
tokenizer = tokenizer,
|
|
pad_token_id = tokenizer.eos_token_id,
|
|
|
|
max_length = max_len,
|
|
temperature = CFG.temperature,
|
|
top_p = CFG.top_p,
|
|
repetition_penalty = CFG.repetition_penalty
|
|
)
|
|
|
|
|
|
llm = HuggingFacePipeline(pipeline = pipe)
|
|
|
|
loader = DirectoryLoader(
|
|
CFG.PDFs_path,
|
|
glob="./*.pdf",
|
|
loader_cls=PyPDFLoader,
|
|
show_progress=True,
|
|
use_multithreading=True
|
|
)
|
|
|
|
documents = loader.load()
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size = CFG.split_chunk_size,
|
|
chunk_overlap = CFG.split_overlap
|
|
)
|
|
|
|
texts = text_splitter.split_documents(documents)
|
|
|
|
vectordb = FAISS.from_documents(
|
|
texts,
|
|
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
|
|
)
|
|
|
|
|
|
vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
|
|
|
|
retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(
|
|
llm = llm,
|
|
chain_type = "stuff",
|
|
retriever = retriever,
|
|
chain_type_kwargs = {"prompt": PROMPT},
|
|
return_source_documents = True,
|
|
verbose = False
|
|
)
|
|
|
|
prompt_template = """
|
|
Don't try to make up an answer, if you don't know just say that you don't know.
|
|
Answer in the same language the question was asked.
|
|
Use only the following pieces of context to answer the question at the end.
|
|
|
|
{context}
|
|
|
|
Question: {question}
|
|
Answer:"""
|
|
|
|
|
|
PROMPT = PromptTemplate(
|
|
template = prompt_template,
|
|
input_variables = ["context", "question"]
|
|
)
|
|
|
|
|
|
locale.getpreferredencoding = lambda: "UTF-8"
|
|
|
|
demo = gr.ChatInterface(
|
|
predict,
|
|
title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
|
|
)
|
|
|
|
demo.queue()
|
|
demo.launch() |