Chris Alexiuk commited on
Commit
643f5c3
·
1 Parent(s): 9044dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -4
app.py CHANGED
@@ -4,6 +4,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.vectorstores import Chroma
5
  from langchain.chains import RetrievalQAWithSourcesChain
6
  from langchain.chat_models import ChatOpenAI
 
 
 
7
  from langchain.prompts.chat import (
8
  ChatPromptTemplate,
9
  SystemMessagePromptTemplate,
@@ -53,12 +56,55 @@ async def init():
53
  # docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
54
  docsearch = Chroma.from_documents(pdf_data, embeddings)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Create a chain that uses the Chroma vector store
57
  chain = RetrievalQAWithSourcesChain.from_chain_type(
58
- ChatOpenAI(
59
- model_name="gpt-3.5-turbo-16k",
60
- temperature=0,
61
- ),
62
  chain_type="stuff",
63
  retriever=docsearch.as_retriever(),
64
  return_source_documents=True,
 
4
  from langchain.vectorstores import Chroma
5
  from langchain.chains import RetrievalQAWithSourcesChain
6
  from langchain.chat_models import ChatOpenAI
7
+ from typing import Any, List, Mapping, Optional
8
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
9
+ from langchain.llms.base import LLM
10
  from langchain.prompts.chat import (
11
  ChatPromptTemplate,
12
  SystemMessagePromptTemplate,
 
56
  # docsearch = await cl.make_async(Chroma.from_documents)(pdf_data, embeddings)
57
  docsearch = Chroma.from_documents(pdf_data, embeddings)
58
 
59
+ # custom SageMaker Model
60
+ class Llama2SageMaker(LLM):
61
+ max_new_tokens: int = 256
62
+ top_p: float = 0.9
63
+ temperature: float = 0.1
64
+
65
+ @property
66
+ def _llm_type(self) -> str:
67
+ return "Llama2SageMaker"
68
+
69
+ def _call(
70
+ self,
71
+ prompt: str,
72
+ stop: Optional[List[str]] = None,
73
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
74
+ ) -> str:
75
+ if stop is not None:
76
+ raise ValueError("stop kwargs are not permitted.")
77
+
78
+ json_body = {
79
+ "inputs" : [
80
+ [{"role" : "user", "content" : prompt}]
81
+ ],
82
+ "parameters" : {
83
+ "max_new_tokens" : self.max_new_tokens,
84
+ "top_p" : self.top_p,
85
+ "temperature" : self.temperature
86
+ }
87
+ }
88
+
89
+ response = requests.post(model_api_gateway, json=json_body)
90
+
91
+ return response.json()[0]["generation"]["content"]
92
+
93
+ @property
94
+ def _identifying_params(self) -> Mapping[str, Any]:
95
+ """Get the identifying parameters."""
96
+ return {
97
+ "max_new_tokens" : self.max_new_tokens,
98
+ "top_p" : self.top_p,
99
+ "temperature" : self.temperature
100
+ }
101
+
102
+ # set our llm to the custom Llama2SageMaker endpoint model
103
+ llm = Llama2SageMaker()
104
+
105
  # Create a chain that uses the Chroma vector store
106
  chain = RetrievalQAWithSourcesChain.from_chain_type(
107
+ llm=llm,
 
 
 
108
  chain_type="stuff",
109
  retriever=docsearch.as_retriever(),
110
  return_source_documents=True,