|
import os |
|
import torch |
|
import pickle |
|
import numpy as np |
|
from langchain.schema import Document |
|
from langchain.llms import HuggingFacePipeline |
|
from sentence_transformers import SentenceTransformer, util |
|
from langchain.chains.question_answering import load_qa_chain |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
hf_auth = os.getenv('hf_auth') |
|
|
|
model_id = 'google/t5-v1_1-base' |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
|
|
|
|
llm_model = T5ForConditionalGeneration.from_pretrained(model_id) |
|
llm_model.to(device) |
|
llm_model.eval() |
|
|
|
stop_list = ['\nHuman:', '\n```\n'] |
|
|
|
stop_token_ids = [tokenizer.encode(x) for x in stop_list] |
|
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids] |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_ids in stop_token_ids: |
|
if input_ids.shape[-1] >= stop_ids.shape[-1] and torch.all(input_ids[0, -stop_ids.shape[-1]:] == stop_ids): |
|
return True |
|
return False |
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) |
|
|
|
generate_text = pipeline( |
|
'text2text-generation', |
|
model=llm_model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1, |
|
stopping_criteria=stopping_criteria, |
|
temperature=0.1, |
|
max_length=512, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
llm = HuggingFacePipeline(pipeline=generate_text) |
|
|
|
|
|
embed_model = SentenceTransformer("all-mpnet-base-v2").to(device) |
|
|
|
with open('documents.pkl', 'rb') as f: |
|
documents = pickle.load(f) |
|
|
|
|
|
embeddings = np.load('embeddings.npy') |
|
|
|
def get_similar_docs(query, top_k, docs, embeddings): |
|
query_embedding = embed_model.encode(query, convert_to_tensor=True) |
|
similarity_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] |
|
result_idx = torch.topk(similarity_scores, k=top_k).indices |
|
return [docs[i] for i in result_idx] |
|
|
|
chain = load_qa_chain(llm, chain_type='stuff') |
|
|
|
conversation_history = [] |
|
|
|
def get_answer(query): |
|
global conversation_history |
|
similar_docs = get_similar_docs(query, 10, documents, embeddings) |
|
|
|
conversation_history.append(Document(page_content=f"User: {query}")) |
|
|
|
combined_docs = conversation_history + similar_docs |
|
|
|
answer = chain.run(input_documents=combined_docs, question=query) |
|
|
|
if "\nHelpful Answer:" in answer: |
|
answer = answer.split("\nHelpful Answer:")[1].split("\nUnhelpful Answers:")[0].strip() |
|
else: |
|
answer = "I don't know, please provide more context." |
|
|
|
conversation_history.append(Document(page_content=f"Assistant: {answer}")) |
|
return answer |