File size: 3,059 Bytes
b102053
4bf7dc9
 
 
 
 
 
 
859c127
b102053
3935054
 
4bf7dc9
859c127
0aeef17
 
 
b722c79
 
 
 
 
 
 
4bf7dc9
 
859c127
4bf7dc9
 
 
 
 
859c127
4bf7dc9
 
 
 
 
859c127
 
4bf7dc9
 
859c127
4bf7dc9
 
859c127
 
4bf7dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
859c127
 
 
4bf7dc9
 
 
 
 
 
 
 
e26e5b2
4bf7dc9
859c127
4bf7dc9
859c127
4bf7dc9
 
859c127
 
 
 
 
 
 
4bf7dc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import torch
import pickle
import numpy as np
from langchain.schema import Document
from langchain.llms import HuggingFacePipeline
from sentence_transformers import SentenceTransformer, util
from langchain.chains.question_answering import load_qa_chain
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline, StoppingCriteria, StoppingCriteriaList


hf_auth = os.getenv('hf_auth')
# LLM Model
model_id = 'google/t5-v1_1-base'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = T5Tokenizer.from_pretrained(model_id)

# Load model
llm_model = T5ForConditionalGeneration.from_pretrained(model_id)
llm_model.to(device)
llm_model.eval()  # Set model to evaluation mode

stop_list = ['\nHuman:', '\n```\n']

stop_token_ids = [tokenizer.encode(x) for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if input_ids.shape[-1] >= stop_ids.shape[-1] and torch.all(input_ids[0, -stop_ids.shape[-1]:] == stop_ids):
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

generate_text = pipeline(
    'text2text-generation',
    model=llm_model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
    stopping_criteria=stopping_criteria,  # without this model rambles during chat
    temperature=0.1,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_length=512,  # max number of tokens to generate in the output
    repetition_penalty=1.1  # without this output begins repeating
)

llm = HuggingFacePipeline(pipeline=generate_text)

# Embedding Model
embed_model = SentenceTransformer("all-mpnet-base-v2").to(device)

with open('documents.pkl', 'rb') as f:
    documents = pickle.load(f)

# Load embeddings from the file
embeddings = np.load('embeddings.npy')

def get_similar_docs(query, top_k, docs, embeddings):
    query_embedding = embed_model.encode(query, convert_to_tensor=True)
    similarity_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
    result_idx = torch.topk(similarity_scores, k=top_k).indices
    return [docs[i] for i in result_idx]

chain = load_qa_chain(llm, chain_type='stuff')

conversation_history = []

def get_answer(query):
    global conversation_history
    similar_docs = get_similar_docs(query, 10, documents, embeddings)

    conversation_history.append(Document(page_content=f"User: {query}"))

    combined_docs = conversation_history + similar_docs
    
    answer = chain.run(input_documents=combined_docs, question=query)
    
    if "\nHelpful Answer:" in answer:
        answer = answer.split("\nHelpful Answer:")[1].split("\nUnhelpful Answers:")[0].strip()
    else:
        answer = "I don't know, please provide more context."

    conversation_history.append(Document(page_content=f"Assistant: {answer}"))
    return answer