import os import torch import pickle import numpy as np from langchain.schema import Document from langchain.llms import HuggingFacePipeline from sentence_transformers import SentenceTransformer, util from langchain.chains.question_answering import load_qa_chain from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline, StoppingCriteria, StoppingCriteriaList hf_auth = os.getenv('hf_auth') # LLM Model model_id = 'google/t5-v1_1-base' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = T5Tokenizer.from_pretrained(model_id) # Load model llm_model = T5ForConditionalGeneration.from_pretrained(model_id) llm_model.to(device) llm_model.eval() # Set model to evaluation mode stop_list = ['\nHuman:', '\n```\n'] stop_token_ids = [tokenizer.encode(x) for x in stop_list] stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids] class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_ids in stop_token_ids: if input_ids.shape[-1] >= stop_ids.shape[-1] and torch.all(input_ids[0, -stop_ids.shape[-1]:] == stop_ids): return True return False stopping_criteria = StoppingCriteriaList([StopOnTokens()]) generate_text = pipeline( 'text2text-generation', model=llm_model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1, stopping_criteria=stopping_criteria, # without this model rambles during chat temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max max_length=512, # max number of tokens to generate in the output repetition_penalty=1.1 # without this output begins repeating ) llm = HuggingFacePipeline(pipeline=generate_text) # Embedding Model embed_model = SentenceTransformer("all-mpnet-base-v2").to(device) with open('documents.pkl', 'rb') as f: documents = pickle.load(f) # Load embeddings from the file embeddings = np.load('embeddings.npy') def get_similar_docs(query, top_k, docs, embeddings): query_embedding = embed_model.encode(query, convert_to_tensor=True) similarity_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] result_idx = torch.topk(similarity_scores, k=top_k).indices return [docs[i] for i in result_idx] chain = load_qa_chain(llm, chain_type='stuff') conversation_history = [] def get_answer(query): global conversation_history similar_docs = get_similar_docs(query, 10, documents, embeddings) conversation_history.append(Document(page_content=f"User: {query}")) combined_docs = conversation_history + similar_docs answer = chain.run(input_documents=combined_docs, question=query) if "\nHelpful Answer:" in answer: answer = answer.split("\nHelpful Answer:")[1].split("\nUnhelpful Answers:")[0].strip() else: answer = "I don't know, please provide more context." conversation_history.append(Document(page_content=f"Assistant: {answer}")) return answer