Joshua Sundance Bailey
commited on
Commit
·
930d412
1
Parent(s):
5f851d5
bm25
Browse files- langchain-streamlit-demo/app.py +14 -2
- langchain-streamlit-demo/qagen.py +1 -7
- requirements.txt +2 -1
langchain-streamlit-demo/app.py
CHANGED
@@ -18,6 +18,7 @@ from langchain.document_loaders import PyPDFLoader
|
|
18 |
from langchain.embeddings import OpenAIEmbeddings
|
19 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
20 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
21 |
from langchain.schema.document import Document
|
22 |
from langchain.schema.retriever import BaseRetriever
|
23 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -142,8 +143,19 @@ def get_texts_and_retriever(
|
|
142 |
)
|
143 |
texts = text_splitter.split_documents(documents)
|
144 |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
|
149 |
# --- Sidebar ---
|
|
|
18 |
from langchain.embeddings import OpenAIEmbeddings
|
19 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
20 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
21 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
22 |
from langchain.schema.document import Document
|
23 |
from langchain.schema.retriever import BaseRetriever
|
24 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
143 |
)
|
144 |
texts = text_splitter.split_documents(documents)
|
145 |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
146 |
+
|
147 |
+
bm25_retriever = BM25Retriever.from_documents(texts)
|
148 |
+
bm25_retriever.k = 4
|
149 |
+
|
150 |
+
faiss_vectorstore = FAISS.from_documents(texts, embeddings)
|
151 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 4})
|
152 |
+
|
153 |
+
ensemble_retriever = EnsembleRetriever(
|
154 |
+
retrievers=[bm25_retriever, faiss_retriever],
|
155 |
+
weights=[0.5, 0.5],
|
156 |
+
)
|
157 |
+
|
158 |
+
return texts, ensemble_retriever
|
159 |
|
160 |
|
161 |
# --- Sidebar ---
|
langchain-streamlit-demo/qagen.py
CHANGED
@@ -7,19 +7,13 @@ from langchain.prompts.chat import (
|
|
7 |
)
|
8 |
from langchain.schema.language_model import BaseLanguageModel
|
9 |
from langchain.schema.runnable import RunnableSequence
|
10 |
-
from pydantic import BaseModel,
|
11 |
|
12 |
|
13 |
class QuestionAnswerPair(BaseModel):
|
14 |
question: str = Field(..., description="The question that will be answered.")
|
15 |
answer: str = Field(..., description="The answer to the question that was asked.")
|
16 |
|
17 |
-
@field_validator("question")
|
18 |
-
def validate_question(cls, v: str) -> str:
|
19 |
-
if not v.endswith("?"):
|
20 |
-
raise ValueError("Question must end with a question mark.")
|
21 |
-
return v
|
22 |
-
|
23 |
|
24 |
class QuestionAnswerPairList(BaseModel):
|
25 |
QuestionAnswerPairs: List[QuestionAnswerPair]
|
|
|
7 |
)
|
8 |
from langchain.schema.language_model import BaseLanguageModel
|
9 |
from langchain.schema.runnable import RunnableSequence
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
|
12 |
|
13 |
class QuestionAnswerPair(BaseModel):
|
14 |
question: str = Field(..., description="The question that will be answered.")
|
15 |
answer: str = Field(..., description="The answer to the question that was asked.")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class QuestionAnswerPairList(BaseModel):
|
19 |
QuestionAnswerPairs: List[QuestionAnswerPair]
|
requirements.txt
CHANGED
@@ -5,7 +5,8 @@ langsmith==0.0.40
|
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==0.28.0
|
7 |
pypdf==3.16.2
|
8 |
-
|
|
|
9 |
streamlit-feedback==0.1.2
|
10 |
tiktoken==0.5.1
|
11 |
tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
|
|
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==0.28.0
|
7 |
pypdf==3.16.2
|
8 |
+
rank_bm25==0.2.2
|
9 |
+
streamlit==1.27.1
|
10 |
streamlit-feedback==0.1.2
|
11 |
tiktoken==0.5.1
|
12 |
tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
|