Pavan178 commited on
Commit
b840efb
·
verified ·
1 Parent(s): 3f31c68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -43
app.py CHANGED
@@ -4,8 +4,9 @@ from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.embeddings import OpenAIEmbeddings
6
  from langchain.vectorstores import FAISS
7
- from langchain.chains import ConversationalRetrievalChain, Chain
8
- from langchain.chat_models import ChatOpenAI
 
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain.prompts import PromptTemplate
11
 
@@ -51,50 +52,39 @@ class AdvancedPdfChatbot:
51
  self.setup_conversation_chain()
52
 
53
  def setup_conversation_chain(self):
54
- class CustomChain(Chain):
55
- refinement_chain: Chain
56
- qa_chain: Chain
57
-
58
- @classmethod
59
- def from_llms(cls, refinement_llm, qa_llm, retriever, memory, prompt):
60
- refinement_chain = Chain(
61
- llm_chain=LLMChain(
62
- llm=refinement_llm,
63
- prompt=self.refinement_prompt,
64
- output_key='refined_query'
65
- )
66
- )
67
- qa_chain = ConversationalRetrievalChain.from_llm(
68
- qa_llm,
69
- retriever=retriever,
70
- memory=memory,
71
- combine_docs_chain_kwargs={"prompt": prompt}
72
- )
73
- return cls(refinement_chain=refinement_chain, qa_chain=qa_chain)
74
-
75
- def _call(self, inputs):
76
- query = inputs['query']
77
- chat_history = inputs.get('chat_history', [])
78
- refined_query = self.refinement_chain.run(query=query, chat_history=chat_history)
79
- response = self.qa_chain({"question": refined_query, "chat_history": chat_history})
80
- self.qa_chain.memory.save_context({"input": query}, {"output": response['answer']})
81
- return {"answer": response['answer']}
82
-
83
- @property
84
- def input_keys(self):
85
- return ['query', 'chat_history']
86
-
87
- @property
88
- def output_keys(self):
89
- return ['answer']
90
-
91
- self.overall_chain = CustomChain.from_llms(
92
- refinement_llm=self.refinement_llm,
93
- qa_llm=self.llm,
94
  retriever=self.db.as_retriever(),
95
  memory=self.memory,
96
- prompt=self.prompt
97
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def chat(self, query):
100
  if not self.overall_chain:
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.embeddings import OpenAIEmbeddings
6
  from langchain.vectorstores import FAISS
7
+ from langchain.chains.base import Chain
8
+ from langchain.llms import ChatOpenAI
9
+ from langchain.chains import LLMChain, ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
12
 
 
52
  self.setup_conversation_chain()
53
 
54
  def setup_conversation_chain(self):
55
+ refinement_chain = LLMChain(
56
+ llm=self.refinement_llm,
57
+ prompt=self.refinement_prompt,
58
+ output_key='refined_query'
59
+ )
60
+ qa_chain = ConversationalRetrievalChain.from_llm(
61
+ self.llm,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  retriever=self.db.as_retriever(),
63
  memory=self.memory,
64
+ combine_docs_chain_kwargs={"prompt": self.prompt}
65
  )
66
+ self.overall_chain = self.CustomChain(refinement_chain=refinement_chain, qa_chain=qa_chain)
67
+
68
+ class CustomChain(Chain):
69
+ def __init__(self, refinement_chain, qa_chain):
70
+ super().__init__()
71
+ self.refinement_chain = refinement_chain
72
+ self.qa_chain = qa_chain
73
+
74
+ @property
75
+ def input_keys(self):
76
+ return ["query", "chat_history"]
77
+
78
+ @property
79
+ def output_keys(self):
80
+ return ["answer"]
81
+
82
+ def _call(self, inputs):
83
+ query = inputs['query']
84
+ chat_history = inputs.get('chat_history', [])
85
+ refined_query = self.refinement_chain.run(query=query, chat_history=chat_history)
86
+ response = self.qa_chain({"question": refined_query, "chat_history": chat_history})
87
+ return {"answer": response['answer']}
88
 
89
  def chat(self, query):
90
  if not self.overall_chain: