lchakkei commited on
Commit
595b6eb
·
verified ·
1 Parent(s): 68f9bf8

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +173 -0
handler.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import locale
3
+ import os
4
+ from typing import Dict, List, Any
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.retrievers.document_compressors import LLMChainExtractor
8
+ from langchain.retrievers import ContextualCompressionRetriever
9
+ from langchain.vectorstores import Chroma
10
+ from langchain import PromptTemplate, LLMChain
11
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.prompts.prompt import PromptTemplate
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
16
+ from langchain.document_loaders import WebBaseLoader
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
19
+ from langchain.chains.combine_documents import create_stuff_documents_chain
20
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
+ from langchain_core.messages import HumanMessage
22
+ from langchain_core.output_parsers import StrOutputParser
23
+ from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
24
+ from operator import itemgetter
25
+ from langchain.schema import format_document
26
+ from langchain.memory import ConversationBufferMemory
27
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
28
+
29
+ class EndpointHandler():
30
+ def __init__(self, path=""):
31
+
32
+ # Config LangChain
33
+ # os.environ["LANGCHAIN_TRACING_V2"] = "true"
34
+ # os.environ["LANGCHAIN_API_KEY"] = "ls__9834e6b2ff094d43a28418c9ecea2fd5"
35
+
36
+ # Create LLM
37
+ model_id = path
38
+
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_id,
41
+ device_map='auto',
42
+ torch_dtype=torch.float16,
43
+ load_in_8bit=True
44
+ )
45
+ model.eval()
46
+
47
+ # model_kwargs = {
48
+ # "input_ids":input_ids,
49
+ # "max_new_tokens":1024,
50
+ # "do_sample":True,
51
+ # "top_k":50,
52
+ # "top_p":self.top_p,
53
+ # "temperature":self.temperature,
54
+ # "repetition_penalty":1.2,
55
+ # "eos_token_id":self.tokenizer.eos_token_id,
56
+ # "bos_token_id":self.tokenizer.bos_token_id,
57
+ # "pad_token_id":self.tokenizer.pad_token_id
58
+ # }
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ model_id,
62
+ )
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024)
66
+ chat = HuggingFacePipeline(pipeline=pipe)
67
+
68
+ # Create Text-Embedding Model
69
+ embedding_function = HuggingFaceBgeEmbeddings(
70
+ model_name="DMetaSoul/Dmeta-embedding",
71
+ model_kwargs={'device': 'cuda'},
72
+ encode_kwargs={'normalize_embeddings': True}
73
+ )
74
+
75
+ # Load Vector db
76
+ urls = [
77
+ "https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html",
78
+ "https://www.thinkhk.com/article/2023-03/24/59874.html"
79
+ ]
80
+
81
+ loader = WebBaseLoader(urls)
82
+ data = loader.load()
83
+
84
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
85
+ all_splits = text_splitter.split_documents(data)
86
+
87
+ vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
88
+ retriever = vectorstore.as_retriever()
89
+
90
+ compressor = LLMChainExtractor.from_llm(chat)
91
+ compression_retriever = ContextualCompressionRetriever(
92
+ base_compressor=compressor, base_retriever=retriever
93
+ )
94
+
95
+ _template = """[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
96
+ Chat History:
97
+ {chat_history}
98
+ Follow Up Input: {question}
99
+ Standalone question: [/INST]"""
100
+
101
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
102
+
103
+ template = """[INST] Answer the question based only on the following context:
104
+ {context}
105
+
106
+ Question: {question} [/INST]
107
+ """
108
+
109
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
110
+
111
+ self.memory = ConversationBufferMemory(
112
+ return_messages=True, output_key="answer", input_key="question"
113
+ )
114
+
115
+ # First we add a step to load memory
116
+ # This adds a "memory" key to the input object
117
+ loaded_memory = RunnablePassthrough.assign(
118
+ chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
119
+ )
120
+ # Now we calculate the standalone question
121
+ standalone_question = {
122
+ "standalone_question": {
123
+ "question": lambda x: x["question"],
124
+ "chat_history": lambda x: get_buffer_string(x["chat_history"]),
125
+ }
126
+ | CONDENSE_QUESTION_PROMPT
127
+ | chat
128
+ | StrOutputParser(),
129
+ }
130
+
131
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
132
+
133
+ def _combine_documents(
134
+ docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
135
+ ):
136
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
137
+ return document_separator.join(doc_strings)
138
+
139
+ # Now we retrieve the documents
140
+ retrieved_documents = {
141
+ "docs": itemgetter("standalone_question") | retriever,
142
+ "question": lambda x: x["standalone_question"],
143
+ }
144
+ # Now we construct the inputs for the final prompt
145
+ final_inputs = {
146
+ "context": lambda x: _combine_documents(x["docs"]),
147
+ "question": itemgetter("question"),
148
+ }
149
+ # And finally, we do the part that returns the answers
150
+ answer = {
151
+ "answer": final_inputs | ANSWER_PROMPT | chat,
152
+ "docs": itemgetter("docs"),
153
+ }
154
+ # And now we put it all together!
155
+ self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer
156
+
157
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
158
+ # get inputs
159
+ inputs = data.pop("inputs",data)
160
+ date = data.pop("date", None)
161
+
162
+ result = self.final_chain.invoke({"question": inputs})
163
+
164
+ answer = result['answer']
165
+
166
+ # Note that the memory does not save automatically
167
+ # This will be improved in the future
168
+ # For now you need to save it yourself
169
+ # self.memory.save_context(inputs, {"answer": answer})
170
+ self.memory.load_memory_variables({})
171
+
172
+ return answer
173
+