sanbu commited on
Commit
7a82ea8
·
1 Parent(s): 2cb56ed

Update space

Browse files
Files changed (2) hide show
  1. app.py +190 -48
  2. requirements.txt +5 -1
app.py CHANGED
@@ -1,63 +1,205 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
27
 
28
- response = ""
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ from dotenv import load_dotenv
4
+ from tianji.knowledges.langchain_onlinellm.models import ZhipuAIEmbeddings, ZhipuLLM
5
+ from langchain_chroma import Chroma
6
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from langchain import hub
11
+ from huggingface_hub import snapshot_download
12
 
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ # Download dataset using Hugging Face's huggingface_hub
17
+ destination_folder = os.path.join("./", "temp", "tianji-chinese")
18
+ snapshot_download(
19
+ repo_id="sanbu/tianji-chinese",
20
+ local_dir=destination_folder,
21
+ repo_type="dataset",
22
+ local_dir_use_symlinks=False,
23
+ )
24
 
25
 
26
+ def create_vectordb(
27
+ data_path: str,
28
+ persist_directory: str,
29
+ embedding_func,
30
+ chunk_size: int,
31
+ force: bool = False,
 
32
  ):
33
+ if os.path.exists(persist_directory) and not force:
34
+ return Chroma(
35
+ persist_directory=persist_directory, embedding_function=embedding_func
36
+ )
37
 
38
+ if force and os.path.exists(persist_directory):
39
+ if os.path.isdir(persist_directory):
40
+ import shutil
 
 
41
 
42
+ shutil.rmtree(persist_directory)
43
+ else:
44
+ os.remove(persist_directory)
45
 
46
+ loader = DirectoryLoader(data_path, glob="*.txt", loader_cls=TextLoader)
47
+ text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=chunk_size, chunk_overlap=200
49
+ )
50
+ split_docs = text_splitter.split_documents(loader.load())
51
+ if len(split_docs) == 0:
52
+ raise gr.Error("当前知识数据无效,处理数据后为空")
53
 
54
+ vector_db = Chroma.from_documents(
55
+ documents=split_docs,
56
+ embedding=embedding_func,
57
+ persist_directory=persist_directory,
58
+ )
59
+ return vector_db
 
 
60
 
 
 
61
 
