KevinHuSh commited on
Commit
8f9784a
·
1 Parent(s): 21cf732

Support Ollama (#261)

Browse files

### What problem does this PR solve?

Issue link:#221

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

README.md CHANGED
@@ -1,6 +1,6 @@
1
  <div align="center">
2
  <a href="https://demo.ragflow.io/">
3
- <img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
4
  </a>
5
  </div>
6
 
@@ -124,12 +124,12 @@
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
- * Running on http://172.22.0.5:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
- 5. In your web browser, enter the IP address of your server as prompted and log in to RAGFlow.
132
- > In the given scenario, you only need to enter `http://IP_of_RAGFlow ` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
133
  6. In [service_conf.yaml](./docker/service_conf.yaml), select the desired LLM factory in `user_default_llm` and update the `API_KEY` field with the corresponding API key.
134
 
135
  > See [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) for more information.
@@ -168,6 +168,11 @@ $ cd ragflow/docker
168
  $ docker compose up -d
169
  ```
170
 
 
 
 
 
 
171
  ## 📜 Roadmap
172
 
173
  See the [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
 
1
  <div align="center">
2
  <a href="https://demo.ragflow.io/">
3
+ <img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
4
  </a>
5
  </div>
6
 
 
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
+ * Running on http://x.x.x.x:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
+ 5. In your web browser, enter the IP address of your server and log in to RAGFlow.
132
+ > In the given scenario, you only need to enter `http://IP_OF_YOUR_MACHINE` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
133
  6. In [service_conf.yaml](./docker/service_conf.yaml), select the desired LLM factory in `user_default_llm` and update the `API_KEY` field with the corresponding API key.
134
 
135
  > See [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) for more information.
 
168
  $ docker compose up -d
169
  ```
170
 
171
+ ## 🆕 Latest Features
172
+
173
+ - Support [Ollam](./docs/ollama.md) for local LLM deployment.
174
+ - Support Chinese UI.
175
+
176
  ## 📜 Roadmap
177
 
178
  See the [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
README_ja.md CHANGED
@@ -124,12 +124,12 @@
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
- * Running on http://172.22.0.5:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
  5. ウェブブラウザで、プロンプトに従ってサーバーの IP アドレスを入力し、RAGFlow にログインします。
132
- > デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://172.22.0.5`(ポート番号は省略)だけを入力すればよい。
133
  6. [service_conf.yaml](./docker/service_conf.yaml) で、`user_default_llm` で希望の LLM ファクトリを選択し、`API_KEY` フィールドを対応する API キーで更新する。
134
 
135
  > 詳しくは [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) を参照してください。
@@ -168,6 +168,11 @@ $ cd ragflow/docker
168
  $ docker compose up -d
169
  ```
170
 
 
 
 
 
 
171
  ## 📜 ロードマップ
172
 
173
  [RAGFlow ロードマップ 2024](https://github.com/infiniflow/ragflow/issues/162) を参照
 
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
+ * Running on http://x.x.x.x:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
  5. ウェブブラウザで、プロンプトに従ってサーバーの IP アドレスを入力し、RAGFlow にログインします。
132
+ > デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://IP_OF_YOUR_MACHINE`(ポート番号は省略)だけを入力すればよい。
133
  6. [service_conf.yaml](./docker/service_conf.yaml) で、`user_default_llm` で希望の LLM ファクトリを選択し、`API_KEY` フィールドを対応する API キーで更新する。
134
 
135
  > 詳しくは [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) を参照してください。
 
168
  $ docker compose up -d
169
  ```
170
 
171
+ ## 🆕 最新の新機能
172
+
173
+ - [Ollam](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
174
+ - 中国語インターフェースをサポートします。
175
+
176
  ## 📜 ロードマップ
177
 
178
  [RAGFlow ロードマップ 2024](https://github.com/infiniflow/ragflow/issues/162) を参照
README_zh.md CHANGED
@@ -124,12 +124,12 @@
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
- * Running on http://172.22.0.5:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
- 5. 根据刚才的界面提示在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
132
- > 上面这个例子中,您只需输入 http://172.22.0.5 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
133
  6. 在 [service_conf.yaml](./docker/service_conf.yaml) 文件的 `user_default_llm` 栏配置 LLM factory,并在 `API_KEY` 栏填写和你选择的大模型相对应的 API key。
134
 
135
  > 详见 [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md)。
@@ -168,9 +168,14 @@ $ cd ragflow/docker
168
  $ docker compose up -d
169
  ```
170
 
 
 
 
 
 
171
  ## 📜 路线图
172
 
173
- 详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)。
174
 
