Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,8 @@ from langchain import PromptTemplate, LLMChain
|
|
20 |
# Vector stores
|
21 |
from langchain.vectorstores import FAISS
|
22 |
|
23 |
-
#
|
24 |
-
from
|
25 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
26 |
|
27 |
# Retrievers
|
@@ -43,7 +43,7 @@ shutil.rmtree('./.cache', ignore_errors=True)
|
|
43 |
|
44 |
class CFG:
|
45 |
# LLMs configuration
|
46 |
-
model_name = 'llama2-13b-chat'
|
47 |
temperature = 0
|
48 |
top_p = 0.95
|
49 |
repetition_penalty = 1.15
|
@@ -130,14 +130,13 @@ pipe = pipeline(
|
|
130 |
repetition_penalty=CFG.repetition_penalty
|
131 |
)
|
132 |
|
133 |
-
#
|
134 |
llm = HuggingFacePipeline(pipeline=pipe)
|
135 |
|
136 |
loader = DirectoryLoader(
|
137 |
CFG.PDFs_path,
|
138 |
glob="./*.pdf",
|
139 |
loader_cls=PyPDFLoader,
|
140 |
-
show_progress=True,
|
141 |
)
|
142 |
|
143 |
documents = loader.load()
|
@@ -160,8 +159,7 @@ retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k})
|
|
160 |
|
161 |
qa_chain = RetrievalQA.from_chain_type(
|
162 |
llm=llm,
|
163 |
-
chain_type="stuff",
|
164 |
-
retriever=retriever,
|
165 |
)
|
166 |
|
167 |
prompt_template = """
|
|
|
20 |
# Vector stores
|
21 |
from langchain.vectorstores import FAISS
|
22 |
|
23 |
+
# Import HuggingFacePipeline from the new package
|
24 |
+
from langchain_huggingface import HuggingFacePipeline
|
25 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
26 |
|
27 |
# Retrievers
|
|
|
43 |
|
44 |
class CFG:
|
45 |
# LLMs configuration
|
46 |
+
model_name = 'llama2-13b-chat'
|
47 |
temperature = 0
|
48 |
top_p = 0.95
|
49 |
repetition_penalty = 1.15
|
|
|
130 |
repetition_penalty=CFG.repetition_penalty
|
131 |
)
|
132 |
|
133 |
+
# Use the updated HuggingFacePipeline class from langchain_huggingface
|
134 |
llm = HuggingFacePipeline(pipeline=pipe)
|
135 |
|
136 |
loader = DirectoryLoader(
|
137 |
CFG.PDFs_path,
|
138 |
glob="./*.pdf",
|
139 |
loader_cls=PyPDFLoader,
|
|
|
140 |
)
|
141 |
|
142 |
documents = loader.load()
|
|
|
159 |
|
160 |
qa_chain = RetrievalQA.from_chain_type(
|
161 |
llm=llm,
|
162 |
+
chain_type="stuff",
|
|
|
163 |
)
|
164 |
|
165 |
prompt_template = """
|