Chris Alexiuk
commited on
Commit
·
643f5c3
1
Parent(s):
9044dde
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
4 |
from langchain.vectorstores import Chroma
|
5 |
from langchain.chains import RetrievalQAWithSourcesChain
|
6 |
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
|
|
7 |
from langchain.prompts.chat import (
|
8 |
ChatPromptTemplate,
|
9 |
SystemMessagePromptTemplate,
|
@@ -53,12 +56,55 @@ async def init():
|
|
53 |
# docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
|
54 |
docsearch = Chroma.from_documents(pdf_data, embeddings)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Create a chain that uses the Chroma vector store
|
57 |
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
58 |
-
|
59 |
-
model_name="gpt-3.5-turbo-16k",
|
60 |
-
temperature=0,
|
61 |
-
),
|
62 |
chain_type="stuff",
|
63 |
retriever=docsearch.as_retriever(),
|
64 |
return_source_documents=True,
|
|
|
4 |
from langchain.vectorstores import Chroma
|
5 |
from langchain.chains import RetrievalQAWithSourcesChain
|
6 |
from langchain.chat_models import ChatOpenAI
|
7 |
+
from typing import Any, List, Mapping, Optional
|
8 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
9 |
+
from langchain.llms.base import LLM
|
10 |
from langchain.prompts.chat import (
|
11 |
ChatPromptTemplate,
|
12 |
SystemMessagePromptTemplate,
|
|
|
56 |
# docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
|
57 |
docsearch = Chroma.from_documents(pdf_data, embeddings)
|
58 |
|
59 |
+
# custom SageMaker Model
|
60 |
+
class Llama2SageMaker(LLM):
|
61 |
+
max_new_tokens: int = 256
|
62 |
+
top_p: float = 0.9
|
63 |
+
temperature: float = 0.1
|
64 |
+
|
65 |
+
@property
|
66 |
+
def _llm_type(self) -> str:
|
67 |
+
return "Llama2SageMaker"
|
68 |
+
|
69 |
+
def _call(
|
70 |
+
self,
|
71 |
+
prompt: str,
|
72 |
+
stop: Optional[List[str]] = None,
|
73 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
74 |
+
) -> str:
|
75 |
+
if stop is not None:
|
76 |
+
raise ValueError("stop kwargs are not permitted.")
|
77 |
+
|
78 |
+
json_body = {
|
79 |
+
"inputs" : [
|
80 |
+
[{"role" : "user", "content" : prompt}]
|
81 |
+
],
|
82 |
+
"parameters" : {
|
83 |
+
"max_new_tokens" : self.max_new_tokens,
|
84 |
+
"top_p" : self.top_p,
|
85 |
+
"temperature" : self.temperature
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
response = requests.post(model_api_gateway, json=json_body)
|
90 |
+
|
91 |
+
return response.json()[0]["generation"]["content"]
|
92 |
+
|
93 |
+
@property
|
94 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
95 |
+
"""Get the identifying parameters."""
|
96 |
+
return {
|
97 |
+
"max_new_tokens" : self.max_new_tokens,
|
98 |
+
"top_p" : self.top_p,
|
99 |
+
"temperature" : self.temperature
|
100 |
+
}
|
101 |
+
|
102 |
+
# set our llm to the custom Llama2SageMaker endpoint model
|
103 |
+
llm = Llama2SageMaker()
|
104 |
+
|
105 |
# Create a chain that uses the Chroma vector store
|
106 |
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
107 |
+
llm=llm,
|
|
|
|
|
|
|
108 |
chain_type="stuff",
|
109 |
retriever=docsearch.as_retriever(),
|
110 |
return_source_documents=True,
|