|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import glob |
|
import textwrap |
|
import time |
|
|
|
import langchain |
|
|
|
|
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
from langchain import PromptTemplate, LLMChain |
|
|
|
|
|
from langchain.vectorstores import FAISS |
|
|
|
|
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
|
|
|
|
from langchain.chains import RetrievalQA |
|
|
|
import torch |
|
import transformers |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, |
|
pipeline |
|
) |
|
import gradio as gr |
|
import locale |
|
import shutil |
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
shutil.rmtree('./.cache', ignore_errors=True) |
|
|
|
class CFG: |
|
|
|
model_name = 'llama2-13b-chat' |
|
temperature = 0 |
|
top_p = 0.95 |
|
repetition_penalty = 1.15 |
|
|
|
|
|
split_chunk_size = 800 |
|
split_overlap = 0 |
|
|
|
|
|
embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
|
|
|
k = 6 |
|
|
|
|
|
PDFs_path = './' |
|
Embeddings_path = './faiss-hp-sentence-transformers' |
|
Output_folder = './rag-vectordb' |
|
|
|
def get_model(model=CFG.model_name): |
|
print('\nDownloading model: ', model, '\n\n') |
|
|
|
model_repo = 'daryl149/llama-2-13b-chat-hf' if model == 'llama2-13b-chat' else None |
|
|
|
if not model_repo: |
|
raise ValueError("Model not implemented: " + model) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_repo, |
|
device_map="auto", |
|
offload_folder="./offload", |
|
trust_remote_code=True |
|
) |
|
|
|
max_len = 2048 |
|
|
|
return tokenizer, model, max_len |
|
|
|
def wrap_text_preserve_newlines(text, width=700): |
|
lines = text.split('\n') |
|
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] |
|
|
|
return '\n'.join(wrapped_lines) |
|
|
|
def process_llm_response(llm_response): |
|
ans = wrap_text_preserve_newlines(llm_response['result']) |
|
|
|
sources_used = ' \n'.join( |
|
[ |
|
f"{source.metadata['source'].split('/')[-1][:-4]} - page: {source.metadata['page']}" |
|
for source in llm_response['source_documents'] |
|
] |
|
) |
|
|
|
return ans + '\n\nSources: \n' + sources_used |
|
|
|
def llm_ans(query): |
|
start = time.time() |
|
|
|
llm_response = qa_chain.invoke(query) |
|
ans = process_llm_response(llm_response) |
|
|
|
end = time.time() |
|
|
|
time_elapsed_str = f'\n\nTime elapsed: {int(round(end - start))} s' |
|
|
|
return ans + time_elapsed_str |
|
|
|
def predict(message, history): |
|
output = str(llm_ans(message)).replace("\n", "<br/>") |
|
return output |
|
|
|
tokenizer, model, max_len = get_model(model=CFG.model_name) |
|
|
|
pipe = pipeline( |
|
task="text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
pad_token_id=tokenizer.eos_token_id, |
|
max_length=max_len, |
|
temperature=CFG.temperature, |
|
top_p=CFG.top_p, |
|
repetition_penalty=CFG.repetition_penalty |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
loader = DirectoryLoader( |
|
CFG.PDFs_path, |
|
glob="./*.pdf", |
|
loader_cls=PyPDFLoader, |
|
show_progress=True, |
|
) |
|
|
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=CFG.split_chunk_size, |
|
chunk_overlap=CFG.split_overlap |
|
) |
|
|
|
texts = text_splitter.split_documents(documents) |
|
|
|
vectordb = FAISS.from_documents( |
|
texts, |
|
HuggingFaceInstructEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2') |
|
) |
|
|
|
|
|
vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag") |
|
|
|
retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k}) |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
) |
|
|
|
prompt_template = """ |
|
Don't try to make up an answer; if you don't know just say that you don't know. |
|
Answer in the same language the question was asked. |
|
Use only the following pieces of context to answer the question at the end. |
|
|
|
{context} |
|
|
|
Question: {question} |
|
Answer:""" |
|
|
|
PROMPT = PromptTemplate( |
|
template=prompt_template, |
|
input_variables=["context", "question"] |
|
) |
|
|
|
locale.getpreferredencoding = lambda: "UTF-8" |
|
|
|
demo = gr.ChatInterface( |
|
fn=predict, |
|
title=f'Open-Source LLM ({CFG.model_name}) Question Answering' |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|