|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
TextIteratorStreamer, |
|
) |
|
import os |
|
from threading import Thread |
|
import spaces |
|
import time |
|
|
|
import langchain |
|
import os |
|
import glob |
|
import gc |
|
|
|
|
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
from langchain import PromptTemplate |
|
|
|
|
|
from langchain_community.vectorstores import FAISS |
|
|
|
|
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
|
|
|
|
from langchain.chains import RetrievalQA |
|
|
|
|
|
import subprocess |
|
|
|
subprocess.run( |
|
"pip install flash-attn --no-build-isolation", |
|
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
|
shell=True, |
|
) |
|
|
|
|
|
class CFG: |
|
DEBUG = False |
|
|
|
|
|
model_name = 'justinj92/phi3-orpo' |
|
temperature = 0.7 |
|
top_p = 0.90 |
|
repetition_penalty = 1.15 |
|
max_len = 8192 |
|
max_new_tokens = 512 |
|
|
|
|
|
split_chunk_size = 800 |
|
split_overlap = 400 |
|
|
|
|
|
embeddings_model_repo = 'BAAI/bge-base-en-v1.5' |
|
|
|
|
|
k = 6 |
|
|
|
|
|
PDFs_path = './data' |
|
Embeddings_path = './embeddings/input' |
|
Output_folder = './ml-papers-vector' |
|
|
|
@spaces.GPU(duration=120) |
|
loader = DirectoryLoader(CFG.PDFs_path, glob="./*.pdf", loader_cls=PyPDFLoader,use_multithreading=True) |
|
|
|
@spaces.GPU(duration=120) |
|
documents = loader.load() |
|
|
|
@spaces.GPU(duration=120) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size = CFG.split_chunk_size, chunk_overlap = CFG.split_overlap) |
|
@spaces.GPU(duration=120) |
|
texts = text_splitter.split_documents(documents) |
|
@spaces.GPU(duration=120) |
|
if not os.path.exists(CFG.Embeddings_path + '/index.faiss'): |
|
embeddings = HuggingFaceInstructEmbeddings(model_name = CFG.embeddings_model_repo, model_kwargs={"device":"cuda"}) |
|
vectordb = FAISS.from_documents(documents=texts, embedding=embeddings) |
|
vectordb.save_local(f"{CFG.Output_folder}/faiss_index_ml_papers") |
|
|
|
embeddings = HuggingFaceInstructEmbeddings(model_name = CFG.embeddings_model_repo, model_kwargs={"device":"cuda"}) |
|
vectordb = FAISS.load_local(CFG.Output_folder + '/faiss_index_ml_papers', embeddings, allow_dangerous_deserialization=True) |
|
|
|
|
|
def build_model(model_repo = CFG.model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2") |
|
|
|
return tokenizer, model |
|
|
|
|
|
tok, model = build_model(model_repo = CFG.model_name) |
|
|
|
terminators = [ |
|
tok.eos_token_id, |
|
32007, |
|
32011, |
|
32001, |
|
32000 |
|
] |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print(f"Using GPU: {torch.cuda.get_device_name(device)}") |
|
else: |
|
device = torch.device("cpu") |
|
print("Using CPU") |
|
|
|
model = model.to(device) |
|
|
|
pipe = pipeline(task="text-generation", model=model, tokenizer=tok, eos_token_id=terminators, do_sample=True, max_new_tokens=CFG.max_new_tokens, temperature=CFG.temperature, top_p=CFG.top_p, repetition_penalty=CFG.repetition_penalty) |
|
|
|
llm = HuggingFacePipeline(pipeline = pipe) |
|
|
|
prompt_template = """ |
|
<|system|> |
|
|
|
You are an expert assistant that answers questions about machine learning and Large Language Models (LLMs). |
|
|
|
You are given some extracted parts from machine learning papers along with a question. |
|
|
|
If you don't know the answer, just say "I don't know." Don't try to make up an answer. |
|
|
|
It is very important that you ALWAYS answer the question in the same language the question is in. Remember to always do that. |
|
|
|
Use only the following pieces of context to answer the question at the end. |
|
|
|
<|end|> |
|
|
|
<|user|> |
|
|
|
Context: {context} |
|
|
|
Question is below. Remember to answer in the same language: |
|
|
|
Question: {question} |
|
|
|
<|end|> |
|
|
|
<|assistant|> |
|
|
|
""" |
|
|
|
|
|
PROMPT = PromptTemplate( |
|
template = prompt_template, |
|
input_variables = ["context", "question"] |
|
) |
|
|
|
retriever = vectordb.as_retriever( |
|
search_type = "similarity", |
|
search_kwargs = {"k": CFG.k} |
|
) |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm = llm, |
|
chain_type = "stuff", |
|
retriever = retriever, |
|
chain_type_kwargs = {"prompt": PROMPT}, |
|
return_source_documents = True, |
|
verbose = False |
|
) |
|
|
|
@spaces.GPU(duration=120) |
|
def wrap_text_preserve_newlines(text, width=1500): |
|
|
|
lines = text.split('\n') |
|
|
|
|
|
wrapped_lines = [textwrap.fill(line, width=width) for line in lines] |
|
|
|
|
|
wrapped_text = '\n'.join(wrapped_lines) |
|
|
|
return wrapped_text |
|
|
|
@spaces.GPU(duration=120) |
|
def process_llm_response(llm_response): |
|
ans = wrap_text_preserve_newlines(llm_response['result']) |
|
|
|
sources_used = ' \n'.join( |
|
[ |
|
source.metadata['source'].split('/')[-1][:-4] |
|
+ ' - page: ' |
|
+ str(source.metadata['page']) |
|
for source in llm_response['source_documents'] |
|
] |
|
) |
|
|
|
ans = ans + '\n\nSources: \n' + sources_used |
|
|
|
|
|
pattern = "<|assistant|>" |
|
index = ans.find(pattern) |
|
if index != -1: |
|
ans = ans[index + len(pattern):] |
|
|
|
return ans.strip() |
|
|
|
@spaces.GPU(duration=120) |
|
def llm_ans(query): |
|
|
|
llm_response = qa_chain.invoke(query) |
|
ans = process_llm_response(llm_response) |
|
|
|
return ans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=llm_ans, |
|
examples=[["Write me a poem about Machine Learning."]], |
|
|
|
additional_inputs_accordion=gr.Accordion( |
|
label="⚙️ Parameters", open=False, render=False |
|
), |
|
additional_inputs=[ |
|
gr.Slider( |
|
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False |
|
), |
|
gr.Checkbox(label="Sampling", value=True), |
|
gr.Slider( |
|
minimum=128, |
|
maximum=4096, |
|
step=1, |
|
value=512, |
|
label="Max new tokens", |
|
render=False, |
|
), |
|
], |
|
stop_btn="Stop Generation", |
|
title="Chat With LLMs", |
|
description="Now Running Phi3-ORPO", |
|
) |
|
demo.launch() |
|
|