175
  ## 🏄 开源社区
176
 
@@ -179,7 +184,7 @@ $ docker compose up -d
179
 
180
  ## 🙌 贡献指南
181
 
182
- RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md)。
183
 
184
  ## 👥 加入社区
185
 
 
124
 
125
  * Running on all addresses (0.0.0.0)
126
  * Running on http://127.0.0.1:9380
127
+ * Running on http://x.x.x.x:9380
128
  INFO:werkzeug:Press CTRL+C to quit
129
  ```
130
 
131
+ 5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
132
+ > 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
133
  6. 在 [service_conf.yaml](./docker/service_conf.yaml) 文件的 `user_default_llm` 栏配置 LLM factory,并在 `API_KEY` 栏填写和你选择的大模型相对应的 API key。
134
 
135
  > 详见 [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md)。
 
168
  $ docker compose up -d
169
  ```
170
 
171
+ ## 🆕 最近新特性
172
+
173
+ - 支持用 [Ollam](./docs/ollama.md) 对大模型进行本地化部署。
174
+ - 支持中文界面。
175
+
176
  ## 📜 路线图
177
 
178
+ 详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
179
 
180
  ## 🏄 开源社区
181
 
 
184
 
185
  ## 🙌 贡献指南
186
 
187
+ RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md)
188
 
189
  ## 👥 加入社区
190
 
api/apps/conversation_app.py CHANGED
@@ -126,7 +126,7 @@ def message_fit_in(msg, max_length=4000):
126
  if c < max_length:
127
  return c, msg
128
 
129
- msg_ = [m for m in msg[:-1] if m.role == "system"]
130
  msg_.append(msg[-1])
131
  msg = msg_
132
  c = count()
 
126
  if c < max_length:
127
  return c, msg
128
 
129
+ msg_ = [m for m in msg[:-1] if m["role"] == "system"]
130
  msg_.append(msg[-1])
131
  msg = msg_
132
  c = count()
api/apps/document_app.py CHANGED
@@ -81,7 +81,7 @@ def upload():
81
  "parser_id": kb.parser_id,
82
  "parser_config": kb.parser_config,
83
  "created_by": current_user.id,
84
- "type": filename_type(filename),
85
  "name": filename,
86
  "location": location,
87
  "size": len(blob),
 
81
  "parser_id": kb.parser_id,
82
  "parser_config": kb.parser_config,
83
  "created_by": current_user.id,
84
+ "type": filetype,
85
  "name": filename,
86
  "location": location,
87
  "size": len(blob),
api/apps/llm_app.py CHANGED
@@ -91,6 +91,57 @@ def set_api_key():
91
  return get_json_result(data=True)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @manager.route('/my_llms', methods=['GET'])
95
  @login_required
96
  def my_llms():
@@ -125,6 +176,12 @@ def list():
125
  for m in llms:
126
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
127
 
 
 
 
 
 
 
128
  res = {}
129
  for m in llms:
130
  if model_type and m["model_type"] != model_type:
 
91
  return get_json_result(data=True)
92
 
93
 
