|
import pickle |
|
import torch |
|
import numpy as np |
|
from langchain.schema import Document |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-large") |
|
llm_model = T5ForConditionalGeneration.from_pretrained("t5-large") |
|
|
|
with open('documents.pkl', 'rb') as f: |
|
documents = pickle.load(f) |
|
|
|
|
|
embeddings = np.load('embeddings.npy') |
|
|
|
model = SentenceTransformer("all-mpnet-base-v2") |
|
|
|
def get_similar_docs(query, top_k, docs, embeddings): |
|
query_embedding = model.encode(query) |
|
similarity = model.similarity(query_embedding, embeddings).squeeze(0) |
|
result_idx = similarity.argsort(descending=True)[:top_k] |
|
|
|
return [docs[i] for i in result_idx] |
|
|
|
chat_history = [] |
|
|
|
def get_answer(query): |
|
global chat_history |
|
|
|
similar_docs = get_similar_docs(query, 10, documents, embeddings) |
|
chat_history.append(Document(page_content=query)) |
|
|
|
combined_docs = chat_history + similar_docs |
|
|
|
input_text = "" |
|
for i in combined_docs: |
|
input_text += i.page_content |
|
|
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
decoder_input_ids = tokenizer(query, return_tensors="pt").input_ids |
|
|
|
|
|
outputs = llm_model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_length=500) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
chat_history.append(Document(page_content=generated_text.split(query)[1].replace('.', '. ').capitalize())) |
|
return generated_text.split(query)[1].capitalize() |