import pickle import torch import numpy as np from langchain.schema import Document from sentence_transformers import SentenceTransformer from transformers import T5Tokenizer, T5ForConditionalGeneration # Load the tokenizer and model tokenizer = T5Tokenizer.from_pretrained("t5-large") llm_model = T5ForConditionalGeneration.from_pretrained("t5-large") with open('documents.pkl', 'rb') as f: documents = pickle.load(f) # Load embeddings from the file embeddings = np.load('embeddings.npy') model = SentenceTransformer("all-mpnet-base-v2") def get_similar_docs(query, top_k, docs, embeddings): query_embedding = model.encode(query) similarity = model.similarity(query_embedding, embeddings).squeeze(0) result_idx = similarity.argsort(descending=True)[:top_k] return [docs[i] for i in result_idx] chat_history = [] def get_answer(query): global chat_history # Prepare input and decoder input similar_docs = get_similar_docs(query, 10, documents, embeddings) chat_history.append(Document(page_content=query)) combined_docs = chat_history + similar_docs input_text = "" for i in combined_docs: input_text += i.page_content # print(input_text) input_ids = tokenizer(input_text, return_tensors="pt").input_ids # Batch size 1 decoder_input_ids = tokenizer(query, return_tensors="pt").input_ids # Batch size 1 # Generate text outputs = llm_model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_length=500) # Decode the generated text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) chat_history.append(Document(page_content=generated_text.split(query)[1].replace('.', '. ').capitalize())) return generated_text.split(query)[1].capitalize()