Harsh2001's picture
Update utils.py
e2de233 verified
import pickle
import torch
import numpy as np
from langchain.schema import Document
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Load the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-large")
llm_model = T5ForConditionalGeneration.from_pretrained("t5-large")
with open('documents.pkl', 'rb') as f:
documents = pickle.load(f)
# Load embeddings from the file
embeddings = np.load('embeddings.npy')
model = SentenceTransformer("all-mpnet-base-v2")
def get_similar_docs(query, top_k, docs, embeddings):
query_embedding = model.encode(query)
similarity = model.similarity(query_embedding, embeddings).squeeze(0)
result_idx = similarity.argsort(descending=True)[:top_k]
return [docs[i] for i in result_idx]
chat_history = []
def get_answer(query):
global chat_history
# Prepare input and decoder input
similar_docs = get_similar_docs(query, 10, documents, embeddings)
chat_history.append(Document(page_content=query))
combined_docs = chat_history + similar_docs
input_text = ""
for i in combined_docs:
input_text += i.page_content
# print(input_text)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids # Batch size 1
decoder_input_ids = tokenizer(query, return_tensors="pt").input_ids # Batch size 1
# Generate text
outputs = llm_model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_length=500)
# Decode the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
chat_history.append(Document(page_content=generated_text.split(query)[1].replace('.', '. ').capitalize()))
return generated_text.split(query)[1].capitalize()