sanbu commited on
Commit
5101d82
·
1 Parent(s): 86681f9
Files changed (3) hide show
  1. app.py +10 -200
  2. models.py +0 -86
  3. requirements.txt +17 -2
app.py CHANGED
@@ -1,205 +1,15 @@
1
  import os
2
- import gradio as gr
3
- from dotenv import load_dotenv
4
- from models import ZhipuAIEmbeddings, ZhipuLLM
5
- from langchain_chroma import Chroma
6
- from langchain_community.document_loaders import DirectoryLoader, TextLoader
7
- from langchain_text_splitters import RecursiveCharacterTextSplitter
8
- from langchain_core.runnables import RunnablePassthrough
9
- from langchain_core.output_parsers import StrOutputParser
10
- from langchain import hub
11
- from huggingface_hub import snapshot_download
12
 
13
- # Load environment variables
14
- load_dotenv()
15
 
16
- # Download dataset using Hugging Face's huggingface_hub
17
- destination_folder = os.path.join("./", "temp", "tianji-chinese")
18
- snapshot_download(
19
- repo_id="sanbu/tianji-chinese",
20
- local_dir=destination_folder,
21
- repo_type="dataset",
22
- local_dir_use_symlinks=False,
23
- )
24
 
 
 
 
25
 
26
- def create_vectordb(
27
- data_path: str,
28
- persist_directory: str,
29
- embedding_func,
30
- chunk_size: int,
31
- force: bool = False,
32
- ):
33
- if os.path.exists(persist_directory) and not force:
34
- return Chroma(
35
- persist_directory=persist_directory, embedding_function=embedding_func
36
- )
37
-
38
- if force and os.path.exists(persist_directory):
39
- if os.path.isdir(persist_directory):
40
- import shutil
41
-
42
- shutil.rmtree(persist_directory)
43
- else:
44
- os.remove(persist_directory)
45
-
46
- loader = DirectoryLoader(data_path, glob="*.txt", loader_cls=TextLoader)
47
- text_splitter = RecursiveCharacterTextSplitter(
48
- chunk_size=chunk_size, chunk_overlap=200
49
- )
50
- split_docs = text_splitter.split_documents(loader.load())
51
- if len(split_docs) == 0:
52
- raise gr.Error("当前知识数据无效,处理数据后为空")
53
-
54
- vector_db = Chroma.from_documents(
55
- documents=split_docs,
56
- embedding=embedding_func,
57
- persist_directory=persist_directory,
58
- )
59
- return vector_db
60
-
61
-
62
- def initialize_chain(chunk_size: int, persist_directory: str, data_path: str):
63
- print("初始化数据库开始")
64
- embeddings = ZhipuAIEmbeddings()
65
- vectordb = create_vectordb(data_path, persist_directory, embeddings, chunk_size)
66
- retriever = vectordb.as_retriever()
67
- prompt = hub.pull("rlm/rag-prompt")
68
- prompt.messages[
69
- 0
70
- ].prompt.template = """
71
- 您是一名用于问答任务的助手。使用检索到的上下文来回答问题。如果没有高度相关上下文 你就自由回答。\
72
- 根据检索到的上下文,结合我的问题,直接给出最后的回答,要详细覆盖全方面。\
73
- \n问题:{question} \n上下文:{context} \n回答:
74
- """
75
- llm = ZhipuLLM()
76
- print("初始化数据库结束")
77
- return (
78
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
79
- | prompt
80
- | llm
81
- | StrOutputParser()
82
- )
83
-
84
-
85
- def format_docs(docs):
86
- return "\n\n".join(doc.page_content for doc in docs)
87
-
88
-
89
- def handle_question(chain, question: str, chat_history):
90
- if not question:
91
- return "", chat_history
92
- try:
93
- result = chain.invoke(question)
94
- chat_history.append((question, result))
95
- return "", chat_history
96
- except Exception as e:
97
- return str(e), chat_history
98
-
99
-
100
- # Define scenarios
101
- scenarios = {
102
- "敬酒礼仪文化": "1-etiquette",
103
- "请客礼仪文化": "2-hospitality",
104
- "送礼礼仪文化": "3-gifting",
105
- "如何说对话": "5-communication",
106
- "化解尴尬场合": "6-awkwardness",
107
- "矛盾&冲突应对": "7-conflict",
108
- }
109
-
110
- # Initialize chains for all scenarios
111
- chains = {}
112
- for scenario_name, scenario_folder in scenarios.items():
113
- data_path = os.path.join(
114
- "./", "temp", "tianji-chinese", "RAG", scenario_folder
115
- )
116
- if not os.path.exists(data_path):
117
- raise FileNotFoundError(f"Data path does not exist: {data_path}")
118
-
119
- chunk_size = 1280
120
- persist_directory = os.path.join("./", "temp", f"chromadb_{scenario_folder}")
121
- chains[scenario_name] = initialize_chain(chunk_size, persist_directory, data_path)
122
-
123
- # Create Gradio interface
124
- TITLE = """
125
- # Tianji 人情世故大模型系统完整版(基于知识库实现) 欢迎star!\n
126
- ## 💫开源项目地址:https://github.com/SocialAI-tianji/Tianji
127
- ## 使用方法:选择你想提问的场景,输入提示,或点击Example自动填充
128
- ## 如果觉得回答不满意,可补充更多信息重复提问。
129
- ### 我们的愿景是构建一个从数据收集开始的大模型全栈垂直领域开源实践.
130
- """
131
-
132
-
133
- def get_examples_for_scenario(scenario):
134
- # Define examples for each scenario
135
- examples_dict = {
136
- "敬酒礼仪文化": [
137
- "喝酒座位怎么排",
138
- "喝酒的完整流程是什么",
139
- "推荐的敬酒词怎么说",
140
- "宴会怎么点菜",
141
- "喝酒容易醉怎么办",
142
- "喝酒的规矩是什么",
143
- ],
144
- "请客礼仪文化": ["请客有那些规矩", "如何选择合适的餐厅", "怎么请别人吃饭"],
145
- "送礼礼仪文化": ["送什么礼物给长辈好", "怎么送礼", "回礼的礼节是什么"],
146
- "如何说对话": [
147
- "怎么和导师沟通",
148
- "怎么提高情商",
149
- "如何读懂潜台词",
150
- "怎么安慰别人",
151
- "怎么和孩子沟通",
152
- "如何与男生聊天",
153
- "如何与女生聊天",
154
- "职场高情商回应技巧",
155
- ],
156
- "化解尴尬场合": ["怎么回应赞美", "怎么拒绝借钱", "如何高效沟通", "怎么和对象沟通", "聊天技巧", "怎么拒绝别人", "职场怎么沟通"],
157
- "矛盾&冲突应对": [
158
- "怎么控制情绪",
159
- "怎么向别人道歉",
160
- "和别人吵架了怎么办",
161
- "如何化解尴尬",
162
- "孩子有情绪怎么办",
163
- "夫妻吵架怎么办",
164
- "情侣冷战怎么办",
165
- ],
166
- }
167
- return examples_dict.get(scenario, [])
168
-
169
-
170
- with gr.Blocks() as demo:
171
- gr.Markdown(TITLE)
172
-
173
- init_status = gr.Textbox(label="初始化状态", value="数据库已初始化", interactive=False)
174
-
175
- with gr.Tabs() as tabs:
176
- for scenario_name in scenarios.keys():
177
- with gr.Tab(scenario_name):
178
- chatbot = gr.Chatbot(height=450, show_copy_button=True)
179
- msg = gr.Textbox(label="输入你的疑问")
180
-
181
- examples = gr.Examples(
182
- label="快速示例",
183
- examples=get_examples_for_scenario(scenario_name),
184
- inputs=[msg],
185
- )
186
-
187
- with gr.Row():
188
- chat_button = gr.Button("聊天")
189
- clear_button = gr.ClearButton(components=[chatbot], value="清除聊天记录")
190
-
191
- # Define a function to invoke the chain for the current scenario
192
- def invoke_chain(question, chat_history, scenario=scenario_name):
193
- print(question)
194
- return handle_question(chains[scenario], question, chat_history)
195
-
196
- chat_button.click(
197
- invoke_chain,
198
- inputs=[msg, chatbot],
199
- outputs=[msg, chatbot],
200
- )
201
-
202
-
203
- # Launch Gradio application
204
- if __name__ == "__main__":
205
- demo.launch()
 
