|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import glob |
|
import textwrap |
|
import time |
|
|
|
import langchain |
|
|
|
|
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
from langchain import PromptTemplate, LLMChain |
|
|
|
|
|
from langchain.vectorstores import FAISS |
|
|
|
|
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
|
|
|
|
from langchain.chains import RetrievalQA |
|
|
|
import torch |
|
import transformers |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, |
|
BitsAndBytesConfig, |
|
pipeline |
|
) |
|
|
|
|
|
sorted(glob.glob('/content/anatomy_vol_*')) |
|
|
|
class CFG: |
|
|
|
model_name = 'llama2-13b-chat' |
|
temperature = 0 |
|
top_p = 0.95 |
|
repetition_penalty = 1.15 |
|
|
|
|
|
split_chunk_size = 800 |
|
split_overlap = 0 |
|
|
|
|
|
embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2' |
|
|
|
|
|
k = 6 |
|
|
|
|
|
PDFs_path = '/content/' |
|
Embeddings_path = '/content/faiss-hp-sentence-transformers' |
|
Output_folder = './rag-vectordb' |
|
|
|
def get_model(model = CFG.model_name): |
|
|
|
print('\nDownloading model: ', model, '\n\n') |
|
|
|
if model == 'wizardlm': |
|
model_repo = 'TheBloke/wizardLM-7B-HF' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_quant_type = "nf4", |
|
bnb_4bit_compute_dtype = torch.float16, |
|
bnb_4bit_use_double_quant = True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_repo, |
|
quantization_config = bnb_config, |
|
device_map = 'auto', |
|
low_cpu_mem_usage = True |
|
) |
|
|
|
max_len = 1024 |
|
|
|
elif model == 'llama2-7b-chat': |
|
model_repo = 'daryl149/llama-2-7b-chat-hf' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_quant_type = "nf4", |
|
bnb_4bit_compute_dtype = torch.float16, |
|
bnb_4bit_use_double_quant = True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_repo, |
|
quantization_config = bnb_config, |
|
device_map = 'auto', |
|
low_cpu_mem_usage = True, |
|
trust_remote_code = True |
|
) |
|
|
|
max_len = 2048 |
|
|
|
elif model == 'llama2-13b-chat': |
|
model_repo = 'daryl149/llama-2-13b-chat-hf' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_quant_type = "nf4", |
|
bnb_4bit_compute_dtype = torch.float16, |
|
bnb_4bit_use_double_quant = True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_repo, |
|
quantization_config = bnb_config, |
|
|
|
low_cpu_mem_usage = True, |
|
trust_remote_code = True |
|
) |
|
|
|
max_len = 2048 |
|
truncation=True, |
|
padding="max_len" |
|
|
|
elif model == 'mistral-7B': |
|
model_repo = 'mistralai/Mistral-7B-v0.1' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_quant_type = "nf4", |
|
bnb_4bit_compute_dtype = torch.float16, |
|
bnb_4bit_use_double_quant = True, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_repo, |
|
quantization_config = bnb_config, |
|
device_map = 'auto', |
|
low_cpu_mem_usage = True, |
|
) |
|
|
|
max_len = 1024 |
|
|
|
else: |
|
print("Not implemented model (tokenizer and backbone)") |
|
|
|
return tokenizer, model, max_len |
|
|
|
print(torch.cuda.is_available()) |
|
print(torch.cuda.device_count()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
model.hf_device_map |
|
|
|
|
|
pipe = pipeline( |
|
task = "text-generation", |
|
model = model, |
|
tokenizer = tokenizer, |
|
pad_token_id = tokenizer.eos_token_id, |
|
|
|
max_length = max_len, |
|
temperature = CFG.temperature, |
|
top_p = CFG.top_p, |
|
repetition_penalty = CFG.repetition_penalty |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline = pipe) |
|
|
|
llm |
|
|
|
query = "what are the structural organization of a human body" |
|
llm.invoke(query) |
|
|
|
"""Langchain""" |
|
|
|
CFG.model_name |
|
|
|
"""Loader""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f'We have {len(documents)} pages in total') |
|
|
|
documents[8].page_content |
|
|
|
"""Splitter""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size = CFG.split_chunk_size, |
|
chunk_overlap = CFG.split_overlap |
|
) |
|
|
|
texts = text_splitter.split_documents(documents) |
|
|
|
print(f'We have created {len(texts)} chunks from {len(documents)} pages') |
|
|
|
"""Create Embeddings""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Prompt Template""" |
|
|
|
prompt_template = """ |
|
Don't try to make up an answer, if you don't know just say that you don't know. |
|
Answer in the same language the question was asked. |
|
Use only the following pieces of context to answer the question at the end. |
|
|
|
{context} |
|
|
|
Question: {question} |
|
Answer:""" |
|
|
|
|
|
PROMPT = PromptTemplate( |
|
template = prompt_template, |
|
input_variables = ["context", "question"] |
|
) |
|
|
|
"""Retriever chain""" |
|
|
|
retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"}) |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm = llm, |
|
chain_type = "stuff", |
|
retriever = retriever, |
|
chain_type_kwargs = {"prompt": PROMPT}, |
|
return_source_documents = True, |
|
verbose = False |
|
) |
|
|
|
question = "what are the structural organization of a human body" |
|
vectordb.max_marginal_relevance_search(question, k = CFG.k) |
|
|
|
|
|
question = "what are the structural organization of a human body" |
|
vectordb.similarity_search(question, k = CFG.k) |
|
|
|
"""Post-process outputs""" |
|
|
|
def wrap_text_preserve_newlines(text, width=700): |
|
|
|
lines = text.split('\n') |
|
|
|
|
|
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] |
|
|
|
|
|
wrapped_text = '\n'.join(wrapped_lines) |
|
|
|
return wrapped_text |
|
|
|
|
|
def process_llm_response(llm_response): |
|
ans = wrap_text_preserve_newlines(llm_response['result']) |
|
|
|
sources_used = ' \n'.join( |
|
[ |
|
source.metadata['source'].split('/')[-1][:-4] |
|
+ ' - page: ' |
|
+ str(source.metadata['page']) |
|
for source in llm_response['source_documents'] |
|
] |
|
) |
|
|
|
ans = ans + '\n\nSources: \n' + sources_used |
|
return ans |
|
|
|
def llm_ans(query): |
|
start = time.time() |
|
|
|
llm_response = qa_chain.invoke(query) |
|
ans = process_llm_response(llm_response) |
|
|
|
end = time.time() |
|
|
|
time_elapsed = int(round(end - start, 0)) |
|
time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s' |
|
return ans + time_elapsed_str |
|
|
|
query =question = "what are the structural organization of a human body" |
|
print(llm_ans(query)) |
|
|
|
"""Gradio Chat UI (Inspired from HinePo)""" |
|
|
|
import gradio as gr |
|
import locale |
|
locale.getpreferredencoding = lambda: "UTF-8" |
|
|
|
def predict(message, history): |
|
|
|
output = str(llm_ans(message)).replace("\n", "<br/>") |
|
return output |
|
|
|
demo = gr.ChatInterface( |
|
fn=predict, |
|
title=f'Open-Source LLM ({CFG["model_name"]}) Question Answering' |
|
) |
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
|