File size: 1,794 Bytes
13c2f2e
 
 
031f4be
13c2f2e
 
 
 
205a632
 
13c2f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3222409
 
13c2f2e
3222409
13c2f2e
 
3222409
7e64201
3222409
 
13c2f2e
3222409
13c2f2e
 
 
 
 
 
 
 
 
 
 
e2de233
d306c41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pickle 
import torch
import numpy as np
from langchain.schema import Document
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-large")
llm_model = T5ForConditionalGeneration.from_pretrained("t5-large")

with open('documents.pkl', 'rb') as f:
    documents = pickle.load(f)

# Load embeddings from the file
embeddings = np.load('embeddings.npy')

model = SentenceTransformer("all-mpnet-base-v2")

def get_similar_docs(query, top_k, docs, embeddings):
    query_embedding = model.encode(query)
    similarity = model.similarity(query_embedding, embeddings).squeeze(0)
    result_idx = similarity.argsort(descending=True)[:top_k]
    
    return [docs[i] for i in result_idx]

chat_history = []

def get_answer(query):
    global chat_history
    # Prepare input and decoder input
    similar_docs = get_similar_docs(query, 10, documents, embeddings)
    chat_history.append(Document(page_content=query))
    
    combined_docs = chat_history + similar_docs

    input_text = ""
    for i in combined_docs:
        input_text += i.page_content

    # print(input_text)
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids  # Batch size 1
    decoder_input_ids = tokenizer(query, return_tensors="pt").input_ids  # Batch size 1

    # Generate text
    outputs = llm_model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids, max_length=500)

    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    chat_history.append(Document(page_content=generated_text.split(query)[1].replace('.', '. ').capitalize()))
    return generated_text.split(query)[1].capitalize()