Joshua Sundance Bailey commited on
Commit
930d412
·
1 Parent(s): 5f851d5
langchain-streamlit-demo/app.py CHANGED
@@ -18,6 +18,7 @@ from langchain.document_loaders import PyPDFLoader
18
  from langchain.embeddings import OpenAIEmbeddings
19
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
20
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
21
  from langchain.schema.document import Document
22
  from langchain.schema.retriever import BaseRetriever
23
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -142,8 +143,19 @@ def get_texts_and_retriever(
142
  )
143
  texts = text_splitter.split_documents(documents)
144
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
145
- db = FAISS.from_documents(texts, embeddings)
146
- return texts, db.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
147
 
148
 
149
  # --- Sidebar ---
 
18
  from langchain.embeddings import OpenAIEmbeddings
19
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
20
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
21
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
22
  from langchain.schema.document import Document
23
  from langchain.schema.retriever import BaseRetriever
24
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
143
  )
144
  texts = text_splitter.split_documents(documents)
145
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
146
+
147
+ bm25_retriever = BM25Retriever.from_documents(texts)
148
+ bm25_retriever.k = 4
149
+
150
+ faiss_vectorstore = FAISS.from_documents(texts, embeddings)
151
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 4})
152
+
153
+ ensemble_retriever = EnsembleRetriever(
154
+ retrievers=[bm25_retriever, faiss_retriever],
155
+ weights=[0.5, 0.5],
156
+ )
157
+
158
+ return texts, ensemble_retriever
159
 
160
 
161
  # --- Sidebar ---
langchain-streamlit-demo/qagen.py CHANGED
@@ -7,19 +7,13 @@ from langchain.prompts.chat import (
7
  )
8
  from langchain.schema.language_model import BaseLanguageModel
9
  from langchain.schema.runnable import RunnableSequence
10
- from pydantic import BaseModel, field_validator, Field
11
 
12
 
13
  class QuestionAnswerPair(BaseModel):
14
  question: str = Field(..., description="The question that will be answered.")
15
  answer: str = Field(..., description="The answer to the question that was asked.")
16
 
17
- @field_validator("question")
18
- def validate_question(cls, v: str) -> str:
19
- if not v.endswith("?"):
20
- raise ValueError("Question must end with a question mark.")
21
- return v
22
-
23
 
24
  class QuestionAnswerPairList(BaseModel):
25
  QuestionAnswerPairs: List[QuestionAnswerPair]
 
7
  )
8
  from langchain.schema.language_model import BaseLanguageModel
9
  from langchain.schema.runnable import RunnableSequence
10
+ from pydantic import BaseModel, Field
11
 
12
 
13
  class QuestionAnswerPair(BaseModel):
14
  question: str = Field(..., description="The question that will be answered.")
15
  answer: str = Field(..., description="The answer to the question that was asked.")
16
 
 
 
 
 
 
 
17
 
18
  class QuestionAnswerPairList(BaseModel):
19
  QuestionAnswerPairs: List[QuestionAnswerPair]
requirements.txt CHANGED
@@ -5,7 +5,8 @@ langsmith==0.0.40
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==0.28.0
7
  pypdf==3.16.2
8
- streamlit==1.27.0
 
9
  streamlit-feedback==0.1.2
10
  tiktoken==0.5.1
11
  tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
 
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==0.28.0
7
  pypdf==3.16.2
8
+ rank_bm25==0.2.2
9
+ streamlit==1.27.1
10
  streamlit-feedback==0.1.2
11
  tiktoken==0.5.1
12
  tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability