Harsh2001's picture
Rename utils.py to utils1.py
8982b98 verified
import os
import torch
import pickle
import numpy as np
from langchain.schema import Document
from langchain.llms import HuggingFacePipeline
from sentence_transformers import SentenceTransformer, util
from langchain.chains.question_answering import load_qa_chain
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline, StoppingCriteria, StoppingCriteriaList
hf_auth = os.getenv('hf_auth')
# LLM Model
model_id = 'google/t5-v1_1-base'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained(model_id)
# Load model
llm_model = T5ForConditionalGeneration.from_pretrained(model_id)
llm_model.to(device)
llm_model.eval() # Set model to evaluation mode
stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer.encode(x) for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if input_ids.shape[-1] >= stop_ids.shape[-1] and torch.all(input_ids[0, -stop_ids.shape[-1]:] == stop_ids):
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
generate_text = pipeline(
'text2text-generation',
model=llm_model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1,
stopping_criteria=stopping_criteria, # without this model rambles during chat
temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
max_length=512, # max number of tokens to generate in the output
repetition_penalty=1.1 # without this output begins repeating
)
llm = HuggingFacePipeline(pipeline=generate_text)
# Embedding Model
embed_model = SentenceTransformer("all-mpnet-base-v2").to(device)
with open('documents.pkl', 'rb') as f:
documents = pickle.load(f)
# Load embeddings from the file
embeddings = np.load('embeddings.npy')
def get_similar_docs(query, top_k, docs, embeddings):
query_embedding = embed_model.encode(query, convert_to_tensor=True)
similarity_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
result_idx = torch.topk(similarity_scores, k=top_k).indices
return [docs[i] for i in result_idx]
chain = load_qa_chain(llm, chain_type='stuff')
conversation_history = []
def get_answer(query):
global conversation_history
similar_docs = get_similar_docs(query, 10, documents, embeddings)
conversation_history.append(Document(page_content=f"User: {query}"))
combined_docs = conversation_history + similar_docs
answer = chain.run(input_documents=combined_docs, question=query)
if "\nHelpful Answer:" in answer:
answer = answer.split("\nHelpful Answer:")[1].split("\nUnhelpful Answers:")[0].strip()
else:
answer = "I don't know, please provide more context."
conversation_history.append(Document(page_content=f"Assistant: {answer}"))
return answer