94
+ @manager.route('/add_llm', methods=['POST'])
95
+ @login_required
96
+ @validate_request("llm_factory", "llm_name", "model_type")
97
+ def add_llm():
98
+ req = request.json
99
+ llm = {
100
+ "tenant_id": current_user.id,
101
+ "llm_factory": req["llm_factory"],
102
+ "model_type": req["model_type"],
103
+ "llm_name": req["llm_name"],
104
+ "api_base": req.get("api_base", ""),
105
+ "api_key": "xxxxxxxxxxxxxxx"
106
+ }
107
+
108
+ factory = req["llm_factory"]
109
+ msg = ""
110
+ if llm["model_type"] == LLMType.EMBEDDING.value:
111
+ mdl = EmbeddingModel[factory](
112
+ key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
113
+ try:
114
+ arr, tc = mdl.encode(["Test if the api key is available"])
115
+ if len(arr[0]) == 0 or tc == 0:
116
+ raise Exception("Fail")
117
+ except Exception as e:
118
+ msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
119
+ elif llm["model_type"] == LLMType.CHAT.value:
120
+ mdl = ChatModel[factory](
121
+ key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
122
+ try:
123
+ m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
124
+ "temperature": 0.9})
125
+ if not tc:
126
+ raise Exception(m)
127
+ except Exception as e:
128
+ msg += f"\nFail to access model({llm['llm_name']})." + str(
129
+ e)
130
+ else:
131
+ # TODO: check other type of models
132
+ pass
133
+
134
+ if msg:
135
+ return get_data_error_result(retmsg=msg)
136
+
137
+
138
+ if not TenantLLMService.filter_update(
139
+ [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
140
+ TenantLLMService.save(**llm)
141
+
142
+ return get_json_result(data=True)
143
+
144
+
145
  @manager.route('/my_llms', methods=['GET'])
146
  @login_required
147
  def my_llms():
 
176
  for m in llms:
177
  m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
178
 
179
+ llm_set = set([m["llm_name"] for m in llms])
180
+ for o in objs:
181
+ if not o.api_key:continue
182
+ if o.llm_name in llm_set:continue
183
+ llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
184
+
185
  res = {}
186
  for m in llms:
187
  if model_type and m["model_type"] != model_type:
api/apps/user_app.py CHANGED
@@ -181,6 +181,10 @@ def user_info():
181
 
182
 
183
  def rollback_user_registration(user_id):
 
 
 
 
184
  try:
185
  TenantService.delete_by_id(user_id)
186
  except Exception as e:
 
181
 
182
 
183
  def rollback_user_registration(user_id):
184
+ try:
185
+ UserService.delete_by_id(user_id)
186
+ except Exception as e:
187
+ pass
188
  try:
189
  TenantService.delete_by_id(user_id)
190
  except Exception as e:
api/db/init_data.py CHANGED
@@ -18,7 +18,7 @@ import time
18
  import uuid
19
 
20
  from api.db import LLMType, UserTenantRole
21
- from api.db.db_models import init_database_tables as init_web_db
22
  from api.db.services import UserService
23
  from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
24
  from api.db.services.user_service import TenantService, UserTenantService
@@ -100,16 +100,16 @@ factory_infos = [{
100
  "status": "1",
101
  },
102
  {
103
- "name": "Local",
104
  "logo": "",
105
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
106
  "status": "1",
107
  }, {
108
- "name": "Moonshot",
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
112
- }
113
  # {
114
  # "name": "文心一言",
115
  # "logo": "",
@@ -230,20 +230,6 @@ def init_llm_factory():
230
  "max_tokens": 512,
231
  "model_type": LLMType.EMBEDDING.value
232
  },
233
- # ---------------------- 本地 ----------------------
234
- {
235
- "fid": factory_infos[3]["name"],
236
- "llm_name": "qwen-14B-chat",
237
- "tags": "LLM,CHAT,",
238
- "max_tokens": 4096,
239
- "model_type": LLMType.CHAT.value
240
- }, {
241
- "fid": factory_infos[3]["name"],
242
- "llm_name": "flag-embedding",
243
- "tags": "TEXT EMBEDDING,",
244
- "max_tokens": 128 * 1000,
245
- "model_type": LLMType.EMBEDDING.value
246
- },
247
  # ------------------------ Moonshot -----------------------
248
  {
249
  "fid": factory_infos[4]["name"],
@@ -282,6 +268,9 @@ def init_llm_factory():
282
  except Exception as e:
283
  pass
284
 
 
 
 
285
  """
286
  drop table llm;
287
  drop table llm_factories;
@@ -295,8 +284,7 @@ def init_llm_factory():
295
  def init_web_data():
296
  start_time = time.time()
297
 
298
- if LLMFactoriesService.get_all().count() != len(factory_infos):
299
- init_llm_factory()
300
  if not UserService.get_all().count():
301
  init_superuser()
302
 
 
18
  import uuid
19
 
20
  from api.db import LLMType, UserTenantRole
21
+ from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
22
  from api.db.services import UserService
23
  from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
24
  from api.db.services.user_service import TenantService, UserTenantService
 
100
  "status": "1",
101
  },
102
  {
103
+ "name": "Ollama",
104
  "logo": "",
105
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
106
  "status": "1",
107
  }, {
108
+ "name": "Moonshot",
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
112
+ },
113
  # {
114
  # "name": "文心一言",
115
  # "logo": "",
 
230
  "max_tokens": 512,
231
  "model_type": LLMType.EMBEDDING.value
232
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  # ------------------------ Moonshot -----------------------
234
  {
235
  "fid": factory_infos[4]["name"],
 
268
  except Exception as e:
269
  pass
270
 
271
+ LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
272
+ LLMService.filter_delete([LLM.fid=="Local"])
273
+
274
  """
275
  drop table llm;
276
  drop table llm_factories;
 
284
  def init_web_data():
285
  start_time = time.time()
286
 
287
+ init_llm_factory()
 
288
  if not UserService.get_all().count():
289
  init_superuser()
290
 
docker/docker-compose-CN.yml CHANGED
@@ -20,6 +20,7 @@ services:
20
  - 443:443
21
  volumes:
22
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
 
23
  - ./ragflow-logs:/ragflow/logs
24
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
25
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
 
20
  - 443:443
21
  volumes:
22
  - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
23
+ - ./entrypoint.sh:/ragflow/entrypoint.sh
24
  - ./ragflow-logs:/ragflow/logs
25
  - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
26
  - ./nginx/proxy.conf:/etc/nginx/proxy.conf
docs/ollama.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ollama
2
+
3
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
4
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
5
+ </div>
6
+
7
+ One-click deployment of local LLMs, that is [Ollama](https://github.com/ollama/ollama).
8
+
9
+ ## Install
10
+
11
+ - [Ollama on Linux](https://github.com/ollama/ollama/blob/main/docs/linux.md)
12
+ - [Ollama Windows Preview](https://github.com/ollama/ollama/blob/main/docs/windows.md)
13
+ - [Docker](https://hub.docker.com/r/ollama/ollama)
14
+
15
+ ## Launch Ollama
16
+
17
+ Decide which LLM you want to deploy ([here's a list for supported LLM](https://ollama.com/library)), say, **mistral**:
18
+ ```bash
19
+ $ ollama run mistral
20
+ ```
21
+ Or,
22
+ ```bash
23
+ $ docker exec -it ollama ollama run mistral
24
+ ```
25
+
26
+ ## Use Ollama in RAGFlow
27
+
28
+ - Go to 'Settings > Model Providers > Models to be added > Ollama'.
29
+
30
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
31
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
32
+ </div>
33
+
34
+ > Base URL: Enter the base URL where the Ollama service is accessible, like, http://<your-ollama-endpoint-domain>:11434
35
+
36
+ - Use Ollama Models.
37
+
38
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
39
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
40
+ </div>
rag/llm/__init__.py CHANGED
@@ -19,7 +19,7 @@ from .cv_model import *
19
 
20
 
21
  EmbeddingModel = {
22
- "Local": HuEmbedding,
23
  "OpenAI": OpenAIEmbed,
24
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
25
  "ZHIPU-AI": ZhipuEmbed,
@@ -29,7 +29,7 @@ EmbeddingModel = {
29
 
30
  CvModel = {
31
  "OpenAI": GptV4,
32
- "Local": LocalCV,
33
  "Tongyi-Qianwen": QWenCV,
34
  "ZHIPU-AI": Zhipu4V,
35
  "Moonshot": LocalCV
@@ -40,7 +40,7 @@ ChatModel = {
40
  "OpenAI": GptTurbo,
41
  "ZHIPU-AI": ZhipuChat,
42
  "Tongyi-Qianwen": QWenChat,
43
- "Local": LocalLLM,
44
  "Moonshot": MoonshotChat
45
  }
46
 
 
19
 
20
 
21
  EmbeddingModel = {
22
+ "Ollama": OllamaEmbed,
23
  "OpenAI": OpenAIEmbed,
24
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
25
  "ZHIPU-AI": ZhipuEmbed,
 
29
 
30
  CvModel = {
31
  "OpenAI": GptV4,
32
+ "Ollama": OllamaCV,
33
  "Tongyi-Qianwen": QWenCV,
34
  "ZHIPU-AI": Zhipu4V,
35
  "Moonshot": LocalCV
 
40
  "OpenAI": GptTurbo,
41
  "ZHIPU-AI": ZhipuChat,
42
  "Tongyi-Qianwen": QWenChat,
43
+ "Ollama": OllamaChat,
44
  "Moonshot": MoonshotChat
45
  }
46
 
rag/llm/chat_model.py CHANGED
@@ -18,6 +18,7 @@ from dashscope import Generation
18
  from abc import ABC
19
  from openai import OpenAI
20
  import openai
 
21
  from rag.nlp import is_english
22
  from rag.utils import num_tokens_from_string
23
 
@@ -129,6 +130,32 @@ class ZhipuChat(Base):
129
  return "**ERROR**: " + str(e), 0
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  class LocalLLM(Base):
133
  class RPCProxy:
134
  def __init__(self, host, port):
 
18
  from abc import ABC
19
  from openai import OpenAI
20
  import openai
21
+ from ollama import Client
22
  from rag.nlp import is_english
23
  from rag.utils import num_tokens_from_string
24
 
 
130
  return "**ERROR**: " + str(e), 0
131
 
132
 
133
+ class OllamaChat(Base):
134
+ def __init__(self, key, model_name, **kwargs):
135
+ self.client = Client(host=kwargs["base_url"])
136
+ self.model_name = model_name
137
+
138
+ def chat(self, system, history, gen_conf):
139
+ if system:
140
+ history.insert(0, {"role": "system", "content": system})
141
+ try:
142
+ options = {"temperature": gen_conf.get("temperature", 0.1),
143
+ "num_predict": gen_conf.get("max_tokens", 128),
144
+ "top_k": gen_conf.get("top_p", 0.3),
145
+ "presence_penalty": gen_conf.get("presence_penalty", 0.4),
146
+ "frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
147
+ }
148
+ response = self.client.chat(
149
+ model=self.model_name,
150
+ messages=history,
151
+ options=options
152
+ )
153
+ ans = response["message"]["content"].strip()
154
+ return ans, response["eval_count"]
155
+ except Exception as e:
156
+ return "**ERROR**: " + str(e), 0
157
+
158
+
159
  class LocalLLM(Base):
160
  class RPCProxy:
161
  def __init__(self, host, port):
rag/llm/cv_model.py CHANGED
@@ -16,7 +16,7 @@
16
  from zhipuai import ZhipuAI
17
  import io
18
  from abc import ABC
19
-
20
  from PIL import Image
21
  from openai import OpenAI
22
  import os
@@ -140,6 +140,28 @@ class Zhipu4V(Base):
140
  return res.choices[0].message.content.strip(), res.usage.total_tokens
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  class LocalCV(Base):
144
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
145
  pass
 
16
  from zhipuai import ZhipuAI
17
  import io
18
  from abc import ABC
19
+ from ollama import Client
20
  from PIL import Image
21
  from openai import OpenAI
22
  import os
 
140
  return res.choices[0].message.content.strip(), res.usage.total_tokens
141
 
142
 
143
+ class OllamaCV(Base):
144
+ def __init__(self, key, model_name, lang="Chinese", **kwargs):
145
+ self.client = Client(host=kwargs["base_url"])
146
+ self.model_name = model_name
147
+ self.lang = lang
148
+
149
+ def describe(self, image, max_tokens=1024):
150
+ prompt = self.prompt("")
151
+ try:
152
+ options = {"num_predict": max_tokens}
153
+ response = self.client.generate(
154
+ model=self.model_name,
155
+ prompt=prompt[0]["content"][1]["text"],
156
+ images=[image],
157
+ options=options
158
+ )
159
+ ans = response["response"].strip()
160
+ return ans, 128
161
+ except Exception as e:
162
+ return "**ERROR**: " + str(e), 0
163
+
164
+
165
  class LocalCV(Base):
166
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
167
  pass
rag/llm/embedding_model.py CHANGED
@@ -16,13 +16,12 @@
16
  from zhipuai import ZhipuAI
17
  import os
18
  from abc import ABC
19
-
20
  import dashscope
21
  from openai import OpenAI
22
  from FlagEmbedding import FlagModel
23
  import torch
24
  import numpy as np
25
- from huggingface_hub import snapshot_download
26
 
27
  from api.utils.file_utils import get_project_base_directory
28
  from rag.utils import num_tokens_from_string
@@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
150
  res = self.client.embeddings.create(input=text,
151
  model=self.model_name)
152
  return np.array(res.data[0].embedding), res.usage.total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from zhipuai import ZhipuAI
17
  import os
18
  from abc import ABC
19
+ from ollama import Client
20
  import dashscope
21
  from openai import OpenAI
22
  from FlagEmbedding import FlagModel
23
  import torch
24
  import numpy as np
 
25
 
26
  from api.utils.file_utils import get_project_base_directory
27
  from rag.utils import num_tokens_from_string
 
149
  res = self.client.embeddings.create(input=text,
150
  model=self.model_name)
151
  return np.array(res.data[0].embedding), res.usage.total_tokens
152
+
153
+
154
+ class OllamaEmbed(Base):
155
+ def __init__(self, key, model_name, **kwargs):
156
+ self.client = Client(host=kwargs["base_url"])
157
+ self.model_name = model_name
158
+
159
+ def encode(self, texts: list, batch_size=32):
160
+ arr = []
161
+ tks_num = 0
162
+ for txt in texts:
163
+ res = self.client.embeddings(prompt=txt,
164
+ model=self.model_name)
165
+ arr.append(res["embedding"])
166
+ tks_num += 128
167
+ return np.array(arr), tks_num
168
+
169
+ def encode_queries(self, text):
170
+ res = self.client.embeddings(prompt=text,
171
+ model=self.model_name)
172
+ return np.array(res["embedding"]), 128
rag/svr/task_executor.py CHANGED
@@ -23,7 +23,8 @@ import re
23
  import sys
24
  import traceback
25
  from functools import partial
26
-
 
27
  from rag.settings import database_logger
28
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
29
 
@@ -97,8 +98,21 @@ def collect(comm, mod, tm):
97
  cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
98
  return tasks
99
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def build(row):
 
102
  if row["size"] > DOC_MAXIMUM_SIZE:
103
  set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
104
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@@ -111,11 +125,14 @@ def build(row):
111
  row["to_page"])
112
  chunker = FACTORY[row["parser_id"].lower()]
113
  try:
114
- cron_logger.info(
115
- "Chunkking {}/{}".format(row["location"], row["name"]))
116
- cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
 
117
  to_page=row["to_page"], lang=row["language"], callback=callback,
118
  kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
 
 
119
  except Exception as e:
120
  if re.search("(No such file|not found)", str(e)):
121
  callback(-1, "Can not find file <%s>" % row["name"])
 
23
  import sys
24
  import traceback
25
  from functools import partial
26
+ import signal
27
+ from contextlib import contextmanager
28
  from rag.settings import database_logger
29
  from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
30
 
 
98
  cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
99
  return tasks
100
 
101
+ @contextmanager
102
+ def timeout(time):
103
+ # Register a function to raise a TimeoutError on the signal.
104
+ signal.signal(signal.SIGALRM, raise_timeout)
105
+ # Schedule the signal to be sent after ``time``.
106
+ signal.alarm(time)
107
+ yield
108
+
109
+
110
+ def raise_timeout(signum, frame):
111
+ raise TimeoutError
112
+
113
 
114
  def build(row):
115
+ from timeit import default_timer as timer
116
  if row["size"] > DOC_MAXIMUM_SIZE:
117
  set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
118
  (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
 
125
  row["to_page"])
126
  chunker = FACTORY[row["parser_id"].lower()]
127
  try:
128
+ st = timer()
129
+ with timeout(30):
130
+ binary = MINIO.get(row["kb_id"], row["location"])
131
+ cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
132
  to_page=row["to_page"], lang=row["language"], callback=callback,
133
  kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
134
+ cron_logger.info(
135
+ "Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
136
  except Exception as e:
137
  if re.search("(No such file|not found)", str(e)):
138
  callback(-1, "Can not find file <%s>" % row["name"])