Chu Thi Thanh commited on
Commit
af8db98
·
1 Parent(s): 815da53

Upload files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+ .venv
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.ui import UI
3
+ from src.chatbot import ChatBot
4
+
5
+ def clear_session():
6
+ return "", []
7
+
8
+ def add_query(chat_history, input):
9
+ if not input:
10
+ raise gr.Error("Please enter a question.")
11
+ chat_history.append((input, None))
12
+ return chat_history
13
+
14
+ def response(chat_history, query):
15
+ res_msg, ref_docs = chatbot.generate_response(query, chat_history[:-1])
16
+ chat_history[-1] = (query, res_msg)
17
+ return "", chat_history, ref_docs
18
+
19
+ if __name__ == "__main__":
20
+ demo, chatspace, ref_docs, text_input, clear_btn = UI.create_demo()
21
+
22
+ chatbot = ChatBot(is_debug=True)
23
+ with demo:
24
+ # Event handler for submitting text and generating response
25
+ text_input.submit(add_query, inputs=[chatspace, text_input], outputs=[chatspace], concurrency_limit=1).\
26
+ success(response, inputs=[chatspace, text_input], outputs=[text_input, chatspace, ref_docs])
27
+ clear_btn.click(clear_session, inputs=[], outputs=[text_input, chatspace])
28
+
29
+ demo.queue(api_open=False)
30
+ demo.launch()
chroma/c6cac3f8-bcae-47bc-a6c2-8fdc2770f19d/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dbee596e8dcd6fe25e98be3ab98e6e07d4594cca8ff319396c9adabca1c7261
3
+ size 37704000
chroma/c6cac3f8-bcae-47bc-a6c2-8fdc2770f19d/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da7633500b48fa102f767e4c2b993bb473a37f64d13f12ccf076b4d058b671c9
3
+ size 100
chroma/c6cac3f8-bcae-47bc-a6c2-8fdc2770f19d/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:383344cb187856235bccb00d648931a7988d828b7e52dedcb37e603489ae94ca
3
+ size 346049
chroma/c6cac3f8-bcae-47bc-a6c2-8fdc2770f19d/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83c28735eeddff73cb1a12abd616afa354bf52e37ee424611241a641c30c913e
3
+ size 24000
chroma/c6cac3f8-bcae-47bc-a6c2-8fdc2770f19d/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef7ee997a8e34ddb4bc8487557d3299ebc0a8132ac19c9ef1b191a093eba3a69
3
+ size 52152
chroma/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b30e09521687b19affc30b8f7082ceabd97a21b09735ecaf40610dc6cc1d230
3
+ size 72232960
data/comments.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ langchain==0.1.17
3
+ langchain_chroma==0.1.0
4
+ langchain_community==0.0.36
5
+ langchain_core==0.1.50
6
+ langchain_openai==0.1.6
7
+ langchain_text_splitters==0.0.1
src/chatbot.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Any, Dict, List, Tuple
3
+ from langchain_chroma import Chroma
4
+ from langchain_core.callbacks import BaseCallbackHandler
5
+ from langchain_core.messages import AIMessage, HumanMessage
6
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ from langchain.chains.history_aware_retriever import create_history_aware_retriever
9
+ from langchain.chains.retrieval import create_retrieval_chain
10
+ from langchain.chains.combine_documents import create_stuff_documents_chain
11
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
12
+ from langchain_core.documents import Document
13
+ from langchain_core.retrievers import BaseRetriever
14
+ import pandas as pd
15
+
16
+ class CustomHandler(BaseCallbackHandler):
17
+ def __init__(self):
18
+ self.prompt = ""
19
+
20
+ def on_llm_start(
21
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
22
+ ) -> Any:
23
+ formatted_prompts = "\n".join(prompts)
24
+ self.prompt = formatted_prompts
25
+
26
+ class CustomRetriever(BaseRetriever):
27
+ vectorstore: Chroma
28
+ comments: pd.DataFrame
29
+
30
+ def _get_relevant_documents(
31
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
32
+ ) -> List[Document]:
33
+ docs = self.vectorstore.similarity_search(query)
34
+ matching_documents = []
35
+ for doc in docs:
36
+ post_id = int(doc.metadata['source'])
37
+ comment = self.comments.loc[self.comments['Post_ID'] == post_id, 'Comment_content'].values
38
+ query = doc.page_content.replace("Content: ", "User: ")
39
+ content = f"{query}\nAssistant: {comment[0]}"
40
+ matching_documents.append(
41
+ Document(
42
+ page_content=content,
43
+ metadata=doc.metadata
44
+ )
45
+ )
46
+
47
+ print(matching_documents)
48
+ return matching_documents
49
+
50
+ class ChatBot:
51
+ def __init__(self, is_debug=False):
52
+ self.is_debug = is_debug
53
+ self.model = ChatOpenAI()
54
+ self.handler = CustomHandler()
55
+ self.embedding_function = OpenAIEmbeddings()
56
+ self.vectorstore = Chroma(
57
+ embedding_function=self.embedding_function,
58
+ collection_name="documents",
59
+ persist_directory="chroma",
60
+ )
61
+ self.comments = pd.read_csv("data/comments.csv")
62
+ self.retriever = CustomRetriever(vectorstore=self.vectorstore, comments=self.comments)
63
+
64
+ def create_chain(self):
65
+ qa_system_prompt = """
66
+ You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe.
67
+ Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
68
+ Please ensure that your responses are socially unbiased and positive in nature.
69
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
70
+ If you don't know the answer to a question, please don't share false information.
71
+
72
+ Here are a few examples of answers:
73
+ {context}
74
+
75
+ """
76
+ prompt = ChatPromptTemplate.from_messages([
77
+ ("system", qa_system_prompt),
78
+ MessagesPlaceholder(variable_name="chat_history"),
79
+ ("human", "{input}")
80
+ ])
81
+
82
+ chain = create_stuff_documents_chain(
83
+ llm=self.model,
84
+ prompt=prompt
85
+ )
86
+
87
+ retriever_prompt = ChatPromptTemplate.from_messages([
88
+ MessagesPlaceholder(variable_name="chat_history"),
89
+ ("human", "{input}"),
90
+ ("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
91
+ ])
92
+ history_aware_retriever = create_history_aware_retriever(
93
+ llm=self.model,
94
+ retriever=self.retriever,
95
+ prompt=retriever_prompt
96
+ )
97
+
98
+ retrieval_chain = create_retrieval_chain(
99
+ # retriever, Replace with History Aware Retriever
100
+ history_aware_retriever,
101
+ chain
102
+ )
103
+
104
+ return retrieval_chain
105
+
106
+ def process_chat_history(self, chat_history):
107
+ history = []
108
+ for (query, response) in chat_history:
109
+ history.append(HumanMessage(content=query))
110
+ history.append(AIMessage(content=response))
111
+ return history
112
+
113
+ def generate_response(self, query, chat_history):
114
+ if not input:
115
+ raise gr.Error("Please enter a question.")
116
+
117
+ history = self.process_chat_history(chat_history)
118
+ conversational_chain = self.create_chain()
119
+ response = conversational_chain.invoke(
120
+ {
121
+ "input": query,
122
+ "chat_history": history,
123
+ },
124
+ config={"callbacks": [self.handler]}
125
+ )["answer"]
126
+
127
+ references = self.handler.prompt if self.is_debug else "This is for debugging purposes only."
128
+ return response, references
src/ui.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ class UI:
4
+ @staticmethod
5
+ def feedback(data: gr.LikeData):
6
+ if data.liked:
7
+ print("You upvoted this response: " + data.value)
8
+ else:
9
+ print("You downvoted this response: " + data.value)
10
+
11
+ @staticmethod
12
+ def create_demo():
13
+ demo = gr.Blocks(title= "Chatbot", theme="Soft")
14
+ with demo:
15
+ with gr.Tab("Chat"):
16
+ chatbot = gr.Chatbot(value=[], elem_id='chatbot')
17
+ chatbot.like(UI.feedback, None, None)
18
+
19
+ text_input = gr.Textbox(
20
+ show_label=False,
21
+ placeholder="Ask me anything!",
22
+ container=False)
23
+
24
+ clear_btn = gr.Button("🧹 Clear")
25
+ with gr.Tab("Prompt"):
26
+ ref_docs = gr.Textbox(label='References', lines=25)
27
+
28
+ return demo, chatbot, ref_docs, text_input, clear_btn