1
  import os
2
+ import subprocess
 
 
 
 
 
 
 
 
 
3
 
4
+ repo_url = "https://github.com/SocialAI-tianji/Tianji.git"
5
+ clone_dir = "Tianji"
6
 
7
+ print("正在克隆 Tianji 仓库...")
8
+ subprocess.run(["git", "clone", repo_url], check=True)
 
 
 
 
 
 
9
 
10
+ print("正在安装依赖...")
11
+ os.chdir(clone_dir)
12
+ subprocess.run(["pip", "install", "-e", "."], check=True)
13
 
14
+ print("正在运行示例...")
15
+ subprocess.run(["python", "run/demo_rag_langchain_all.py"], check=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py DELETED
@@ -1,86 +0,0 @@
1
- from langchain_core.language_models.llms import LLM
2
- from langchain_core.callbacks.manager import CallbackManagerForLLMRun
3
- from langchain.embeddings.base import Embeddings
4
- from typing import Any, Dict, List, Optional
5
- import os
6
- from zhipuai import ZhipuAI
7
- from langchain.pydantic_v1 import BaseModel, root_validator
8
-
9
-
10
- class ZhipuLLM(LLM):
11
- """A custom chat model for ZhipuAI."""
12
-
13
- client: Any = None
14
-
15
- def __init__(self):
16
- super().__init__()
17
- print("Initializing model...")
18
- self.client = ZhipuAI(api_key=os.environ.get("ZHIPUAI_API_KEY"))
19
- print("Model initialization complete")
20
-
21
- def _call(
22
- self,
23
- prompt: str,
24
- stop: Optional[List[str]] = None,
25
- run_manager: Optional[CallbackManagerForLLMRun] = None,
26
- **kwargs: Any,
27
- ) -> str:
28
- """Run the LLM on the given input."""
29
-
30
- response = self.client.chat.completions.create(
31
- model="glm-4-flash",
32
- messages=[
33
- {"role": "user", "content": prompt},
34
- ],
35
- )
36
- return response.choices[0].message.content
37
-
38
- @property
39
- def _identifying_params(self) -> Dict[str, Any]:
40
- """Return a dictionary of identifying parameters."""
41
- return {"model_name": "ZhipuAI"}
42
-
43
- @property
44
- def _llm_type(self) -> str:
45
- """Get the type of language model used by this chat model."""
46
- return "ZhipuAI"
47
-
48
-
49
- class ZhipuAIEmbeddings(BaseModel, Embeddings):
50
- """`Zhipuai Embeddings` embedding models."""
51
-
52
- zhipuai_api_key: Optional[str] = None
53
-
54
- @root_validator()
55
- def validate_environment(cls, values: Dict) -> Dict:
56
- values["zhupuai_api_key"] = values.get("zhupuai_api_key") or os.getenv(
57
- "ZHIPUAI_API_KEY"
58
- )
59
- try:
60
- import zhipuai
61
-
62
- zhipuai.api_key = values["zhupuai_api_key"]
63
- values["client"] = zhipuai.ZhipuAI()
64
- except ImportError:
65
- raise ValueError(
66
- "Zhipuai package not found, please install it with `pip install zhipuai`"
67
- )
68
- return values
69
-
70
- def _embed(self, texts: str) -> List[float]:
71
- try:
72
- resp = self.client.embeddings.create(
73
- model="embedding-3",
74
- input=texts,
75
- )
76
- except Exception as e:
77
- raise ValueError(f"Error raised by inference endpoint: {e}")
78
- embeddings = resp.data[0].embedding
79
- return embeddings
80
-
81
- def embed_query(self, text: str) -> List[float]:
82
- resp = self.embed_documents([text])
83
- return resp[0]
84
-
85
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
86
- return [self._embed(text) for text in texts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,21 @@
1
- huggingface_hub==0.22.2
2
  python-dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  langchain==0.2.15
4
  langchain_community==0.2.14
5
  langchain_chroma
6
- zhipuai
 
 
1
  python-dotenv
2
+ modelscope
3
+ tiktoken
4
+ einops
5
+ loguru
6
+ bitsandbytes
7
+ duckduckgo_search==5.3.1b1
8
+ beautifulsoup4==4.12.3
9
+ gradio
10
+ streamlit
11
+ streamlit_chat
12
+ zhipuai
13
+
14
+ # for agent
15
+ # metagpt==0.8.1 # 注意, metagpt 安装可能会导致 llamaindex 及 langchian 版本改变
16
+
17
+ # for rag
18
  langchain==0.2.15
19
  langchain_community==0.2.14
20
  langchain_chroma
21
+ sentence-transformers