hiwei commited on
Commit
606df26
·
verified ·
1 Parent(s): 1fd4334

add chatbot

Browse files
Files changed (1) hide show
  1. app.py +40 -2
app.py CHANGED
@@ -1,11 +1,15 @@
 
 
1
  import gradio
2
  import gradio as gr
3
  import spacy
4
- from langchain.chains import RetrievalQA
 
5
  from langchain.text_splitter import SpacyTextSplitter
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
8
  from langchain_community.vectorstores import Chroma
 
9
  from langchain_core.prompts import PromptTemplate
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
 
@@ -19,6 +23,14 @@ Helpful Answer:"""
19
  QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
20
 
21
 
 
 
 
 
 
 
 
 
22
  class RAGDemo(object):
23
  def __init__(self):
24
  self.embedding = None
@@ -79,8 +91,24 @@ class RAGDemo(object):
79
  resp = basic_qa.invoke(input_text)
80
  return resp['result']
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def __call__(self):
83
- with gr.Blocks() as demo:
84
  gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
85
  "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
86
  with gr.Tab("Settings"):
@@ -109,6 +137,10 @@ class RAGDemo(object):
109
  submit_btn = gr.Button("submit")
110
  with gr.Column():
111
  output = gr.TextArea(label="answer")
 
 
 
 
112
  initial_btn.click(
113
  self._init_settings,
114
  inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
@@ -119,6 +151,12 @@ class RAGDemo(object):
119
  inputs=input_text,
120
  outputs=output,
121
  )
 
 
 
 
 
 
122
  return demo
123
 
124
 
 
1
+ from typing import List
2
+
3
  import gradio
4
  import gradio as gr
5
  import spacy
6
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
7
+ from langchain.memory import ConversationBufferMemory
8
  from langchain.text_splitter import SpacyTextSplitter
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
11
  from langchain_community.vectorstores import Chroma
12
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
13
  from langchain_core.prompts import PromptTemplate
14
  from langchain_google_genai import ChatGoogleGenerativeAI
15
 
 
23
  QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
24
 
25
 
26
+ def convert_chat_history_to_messages(chat_history) -> List[BaseMessage]:
27
+ result = []
28
+ for human, ai in chat_history:
29
+ result.append(HumanMessage(content=human))
30
+ result.append(AIMessage(content=ai))
31
+ return result
32
+
33
+
34
  class RAGDemo(object):
35
  def __init__(self):
36
  self.embedding = None
 
91
  resp = basic_qa.invoke(input_text)
92
  return resp['result']
93
 
94
+ def _chat_qa(self, message, chat_history):
95
+ memory = ConversationBufferMemory(
96
+ chat_memory=convert_chat_history_to_messages(chat_history),
97
+ memory_key="chat_history",
98
+ return_messages=True,
99
+ )
100
+ qa = ConversationalRetrievalChain.from_llm(
101
+ self.chat_model,
102
+ retriever=self.vector_db.as_retriever(),
103
+ memory=memory,
104
+ verbose=True,
105
+ )
106
+ resp = qa.invoke(message)
107
+ chat_history.append((message, resp['answer']))
108
+ return "", chat_history
109
+
110
  def __call__(self):
111
+ with gr.Blocks(title="🔥 RAG Demo") as demo:
112
  gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
113
  "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
114
  with gr.Tab("Settings"):
 
137
  submit_btn = gr.Button("submit")
138
  with gr.Column():
139
  output = gr.TextArea(label="answer")
140
+ with gr.Tab("Chat RAG"):
141
+ chatbot = gr.Chatbot(label="chat with pdf")
142
+ input_msg = gr.Textbox(placeholder="input your question...", label="input")
143
+ clear_btn = gr.ClearButton([chatbot, input_msg])
144
  initial_btn.click(
145
  self._init_settings,
146
  inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
 
151
  inputs=input_text,
152
  outputs=output,
153
  )
154
+
155
+ input_msg.submit(
156
+ self._chat_qa,
157
+ inputs=[input_msg, chatbot],
158
+ outputs=[input_msg, chatbot]
159
+ )
160
  return demo
161
 
162