62
+ def initialize_chain(chunk_size: int, persist_directory: str, data_path: str):
63
+ print("初始化数据库开始")
64
+ embeddings = ZhipuAIEmbeddings()
65
+ vectordb = create_vectordb(data_path, persist_directory, embeddings, chunk_size)
66
+ retriever = vectordb.as_retriever()
67
+ prompt = hub.pull("rlm/rag-prompt")
68
+ prompt.messages[
69
+ 0
70
+ ].prompt.template = """
71
+ 您是一名用于问答任务的助手。使用检索到的上下文来回答问题。如果没有高度相关上下文 你就自由回答。\
72
+ 根据检索到的上下文,结合我的问题,直接给出最后的回答,要详细覆盖全方面。\
73
+ \n问题:{question} \n上下文:{context} \n回答:
74
+ """
75
+ llm = ZhipuLLM()
76
+ print("初始化数据库结束")
77
+ return (
78
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
79
+ | prompt
80
+ | llm
81
+ | StrOutputParser()
82
+ )
83
+
84
+
85
+ def format_docs(docs):
86
+ return "\n\n".join(doc.page_content for doc in docs)
87
+
88
+
89
+ def handle_question(chain, question: str, chat_history):
90
+ if not question:
91
+ return "", chat_history
92
+ try:
93
+ result = chain.invoke(question)
94
+ chat_history.append((question, result))
95
+ return "", chat_history
96
+ except Exception as e:
97
+ return str(e), chat_history
98
+
99
+
100
+ # Define scenarios
101
+ scenarios = {
102
+ "敬酒礼仪文化": "1-etiquette",
103
+ "请客礼仪文化": "2-hospitality",
104
+ "送礼礼仪文化": "3-gifting",
105
+ "如何说对话": "5-communication",
106
+ "化解尴尬场合": "6-awkwardness",
107
+ "矛盾&冲突应对": "7-conflict",
108
+ }
109
+
110
+ # Initialize chains for all scenarios
111
+ chains = {}
112
+ for scenario_name, scenario_folder in scenarios.items():
113
+ data_path = os.path.join(
114
+ "./", "temp", "tianji-chinese", "RAG", scenario_folder
115
+ )
116
+ if not os.path.exists(data_path):
117
+ raise FileNotFoundError(f"Data path does not exist: {data_path}")
118
+
119
+ chunk_size = 1280
120
+ persist_directory = os.path.join("./", "temp", f"chromadb_{scenario_folder}")
121
+ chains[scenario_name] = initialize_chain(chunk_size, persist_directory, data_path)
122
+
123
+ # Create Gradio interface
124
+ TITLE = """
125
+ # Tianji 人情世故大模型系统完整版(基于知识库实现) 欢迎star!\n
126
+ ## 💫开源项目地址:https://github.com/SocialAI-tianji/Tianji
127
+ ## 使用方法:选择你想提问的场景,输入提示,或点击Example自动填充
128
+ ## 如果觉得回答不满意,可补充更多信息重复提问。
129
+ ### 我们的愿景是构建一个从数据收集开始的大模型全栈垂直领域开源实践.
130
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
+ def get_examples_for_scenario(scenario):
134
+ # Define examples for each scenario
135
+ examples_dict = {
136
+ "敬酒礼仪文化": [
137
+ "喝酒座位怎么排",
138
+ "喝酒的完整流程是什么",
139
+ "推荐的敬酒词怎么说",
140
+ "宴会怎么点菜",
141
+ "喝酒容易醉怎么办",
142
+ "喝酒的规矩是什么",
143
+ ],
144
+ "请客礼仪文化": ["请客有那些规矩", "如何选择合适的餐厅", "怎么请别人吃饭"],
145
+ "送礼礼仪文化": ["送什么礼物给长辈好", "怎么送礼", "回礼的礼节是什么"],
146
+ "如何说对话": [
147
+ "怎么和导师沟通",
148
+ "怎么提高情商",
149
+ "如何读懂潜台词",
150
+ "怎么安慰别人",
151
+ "怎么和孩子沟通",
152
+ "如何与男生聊天",
153
+ "如何与女生聊天",
154
+ "职场高情商回应技巧",
155
+ ],
156
+ "化解尴尬场合": ["怎么回应赞美", "怎么拒绝借钱", "如何高效沟通", "怎么和对象沟通", "聊天技巧", "怎么拒绝别人", "职场怎么沟通"],
157
+ "矛盾&冲突应对": [
158
+ "怎么控制情绪",
159
+ "怎么向别人道歉",
160
+ "和别人吵架了怎么办",
161
+ "如何化解尴尬",
162
+ "孩子有情绪怎么办",
163
+ "夫妻吵架怎么办",
164
+ "情侣冷战怎么办",
165
+ ],
166
+ }
167
+ return examples_dict.get(scenario, [])
168
+
169
+
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown(TITLE)
172
+
173
+ init_status = gr.Textbox(label="初始化状态", value="数据库已初始化", interactive=False)
174
+
175
+ with gr.Tabs() as tabs:
176
+ for scenario_name in scenarios.keys():
177
+ with gr.Tab(scenario_name):
178
+ chatbot = gr.Chatbot(height=450, show_copy_button=True)
179
+ msg = gr.Textbox(label="输入你的疑问")
180
+
181
+ examples = gr.Examples(
182
+ label="快速示例",
183
+ examples=get_examples_for_scenario(scenario_name),
184
+ inputs=[msg],
185
+ )
186
+
187
+ with gr.Row():
188
+ chat_button = gr.Button("聊天")
189
+ clear_button = gr.ClearButton(components=[chatbot], value="清除聊天记录")
190
+
191
+ # Define a function to invoke the chain for the current scenario
192
+ def invoke_chain(question, chat_history, scenario=scenario_name):
193
+ print(question)
194
+ return handle_question(chains[scenario], question, chat_history)
195
+
196
+ chat_button.click(
197
+ invoke_chain,
198
+ inputs=[msg, chatbot],
199
+ outputs=[msg, chatbot],
200
+ )
201
+
202
+
203
+ # Launch Gradio application
204
  if __name__ == "__main__":
205
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ langchain==0.2.15
3
+ langchain_community==0.2.14
4
+ langchain_chroma
5
+ zhipuai