jinhai-2012 commited on
Commit
aeb6dbc
·
1 Parent(s): 7a6ad40

Format file format from Windows/dos to Unix (#1949)

Browse files

### What problem does this PR solve?

Related source file is in Windows/DOS format, they are format to Unix
format.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile.cuda +27 -27
  2. Dockerfile.scratch +56 -56
  3. Dockerfile.scratch.oc9 +58 -58
  4. agent/component/baidu.py +69 -69
  5. agent/component/baidufanyi.py +99 -99
  6. agent/component/bing.py +85 -85
  7. agent/component/deepl.py +62 -62
  8. agent/component/github.py +61 -61
  9. agent/component/google.py +96 -96
  10. agent/component/googlescholar.py +70 -70
  11. agent/component/qweather.py +111 -111
  12. agent/templates/websearch_assistant.json +0 -0
  13. agent/test/dsl_examples/keyword_wikipedia_and_generate.json +62 -62
  14. api/apps/__init__.py +124 -124
  15. api/apps/api_app.py +734 -734
  16. api/apps/chunk_app.py +318 -318
  17. api/apps/conversation_app.py +177 -177
  18. api/apps/dialog_app.py +172 -172
  19. api/apps/document_app.py +586 -586
  20. api/apps/kb_app.py +153 -153
  21. api/apps/llm_app.py +279 -279
  22. api/apps/user_app.py +391 -391
  23. api/db/__init__.py +102 -102
  24. api/db/db_models.py +972 -972
  25. api/db/db_utils.py +130 -130
  26. api/db/init_data.py +184 -184
  27. api/db/operatioins.py +21 -21
  28. api/db/reload_config_base.py +28 -28
  29. api/db/runtime_config.py +54 -54
  30. api/db/services/__init__.py +38 -38
  31. api/db/services/api_service.py +68 -68
  32. api/db/services/common_service.py +183 -183
  33. api/db/services/dialog_service.py +392 -392
  34. api/db/services/document_service.py +382 -382
  35. api/db/services/knowledgebase_service.py +144 -144
  36. api/db/services/llm_service.py +242 -242
  37. api/db/services/task_service.py +175 -175
  38. api/ragflow_server.py +99 -99
  39. api/settings.py +251 -251
  40. api/utils/__init__.py +346 -346
  41. api/utils/api_utils.py +269 -269
  42. api/utils/commands.py +78 -78
  43. api/utils/file_utils.py +207 -207
  44. api/utils/log_utils.py +313 -313
  45. api/utils/t_crypt.py +24 -24
  46. api/versions.py +27 -27
  47. conf/service_conf.yaml +49 -49
  48. deepdoc/README.md +121 -121
  49. deepdoc/parser/ppt_parser.py +61 -61
  50. deepdoc/parser/resume/__init__.py +64 -64
Dockerfile.cuda CHANGED
@@ -1,27 +1,27 @@
1
- FROM infiniflow/ragflow-base:v2.0
2
- USER root
3
-
4
- WORKDIR /ragflow
5
-
6
- ## for cuda > 12.0
7
- RUN pip uninstall -y onnxruntime-gpu
8
- RUN pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
9
-
10
-
11
- ADD ./web ./web
12
- RUN cd ./web && npm i --force && npm run build
13
-
14
- ADD ./api ./api
15
- ADD ./conf ./conf
16
- ADD ./deepdoc ./deepdoc
17
- ADD ./rag ./rag
18
- ADD ./agent ./agent
19
- ADD ./graphrag ./graphrag
20
-
21
- ENV PYTHONPATH=/ragflow/
22
- ENV HF_ENDPOINT=https://hf-mirror.com
23
-
24
- ADD docker/entrypoint.sh ./entrypoint.sh
25
- RUN chmod +x ./entrypoint.sh
26
-
27
- ENTRYPOINT ["./entrypoint.sh"]
 
1
+ FROM infiniflow/ragflow-base:v2.0
2
+ USER root
3
+
4
+ WORKDIR /ragflow
5
+
6
+ ## for cuda > 12.0
7
+ RUN pip uninstall -y onnxruntime-gpu
8
+ RUN pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
9
+
10
+
11
+ ADD ./web ./web
12
+ RUN cd ./web && npm i --force && npm run build
13
+
14
+ ADD ./api ./api
15
+ ADD ./conf ./conf
16
+ ADD ./deepdoc ./deepdoc
17
+ ADD ./rag ./rag
18
+ ADD ./agent ./agent
19
+ ADD ./graphrag ./graphrag
20
+
21
+ ENV PYTHONPATH=/ragflow/
22
+ ENV HF_ENDPOINT=https://hf-mirror.com
23
+
24
+ ADD docker/entrypoint.sh ./entrypoint.sh
25
+ RUN chmod +x ./entrypoint.sh
26
+
27
+ ENTRYPOINT ["./entrypoint.sh"]
Dockerfile.scratch CHANGED
@@ -1,56 +1,56 @@
1
- FROM ubuntu:22.04
2
- USER root
3
-
4
- WORKDIR /ragflow
5
-
6
- RUN apt-get update && apt-get install -y wget curl build-essential libopenmpi-dev
7
-
8
- RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
9
- bash ~/miniconda.sh -b -p /root/miniconda3 && \
10
- rm ~/miniconda.sh && ln -s /root/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
11
- echo ". /root/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
12
- echo "conda activate base" >> ~/.bashrc
13
-
14
- ENV PATH /root/miniconda3/bin:$PATH
15
-
16
- RUN conda create -y --name py11 python=3.11
17
-
18
- ENV CONDA_DEFAULT_ENV py11
19
- ENV CONDA_PREFIX /root/miniconda3/envs/py11
20
- ENV PATH $CONDA_PREFIX/bin:$PATH
21
-
22
- RUN curl -sL https://deb.nodesource.com/setup_14.x | bash -
23
- RUN apt-get install -y nodejs
24
-
25
- RUN apt-get install -y nginx
26
-
27
- ADD ./web ./web
28
- ADD ./api ./api
29
- ADD ./conf ./conf
30
- ADD ./deepdoc ./deepdoc
31
- ADD ./rag ./rag
32
- ADD ./requirements.txt ./requirements.txt
33
- ADD ./agent ./agent
34
- ADD ./graphrag ./graphrag
35
-
36
- RUN apt install openmpi-bin openmpi-common libopenmpi-dev
37
- ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu/openmpi/lib:$LD_LIBRARY_PATH
38
- RUN rm /root/miniconda3/envs/py11/compiler_compat/ld
39
- RUN cd ./web && npm i --force && npm run build
40
- RUN conda run -n py11 pip install -i https://mirrors.aliyun.com/pypi/simple/ -r ./requirements.txt
41
-
42
- RUN apt-get update && \
43
- apt-get install -y libglib2.0-0 libgl1-mesa-glx && \
44
- rm -rf /var/lib/apt/lists/*
45
-
46
- RUN conda run -n py11 pip install -i https://mirrors.aliyun.com/pypi/simple/ ollama
47
- RUN conda run -n py11 python -m nltk.downloader punkt
48
- RUN conda run -n py11 python -m nltk.downloader wordnet
49
-
50
- ENV PYTHONPATH=/ragflow/
51
- ENV HF_ENDPOINT=https://hf-mirror.com
52
-
53
- ADD docker/entrypoint.sh ./entrypoint.sh
54
- RUN chmod +x ./entrypoint.sh
55
-
56
- ENTRYPOINT ["./entrypoint.sh"]
 
1
+ FROM ubuntu:22.04
2
+ USER root
3
+
4
+ WORKDIR /ragflow
5
+
6
+ RUN apt-get update && apt-get install -y wget curl build-essential libopenmpi-dev
7
+
8
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
9
+ bash ~/miniconda.sh -b -p /root/miniconda3 && \
10
+ rm ~/miniconda.sh && ln -s /root/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
11
+ echo ". /root/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
12
+ echo "conda activate base" >> ~/.bashrc
13
+
14
+ ENV PATH /root/miniconda3/bin:$PATH
15
+
16
+ RUN conda create -y --name py11 python=3.11
17
+
18
+ ENV CONDA_DEFAULT_ENV py11
19
+ ENV CONDA_PREFIX /root/miniconda3/envs/py11
20
+ ENV PATH $CONDA_PREFIX/bin:$PATH
21
+
22
+ RUN curl -sL https://deb.nodesource.com/setup_14.x | bash -
23
+ RUN apt-get install -y nodejs
24
+
25
+ RUN apt-get install -y nginx
26
+
27
+ ADD ./web ./web
28
+ ADD ./api ./api
29
+ ADD ./conf ./conf
30
+ ADD ./deepdoc ./deepdoc
31
+ ADD ./rag ./rag
32
+ ADD ./requirements.txt ./requirements.txt
33
+ ADD ./agent ./agent
34
+ ADD ./graphrag ./graphrag
35
+
36
+ RUN apt install openmpi-bin openmpi-common libopenmpi-dev
37
+ ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu/openmpi/lib:$LD_LIBRARY_PATH
38
+ RUN rm /root/miniconda3/envs/py11/compiler_compat/ld
39
+ RUN cd ./web && npm i --force && npm run build
40
+ RUN conda run -n py11 pip install -i https://mirrors.aliyun.com/pypi/simple/ -r ./requirements.txt
41
+
42
+ RUN apt-get update && \
43
+ apt-get install -y libglib2.0-0 libgl1-mesa-glx && \
44
+ rm -rf /var/lib/apt/lists/*
45
+
46
+ RUN conda run -n py11 pip install -i https://mirrors.aliyun.com/pypi/simple/ ollama
47
+ RUN conda run -n py11 python -m nltk.downloader punkt
48
+ RUN conda run -n py11 python -m nltk.downloader wordnet
49
+
50
+ ENV PYTHONPATH=/ragflow/
51
+ ENV HF_ENDPOINT=https://hf-mirror.com
52
+
53
+ ADD docker/entrypoint.sh ./entrypoint.sh
54
+ RUN chmod +x ./entrypoint.sh
55
+
56
+ ENTRYPOINT ["./entrypoint.sh"]
Dockerfile.scratch.oc9 CHANGED
@@ -1,58 +1,58 @@
1
- FROM opencloudos/opencloudos:9.0
2
- USER root
3
-
4
- WORKDIR /ragflow
5
-
6
- RUN dnf update -y && dnf install -y wget curl gcc-c++ openmpi-devel
7
-
8
- RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
9
- bash ~/miniconda.sh -b -p /root/miniconda3 && \
10
- rm ~/miniconda.sh && ln -s /root/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
11
- echo ". /root/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
12
- echo "conda activate base" >> ~/.bashrc
13
-
14
- ENV PATH /root/miniconda3/bin:$PATH
15
-
16
- RUN conda create -y --name py11 python=3.11
17
-
18
- ENV CONDA_DEFAULT_ENV py11
19
- ENV CONDA_PREFIX /root/miniconda3/envs/py11
20
- ENV PATH $CONDA_PREFIX/bin:$PATH
21
-
22
- # RUN curl -sL https://rpm.nodesource.com/setup_14.x | bash -
23
- RUN dnf install -y nodejs
24
-
25
- RUN dnf install -y nginx
26
-
27
- ADD ./web ./web
28
- ADD ./api ./api
29
- ADD ./conf ./conf
30
- ADD ./deepdoc ./deepdoc
31
- ADD ./rag ./rag
32
- ADD ./requirements.txt ./requirements.txt
33
- ADD ./agent ./agent
34
- ADD ./graphrag ./graphrag
35
-
36
- RUN dnf install -y openmpi openmpi-devel python3-openmpi
37
- ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH
38
- ENV LD_LIBRARY_PATH /usr/lib64/openmpi/lib:$LD_LIBRARY_PATH
39
- RUN rm /root/miniconda3/envs/py11/compiler_compat/ld
40
- RUN cd ./web && npm i --force && npm run build
41
- RUN conda run -n py11 pip install $(grep -ivE "mpi4py" ./requirements.txt) # without mpi4py==3.1.5
42
- RUN conda run -n py11 pip install redis
43
-
44
- RUN dnf update -y && \
45
- dnf install -y glib2 mesa-libGL && \
46
- dnf clean all
47
-
48
- RUN conda run -n py11 pip install ollama
49
- RUN conda run -n py11 python -m nltk.downloader punkt
50
- RUN conda run -n py11 python -m nltk.downloader wordnet
51
-
52
- ENV PYTHONPATH=/ragflow/
53
- ENV HF_ENDPOINT=https://hf-mirror.com
54
-
55
- ADD docker/entrypoint.sh ./entrypoint.sh
56
- RUN chmod +x ./entrypoint.sh
57
-
58
- ENTRYPOINT ["./entrypoint.sh"]
 
1
+ FROM opencloudos/opencloudos:9.0
2
+ USER root
3
+
4
+ WORKDIR /ragflow
5
+
6
+ RUN dnf update -y && dnf install -y wget curl gcc-c++ openmpi-devel
7
+
8
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
9
+ bash ~/miniconda.sh -b -p /root/miniconda3 && \
10
+ rm ~/miniconda.sh && ln -s /root/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
11
+ echo ". /root/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
12
+ echo "conda activate base" >> ~/.bashrc
13
+
14
+ ENV PATH /root/miniconda3/bin:$PATH
15
+
16
+ RUN conda create -y --name py11 python=3.11
17
+
18
+ ENV CONDA_DEFAULT_ENV py11
19
+ ENV CONDA_PREFIX /root/miniconda3/envs/py11
20
+ ENV PATH $CONDA_PREFIX/bin:$PATH
21
+
22
+ # RUN curl -sL https://rpm.nodesource.com/setup_14.x | bash -
23
+ RUN dnf install -y nodejs
24
+
25
+ RUN dnf install -y nginx
26
+
27
+ ADD ./web ./web
28
+ ADD ./api ./api
29
+ ADD ./conf ./conf
30
+ ADD ./deepdoc ./deepdoc
31
+ ADD ./rag ./rag
32
+ ADD ./requirements.txt ./requirements.txt
33
+ ADD ./agent ./agent
34
+ ADD ./graphrag ./graphrag
35
+
36
+ RUN dnf install -y openmpi openmpi-devel python3-openmpi
37
+ ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH
38
+ ENV LD_LIBRARY_PATH /usr/lib64/openmpi/lib:$LD_LIBRARY_PATH
39
+ RUN rm /root/miniconda3/envs/py11/compiler_compat/ld
40
+ RUN cd ./web && npm i --force && npm run build
41
+ RUN conda run -n py11 pip install $(grep -ivE "mpi4py" ./requirements.txt) # without mpi4py==3.1.5
42
+ RUN conda run -n py11 pip install redis
43
+
44
+ RUN dnf update -y && \
45
+ dnf install -y glib2 mesa-libGL && \
46
+ dnf clean all
47
+
48
+ RUN conda run -n py11 pip install ollama
49
+ RUN conda run -n py11 python -m nltk.downloader punkt
50
+ RUN conda run -n py11 python -m nltk.downloader wordnet
51
+
52
+ ENV PYTHONPATH=/ragflow/
53
+ ENV HF_ENDPOINT=https://hf-mirror.com
54
+
55
+ ADD docker/entrypoint.sh ./entrypoint.sh
56
+ RUN chmod +x ./entrypoint.sh
57
+
58
+ ENTRYPOINT ["./entrypoint.sh"]
agent/component/baidu.py CHANGED
@@ -1,69 +1,69 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import random
17
- from abc import ABC
18
- from functools import partial
19
- import pandas as pd
20
- import requests
21
- import re
22
- from agent.settings import DEBUG
23
- from agent.component.base import ComponentBase, ComponentParamBase
24
-
25
-
26
- class BaiduParam(ComponentParamBase):
27
- """
28
- Define the Baidu component parameters.
29
- """
30
-
31
- def __init__(self):
32
- super().__init__()
33
- self.top_n = 10
34
-
35
- def check(self):
36
- self.check_positive_integer(self.top_n, "Top N")
37
-
38
-
39
- class Baidu(ComponentBase, ABC):
40
- component_name = "Baidu"
41
-
42
- def _run(self, history, **kwargs):
43
- ans = self.get_input()
44
- ans = " - ".join(ans["content"]) if "content" in ans else ""
45
- if not ans:
46
- return Baidu.be_output("")
47
-
48
- try:
49
- url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n)
50
- headers = {
51
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.104 Safari/537.36'}
52
- response = requests.get(url=url, headers=headers)
53
-
54
- url_res = re.findall(r"'url': \\\"(.*?)\\\"}", response.text)
55
- title_res = re.findall(r"'title': \\\"(.*?)\\\",\\n", response.text)
56
- body_res = re.findall(r"\"contentText\":\"(.*?)\"", response.text)
57
- baidu_res = [{"content": re.sub('<em>|</em>', '', '<a href="' + url + '">' + title + '</a> ' + body)} for
58
- url, title, body in zip(url_res, title_res, body_res)]
59
- del body_res, url_res, title_res
60
- except Exception as e:
61
- return Baidu.be_output("**ERROR**: " + str(e))
62
-
63
- if not baidu_res:
64
- return Baidu.be_output("")
65
-
66
- df = pd.DataFrame(baidu_res)
67
- if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
68
- return df
69
-
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import random
17
+ from abc import ABC
18
+ from functools import partial
19
+ import pandas as pd
20
+ import requests
21
+ import re
22
+ from agent.settings import DEBUG
23
+ from agent.component.base import ComponentBase, ComponentParamBase
24
+
25
+
26
+ class BaiduParam(ComponentParamBase):
27
+ """
28
+ Define the Baidu component parameters.
29
+ """
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.top_n = 10
34
+
35
+ def check(self):
36
+ self.check_positive_integer(self.top_n, "Top N")
37
+
38
+
39
+ class Baidu(ComponentBase, ABC):
40
+ component_name = "Baidu"
41
+
42
+ def _run(self, history, **kwargs):
43
+ ans = self.get_input()
44
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
45
+ if not ans:
46
+ return Baidu.be_output("")
47
+
48
+ try:
49
+ url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n)
50
+ headers = {
51
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.104 Safari/537.36'}
52
+ response = requests.get(url=url, headers=headers)
53
+
54
+ url_res = re.findall(r"'url': \\\"(.*?)\\\"}", response.text)
55
+ title_res = re.findall(r"'title': \\\"(.*?)\\\",\\n", response.text)
56
+ body_res = re.findall(r"\"contentText\":\"(.*?)\"", response.text)
57
+ baidu_res = [{"content": re.sub('<em>|</em>', '', '<a href="' + url + '">' + title + '</a> ' + body)} for
58
+ url, title, body in zip(url_res, title_res, body_res)]
59
+ del body_res, url_res, title_res
60
+ except Exception as e:
61
+ return Baidu.be_output("**ERROR**: " + str(e))
62
+
63
+ if not baidu_res:
64
+ return Baidu.be_output("")
65
+
66
+ df = pd.DataFrame(baidu_res)
67
+ if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
68
+ return df
69
+
agent/component/baidufanyi.py CHANGED
@@ -1,99 +1,99 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import random
17
- from abc import ABC
18
- import requests
19
- from agent.component.base import ComponentBase, ComponentParamBase
20
- from hashlib import md5
21
-
22
-
23
- class BaiduFanyiParam(ComponentParamBase):
24
- """
25
- Define the BaiduFanyi component parameters.
26
- """
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.appid = "xxx"
31
- self.secret_key = "xxx"
32
- self.trans_type = 'translate'
33
- self.parameters = []
34
- self.source_lang = 'auto'
35
- self.target_lang = 'auto'
36
- self.domain = 'finance'
37
-
38
- def check(self):
39
- self.check_positive_integer(self.top_n, "Top N")
40
- self.check_empty(self.appid, "BaiduFanyi APPID")
41
- self.check_empty(self.secret_key, "BaiduFanyi Secret Key")
42
- self.check_valid_value(self.trans_type, "Translate type", ['translate', 'fieldtranslate'])
43
- self.check_valid_value(self.trans_type, "Translate domain",
44
- ['it', 'finance', 'machinery', 'senimed', 'novel', 'academic', 'aerospace', 'wiki',
45
- 'news', 'law', 'contract'])
46
- self.check_valid_value(self.source_lang, "Source language",
47
- ['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
48
- 'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
49
- 'hu', 'cht', 'vie'])
50
- self.check_valid_value(self.target_lang, "Target language",
51
- ['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
52
- 'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
53
- 'hu', 'cht', 'vie'])
54
- self.check_valid_value(self.domain, "Translate field",
55
- ['it', 'finance', 'machinery', 'senimed', 'novel', 'academic', 'aerospace', 'wiki',
56
- 'news', 'law', 'contract'])
57
-
58
-
59
- class BaiduFanyi(ComponentBase, ABC):
60
- component_name = "BaiduFanyi"
61
-
62
- def _run(self, history, **kwargs):
63
-
64
- ans = self.get_input()
65
- ans = " - ".join(ans["content"]) if "content" in ans else ""
66
- if not ans:
67
- return BaiduFanyi.be_output("")
68
-
69
- try:
70
- source_lang = self._param.source_lang
71
- target_lang = self._param.target_lang
72
- appid = self._param.appid
73
- salt = random.randint(32768, 65536)
74
- secret_key = self._param.secret_key
75
-
76
- if self._param.trans_type == 'translate':
77
- sign = md5((appid + ans + salt + secret_key).encode('utf-8')).hexdigest()
78
- url = 'http://api.fanyi.baidu.com/api/trans/vip/translate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&sign=' + sign
79
- headers = {"Content-Type": "application/x-www-form-urlencoded"}
80
- response = requests.post(url=url, headers=headers).json()
81
-
82
- if response.get('error_code'):
83
- BaiduFanyi.be_output("**Error**:" + response['error_msg'])
84
-
85
- return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
86
- elif self._param.trans_type == 'fieldtranslate':
87
- domain = self._param.domain
88
- sign = md5((appid + ans + salt + domain + secret_key).encode('utf-8')).hexdigest()
89
- url = 'http://api.fanyi.baidu.com/api/trans/vip/fieldtranslate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&domain=' + domain + '&sign=' + sign
90
- headers = {"Content-Type": "application/x-www-form-urlencoded"}
91
- response = requests.post(url=url, headers=headers).json()
92
-
93
- if response.get('error_code'):
94
- BaiduFanyi.be_output("**Error**:" + response['error_msg'])
95
-
96
- return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
97
-
98
- except Exception as e:
99
- BaiduFanyi.be_output("**Error**:" + str(e))
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import random
17
+ from abc import ABC
18
+ import requests
19
+ from agent.component.base import ComponentBase, ComponentParamBase
20
+ from hashlib import md5
21
+
22
+
23
+ class BaiduFanyiParam(ComponentParamBase):
24
+ """
25
+ Define the BaiduFanyi component parameters.
26
+ """
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.appid = "xxx"
31
+ self.secret_key = "xxx"
32
+ self.trans_type = 'translate'
33
+ self.parameters = []
34
+ self.source_lang = 'auto'
35
+ self.target_lang = 'auto'
36
+ self.domain = 'finance'
37
+
38
+ def check(self):
39
+ self.check_positive_integer(self.top_n, "Top N")
40
+ self.check_empty(self.appid, "BaiduFanyi APPID")
41
+ self.check_empty(self.secret_key, "BaiduFanyi Secret Key")
42
+ self.check_valid_value(self.trans_type, "Translate type", ['translate', 'fieldtranslate'])
43
+ self.check_valid_value(self.trans_type, "Translate domain",
44
+ ['it', 'finance', 'machinery', 'senimed', 'novel', 'academic', 'aerospace', 'wiki',
45
+ 'news', 'law', 'contract'])
46
+ self.check_valid_value(self.source_lang, "Source language",
47
+ ['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
48
+ 'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
49
+ 'hu', 'cht', 'vie'])
50
+ self.check_valid_value(self.target_lang, "Target language",
51
+ ['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
52
+ 'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
53
+ 'hu', 'cht', 'vie'])
54
+ self.check_valid_value(self.domain, "Translate field",
55
+ ['it', 'finance', 'machinery', 'senimed', 'novel', 'academic', 'aerospace', 'wiki',
56
+ 'news', 'law', 'contract'])
57
+
58
+
59
+ class BaiduFanyi(ComponentBase, ABC):
60
+ component_name = "BaiduFanyi"
61
+
62
+ def _run(self, history, **kwargs):
63
+
64
+ ans = self.get_input()
65
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
66
+ if not ans:
67
+ return BaiduFanyi.be_output("")
68
+
69
+ try:
70
+ source_lang = self._param.source_lang
71
+ target_lang = self._param.target_lang
72
+ appid = self._param.appid
73
+ salt = random.randint(32768, 65536)
74
+ secret_key = self._param.secret_key
75
+
76
+ if self._param.trans_type == 'translate':
77
+ sign = md5((appid + ans + salt + secret_key).encode('utf-8')).hexdigest()
78
+ url = 'http://api.fanyi.baidu.com/api/trans/vip/translate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&sign=' + sign
79
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
80
+ response = requests.post(url=url, headers=headers).json()
81
+
82
+ if response.get('error_code'):
83
+ BaiduFanyi.be_output("**Error**:" + response['error_msg'])
84
+
85
+ return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
86
+ elif self._param.trans_type == 'fieldtranslate':
87
+ domain = self._param.domain
88
+ sign = md5((appid + ans + salt + domain + secret_key).encode('utf-8')).hexdigest()
89
+ url = 'http://api.fanyi.baidu.com/api/trans/vip/fieldtranslate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&domain=' + domain + '&sign=' + sign
90
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
91
+ response = requests.post(url=url, headers=headers).json()
92
+
93
+ if response.get('error_code'):
94
+ BaiduFanyi.be_output("**Error**:" + response['error_msg'])
95
+
96
+ return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
97
+
98
+ except Exception as e:
99
+ BaiduFanyi.be_output("**Error**:" + str(e))
agent/component/bing.py CHANGED
@@ -1,85 +1,85 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- import requests
18
- import pandas as pd
19
- from agent.settings import DEBUG
20
- from agent.component.base import ComponentBase, ComponentParamBase
21
-
22
-
23
- class BingParam(ComponentParamBase):
24
- """
25
- Define the Bing component parameters.
26
- """
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.top_n = 10
31
- self.channel = "Webpages"
32
- self.api_key = "YOUR_ACCESS_KEY"
33
- self.country = "CN"
34
- self.language = "en"
35
-
36
- def check(self):
37
- self.check_positive_integer(self.top_n, "Top N")
38
- self.check_valid_value(self.channel, "Bing Web Search or Bing News", ["Webpages", "News"])
39
- self.check_empty(self.api_key, "Bing subscription key")
40
- self.check_valid_value(self.country, "Bing Country",
41
- ['AR', 'AU', 'AT', 'BE', 'BR', 'CA', 'CL', 'DK', 'FI', 'FR', 'DE', 'HK', 'IN', 'ID',
42
- 'IT', 'JP', 'KR', 'MY', 'MX', 'NL', 'NZ', 'NO', 'CN', 'PL', 'PT', 'PH', 'RU', 'SA',
43
- 'ZA', 'ES', 'SE', 'CH', 'TW', 'TR', 'GB', 'US'])
44
- self.check_valid_value(self.language, "Bing Languages",
45
- ['ar', 'eu', 'bn', 'bg', 'ca', 'ns', 'nt', 'hr', 'cs', 'da', 'nl', 'en', 'gb', 'et',
46
- 'fi', 'fr', 'gl', 'de', 'gu', 'he', 'hi', 'hu', 'is', 'it', 'jp', 'kn', 'ko', 'lv',
47
- 'lt', 'ms', 'ml', 'mr', 'nb', 'pl', 'br', 'pt', 'pa', 'ro', 'ru', 'sr', 'sk', 'sl',
48
- 'es', 'sv', 'ta', 'te', 'th', 'tr', 'uk', 'vi'])
49
-
50
-
51
- class Bing(ComponentBase, ABC):
52
- component_name = "Bing"
53
-
54
- def _run(self, history, **kwargs):
55
- ans = self.get_input()
56
- ans = " - ".join(ans["content"]) if "content" in ans else ""
57
- if not ans:
58
- return Bing.be_output("")
59
-
60
- try:
61
- headers = {"Ocp-Apim-Subscription-Key": self._param.api_key, 'Accept-Language': self._param.language}
62
- params = {"q": ans, "textDecorations": True, "textFormat": "HTML", "cc": self._param.country,
63
- "answerCount": 1, "promote": self._param.channel}
64
- if self._param.channel == "Webpages":
65
- response = requests.get("https://api.bing.microsoft.com/v7.0/search", headers=headers, params=params)
66
- response.raise_for_status()
67
- search_results = response.json()
68
- bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["snippet"]} for i in
69
- search_results["webPages"]["value"]]
70
- elif self._param.channel == "News":
71
- response = requests.get("https://api.bing.microsoft.com/v7.0/news/search", headers=headers,
72
- params=params)
73
- response.raise_for_status()
74
- search_results = response.json()
75
- bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["description"]} for i
76
- in search_results['news']['value']]
77
- except Exception as e:
78
- return Bing.be_output("**ERROR**: " + str(e))
79
-
80
- if not bing_res:
81
- return Bing.be_output("")
82
-
83
- df = pd.DataFrame(bing_res)
84
- if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
85
- return df
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ import requests
18
+ import pandas as pd
19
+ from agent.settings import DEBUG
20
+ from agent.component.base import ComponentBase, ComponentParamBase
21
+
22
+
23
+ class BingParam(ComponentParamBase):
24
+ """
25
+ Define the Bing component parameters.
26
+ """
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.top_n = 10
31
+ self.channel = "Webpages"
32
+ self.api_key = "YOUR_ACCESS_KEY"
33
+ self.country = "CN"
34
+ self.language = "en"
35
+
36
+ def check(self):
37
+ self.check_positive_integer(self.top_n, "Top N")
38
+ self.check_valid_value(self.channel, "Bing Web Search or Bing News", ["Webpages", "News"])
39
+ self.check_empty(self.api_key, "Bing subscription key")
40
+ self.check_valid_value(self.country, "Bing Country",
41
+ ['AR', 'AU', 'AT', 'BE', 'BR', 'CA', 'CL', 'DK', 'FI', 'FR', 'DE', 'HK', 'IN', 'ID',
42
+ 'IT', 'JP', 'KR', 'MY', 'MX', 'NL', 'NZ', 'NO', 'CN', 'PL', 'PT', 'PH', 'RU', 'SA',
43
+ 'ZA', 'ES', 'SE', 'CH', 'TW', 'TR', 'GB', 'US'])
44
+ self.check_valid_value(self.language, "Bing Languages",
45
+ ['ar', 'eu', 'bn', 'bg', 'ca', 'ns', 'nt', 'hr', 'cs', 'da', 'nl', 'en', 'gb', 'et',
46
+ 'fi', 'fr', 'gl', 'de', 'gu', 'he', 'hi', 'hu', 'is', 'it', 'jp', 'kn', 'ko', 'lv',
47
+ 'lt', 'ms', 'ml', 'mr', 'nb', 'pl', 'br', 'pt', 'pa', 'ro', 'ru', 'sr', 'sk', 'sl',
48
+ 'es', 'sv', 'ta', 'te', 'th', 'tr', 'uk', 'vi'])
49
+
50
+
51
+ class Bing(ComponentBase, ABC):
52
+ component_name = "Bing"
53
+
54
+ def _run(self, history, **kwargs):
55
+ ans = self.get_input()
56
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
57
+ if not ans:
58
+ return Bing.be_output("")
59
+
60
+ try:
61
+ headers = {"Ocp-Apim-Subscription-Key": self._param.api_key, 'Accept-Language': self._param.language}
62
+ params = {"q": ans, "textDecorations": True, "textFormat": "HTML", "cc": self._param.country,
63
+ "answerCount": 1, "promote": self._param.channel}
64
+ if self._param.channel == "Webpages":
65
+ response = requests.get("https://api.bing.microsoft.com/v7.0/search", headers=headers, params=params)
66
+ response.raise_for_status()
67
+ search_results = response.json()
68
+ bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["snippet"]} for i in
69
+ search_results["webPages"]["value"]]
70
+ elif self._param.channel == "News":
71
+ response = requests.get("https://api.bing.microsoft.com/v7.0/news/search", headers=headers,
72
+ params=params)
73
+ response.raise_for_status()
74
+ search_results = response.json()
75
+ bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["description"]} for i
76
+ in search_results['news']['value']]
77
+ except Exception as e:
78
+ return Bing.be_output("**ERROR**: " + str(e))
79
+
80
+ if not bing_res:
81
+ return Bing.be_output("")
82
+
83
+ df = pd.DataFrame(bing_res)
84
+ if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
85
+ return df
agent/component/deepl.py CHANGED
@@ -1,62 +1,62 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- import re
18
- from agent.component.base import ComponentBase, ComponentParamBase
19
- import deepl
20
-
21
-
22
- class DeepLParam(ComponentParamBase):
23
- """
24
- Define the DeepL component parameters.
25
- """
26
-
27
- def __init__(self):
28
- super().__init__()
29
- self.auth_key = "xxx"
30
- self.parameters = []
31
- self.source_lang = 'ZH'
32
- self.target_lang = 'EN-GB'
33
-
34
- def check(self):
35
- self.check_positive_integer(self.top_n, "Top N")
36
- self.check_valid_value(self.source_lang, "Source language",
37
- ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN', 'ES', 'ET', 'FI', 'FR', 'HU', 'ID', 'IT',
38
- 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT', 'RO', 'RU', 'SK', 'SL', 'SV', 'TR',
39
- 'UK', 'ZH'])
40
- self.check_valid_value(self.target_lang, "Target language",
41
- ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN-GB', 'EN-US', 'ES', 'ET', 'FI', 'FR', 'HU',
42
- 'ID', 'IT', 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT-BR', 'PT-PT', 'RO', 'RU',
43
- 'SK', 'SL', 'SV', 'TR', 'UK', 'ZH'])
44
-
45
-
46
- class DeepL(ComponentBase, ABC):
47
- component_name = "GitHub"
48
-
49
- def _run(self, history, **kwargs):
50
- ans = self.get_input()
51
- ans = " - ".join(ans["content"]) if "content" in ans else ""
52
- if not ans:
53
- return DeepL.be_output("")
54
-
55
- try:
56
- translator = deepl.Translator(self._param.auth_key)
57
- result = translator.translate_text(ans, source_lang=self._param.source_lang,
58
- target_lang=self._param.target_lang)
59
-
60
- return DeepL.be_output(result.text)
61
- except Exception as e:
62
- DeepL.be_output("**Error**:" + str(e))
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ import re
18
+ from agent.component.base import ComponentBase, ComponentParamBase
19
+ import deepl
20
+
21
+
22
+ class DeepLParam(ComponentParamBase):
23
+ """
24
+ Define the DeepL component parameters.
25
+ """
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.auth_key = "xxx"
30
+ self.parameters = []
31
+ self.source_lang = 'ZH'
32
+ self.target_lang = 'EN-GB'
33
+
34
+ def check(self):
35
+ self.check_positive_integer(self.top_n, "Top N")
36
+ self.check_valid_value(self.source_lang, "Source language",
37
+ ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN', 'ES', 'ET', 'FI', 'FR', 'HU', 'ID', 'IT',
38
+ 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT', 'RO', 'RU', 'SK', 'SL', 'SV', 'TR',
39
+ 'UK', 'ZH'])
40
+ self.check_valid_value(self.target_lang, "Target language",
41
+ ['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN-GB', 'EN-US', 'ES', 'ET', 'FI', 'FR', 'HU',
42
+ 'ID', 'IT', 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT-BR', 'PT-PT', 'RO', 'RU',
43
+ 'SK', 'SL', 'SV', 'TR', 'UK', 'ZH'])
44
+
45
+
46
+ class DeepL(ComponentBase, ABC):
47
+ component_name = "GitHub"
48
+
49
+ def _run(self, history, **kwargs):
50
+ ans = self.get_input()
51
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
52
+ if not ans:
53
+ return DeepL.be_output("")
54
+
55
+ try:
56
+ translator = deepl.Translator(self._param.auth_key)
57
+ result = translator.translate_text(ans, source_lang=self._param.source_lang,
58
+ target_lang=self._param.target_lang)
59
+
60
+ return DeepL.be_output(result.text)
61
+ except Exception as e:
62
+ DeepL.be_output("**Error**:" + str(e))
agent/component/github.py CHANGED
@@ -1,61 +1,61 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- import pandas as pd
18
- import requests
19
- from agent.settings import DEBUG
20
- from agent.component.base import ComponentBase, ComponentParamBase
21
-
22
-
23
- class GitHubParam(ComponentParamBase):
24
- """
25
- Define the GitHub component parameters.
26
- """
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.top_n = 10
31
-
32
- def check(self):
33
- self.check_positive_integer(self.top_n, "Top N")
34
-
35
-
36
- class GitHub(ComponentBase, ABC):
37
- component_name = "GitHub"
38
-
39
- def _run(self, history, **kwargs):
40
- ans = self.get_input()
41
- ans = " - ".join(ans["content"]) if "content" in ans else ""
42
- if not ans:
43
- return GitHub.be_output("")
44
-
45
- try:
46
- url = 'https://api.github.com/search/repositories?q=' + ans + '&sort=stars&order=desc&per_page=' + str(
47
- self._param.top_n)
48
- headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
49
- response = requests.get(url=url, headers=headers).json()
50
-
51
- github_res = [{"content": '<a href="' + i["html_url"] + '">' + i["name"] + '</a>' + str(
52
- i["description"]) + '\n stars:' + str(i['watchers'])} for i in response['items']]
53
- except Exception as e:
54
- return GitHub.be_output("**ERROR**: " + str(e))
55
-
56
- if not github_res:
57
- return GitHub.be_output("")
58
-
59
- df = pd.DataFrame(github_res)
60
- if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
61
- return df
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ import pandas as pd
18
+ import requests
19
+ from agent.settings import DEBUG
20
+ from agent.component.base import ComponentBase, ComponentParamBase
21
+
22
+
23
+ class GitHubParam(ComponentParamBase):
24
+ """
25
+ Define the GitHub component parameters.
26
+ """
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.top_n = 10
31
+
32
+ def check(self):
33
+ self.check_positive_integer(self.top_n, "Top N")
34
+
35
+
36
+ class GitHub(ComponentBase, ABC):
37
+ component_name = "GitHub"
38
+
39
+ def _run(self, history, **kwargs):
40
+ ans = self.get_input()
41
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
42
+ if not ans:
43
+ return GitHub.be_output("")
44
+
45
+ try:
46
+ url = 'https://api.github.com/search/repositories?q=' + ans + '&sort=stars&order=desc&per_page=' + str(
47
+ self._param.top_n)
48
+ headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
49
+ response = requests.get(url=url, headers=headers).json()
50
+
51
+ github_res = [{"content": '<a href="' + i["html_url"] + '">' + i["name"] + '</a>' + str(
52
+ i["description"]) + '\n stars:' + str(i['watchers'])} for i in response['items']]
53
+ except Exception as e:
54
+ return GitHub.be_output("**ERROR**: " + str(e))
55
+
56
+ if not github_res:
57
+ return GitHub.be_output("")
58
+
59
+ df = pd.DataFrame(github_res)
60
+ if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
61
+ return df
agent/component/google.py CHANGED
@@ -1,96 +1,96 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- from serpapi import GoogleSearch
18
- import pandas as pd
19
- from agent.settings import DEBUG
20
- from agent.component.base import ComponentBase, ComponentParamBase
21
-
22
-
23
- class GoogleParam(ComponentParamBase):
24
- """
25
- Define the Google component parameters.
26
- """
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.top_n = 10
31
- self.api_key = "xxx"
32
- self.country = "cn"
33
- self.language = "en"
34
-
35
- def check(self):
36
- self.check_positive_integer(self.top_n, "Top N")
37
- self.check_empty(self.api_key, "SerpApi API key")
38
- self.check_valid_value(self.country, "Google Country",
39
- ['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at',
40
- 'az', 'bs', 'bh', 'bd', 'bb', 'by', 'be', 'bz', 'bj', 'bm', 'bt', 'bo', 'ba', 'bw',
41
- 'bv', 'br', 'io', 'bn', 'bg', 'bf', 'bi', 'kh', 'cm', 'ca', 'cv', 'ky', 'cf', 'td',
42
- 'cl', 'cn', 'cx', 'cc', 'co', 'km', 'cg', 'cd', 'ck', 'cr', 'ci', 'hr', 'cu', 'cy',
43
- 'cz', 'dk', 'dj', 'dm', 'do', 'ec', 'eg', 'sv', 'gq', 'er', 'ee', 'et', 'fk', 'fo',
44
- 'fj', 'fi', 'fr', 'gf', 'pf', 'tf', 'ga', 'gm', 'ge', 'de', 'gh', 'gi', 'gr', 'gl',
45
- 'gd', 'gp', 'gu', 'gt', 'gn', 'gw', 'gy', 'ht', 'hm', 'va', 'hn', 'hk', 'hu', 'is',
46
- 'in', 'id', 'ir', 'iq', 'ie', 'il', 'it', 'jm', 'jp', 'jo', 'kz', 'ke', 'ki', 'kp',
47
- 'kr', 'kw', 'kg', 'la', 'lv', 'lb', 'ls', 'lr', 'ly', 'li', 'lt', 'lu', 'mo', 'mk',
48
- 'mg', 'mw', 'my', 'mv', 'ml', 'mt', 'mh', 'mq', 'mr', 'mu', 'yt', 'mx', 'fm', 'md',
49
- 'mc', 'mn', 'ms', 'ma', 'mz', 'mm', 'na', 'nr', 'np', 'nl', 'an', 'nc', 'nz', 'ni',
50
- 'ne', 'ng', 'nu', 'nf', 'mp', 'no', 'om', 'pk', 'pw', 'ps', 'pa', 'pg', 'py', 'pe',
51
- 'ph', 'pn', 'pl', 'pt', 'pr', 'qa', 're', 'ro', 'ru', 'rw', 'sh', 'kn', 'lc', 'pm',
52
- 'vc', 'ws', 'sm', 'st', 'sa', 'sn', 'rs', 'sc', 'sl', 'sg', 'sk', 'si', 'sb', 'so',
53
- 'za', 'gs', 'es', 'lk', 'sd', 'sr', 'sj', 'sz', 'se', 'ch', 'sy', 'tw', 'tj', 'tz',
54
- 'th', 'tl', 'tg', 'tk', 'to', 'tt', 'tn', 'tr', 'tm', 'tc', 'tv', 'ug', 'ua', 'ae',
55
- 'uk', 'gb', 'us', 'um', 'uy', 'uz', 'vu', 've', 'vn', 'vg', 'vi', 'wf', 'eh', 'ye',
56
- 'zm', 'zw'])
57
- self.check_valid_value(self.language, "Google languages",
58
- ['af', 'ak', 'sq', 'ws', 'am', 'ar', 'hy', 'az', 'eu', 'be', 'bem', 'bn', 'bh',
59
- 'xx-bork', 'bs', 'br', 'bg', 'bt', 'km', 'ca', 'chr', 'ny', 'zh-cn', 'zh-tw', 'co',
60
- 'hr', 'cs', 'da', 'nl', 'xx-elmer', 'en', 'eo', 'et', 'ee', 'fo', 'tl', 'fi', 'fr',
61
- 'fy', 'gaa', 'gl', 'ka', 'de', 'el', 'kl', 'gn', 'gu', 'xx-hacker', 'ht', 'ha', 'haw',
62
- 'iw', 'hi', 'hu', 'is', 'ig', 'id', 'ia', 'ga', 'it', 'ja', 'jw', 'kn', 'kk', 'rw',
63
- 'rn', 'xx-klingon', 'kg', 'ko', 'kri', 'ku', 'ckb', 'ky', 'lo', 'la', 'lv', 'ln', 'lt',
64
- 'loz', 'lg', 'ach', 'mk', 'mg', 'ms', 'ml', 'mt', 'mv', 'mi', 'mr', 'mfe', 'mo', 'mn',
65
- 'sr-me', 'my', 'ne', 'pcm', 'nso', 'no', 'nn', 'oc', 'or', 'om', 'ps', 'fa',
66
- 'xx-pirate', 'pl', 'pt', 'pt-br', 'pt-pt', 'pa', 'qu', 'ro', 'rm', 'nyn', 'ru', 'gd',
67
- 'sr', 'sh', 'st', 'tn', 'crs', 'sn', 'sd', 'si', 'sk', 'sl', 'so', 'es', 'es-419', 'su',
68
- 'sw', 'sv', 'tg', 'ta', 'tt', 'te', 'th', 'ti', 'to', 'lua', 'tum', 'tr', 'tk', 'tw',
69
- 'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu']
70
- )
71
-
72
-
73
- class Google(ComponentBase, ABC):
74
- component_name = "Google"
75
-
76
- def _run(self, history, **kwargs):
77
- ans = self.get_input()
78
- ans = " - ".join(ans["content"]) if "content" in ans else ""
79
- if not ans:
80
- return Google.be_output("")
81
-
82
- try:
83
- client = GoogleSearch(
84
- {"engine": "google", "q": ans, "api_key": self._param.api_key, "gl": self._param.country,
85
- "hl": self._param.language, "num": self._param.top_n})
86
- google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in
87
- client.get_dict()["organic_results"]]
88
- except Exception as e:
89
- return Google.be_output("**ERROR**: Existing Unavailable Parameters!")
90
-
91
- if not google_res:
92
- return Google.be_output("")
93
-
94
- df = pd.DataFrame(google_res)
95
- if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
96
- return df
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ from serpapi import GoogleSearch
18
+ import pandas as pd
19
+ from agent.settings import DEBUG
20
+ from agent.component.base import ComponentBase, ComponentParamBase
21
+
22
+
23
+ class GoogleParam(ComponentParamBase):
24
+ """
25
+ Define the Google component parameters.
26
+ """
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.top_n = 10
31
+ self.api_key = "xxx"
32
+ self.country = "cn"
33
+ self.language = "en"
34
+
35
+ def check(self):
36
+ self.check_positive_integer(self.top_n, "Top N")
37
+ self.check_empty(self.api_key, "SerpApi API key")
38
+ self.check_valid_value(self.country, "Google Country",
39
+ ['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at',
40
+ 'az', 'bs', 'bh', 'bd', 'bb', 'by', 'be', 'bz', 'bj', 'bm', 'bt', 'bo', 'ba', 'bw',
41
+ 'bv', 'br', 'io', 'bn', 'bg', 'bf', 'bi', 'kh', 'cm', 'ca', 'cv', 'ky', 'cf', 'td',
42
+ 'cl', 'cn', 'cx', 'cc', 'co', 'km', 'cg', 'cd', 'ck', 'cr', 'ci', 'hr', 'cu', 'cy',
43
+ 'cz', 'dk', 'dj', 'dm', 'do', 'ec', 'eg', 'sv', 'gq', 'er', 'ee', 'et', 'fk', 'fo',
44
+ 'fj', 'fi', 'fr', 'gf', 'pf', 'tf', 'ga', 'gm', 'ge', 'de', 'gh', 'gi', 'gr', 'gl',
45
+ 'gd', 'gp', 'gu', 'gt', 'gn', 'gw', 'gy', 'ht', 'hm', 'va', 'hn', 'hk', 'hu', 'is',
46
+ 'in', 'id', 'ir', 'iq', 'ie', 'il', 'it', 'jm', 'jp', 'jo', 'kz', 'ke', 'ki', 'kp',
47
+ 'kr', 'kw', 'kg', 'la', 'lv', 'lb', 'ls', 'lr', 'ly', 'li', 'lt', 'lu', 'mo', 'mk',
48
+ 'mg', 'mw', 'my', 'mv', 'ml', 'mt', 'mh', 'mq', 'mr', 'mu', 'yt', 'mx', 'fm', 'md',
49
+ 'mc', 'mn', 'ms', 'ma', 'mz', 'mm', 'na', 'nr', 'np', 'nl', 'an', 'nc', 'nz', 'ni',
50
+ 'ne', 'ng', 'nu', 'nf', 'mp', 'no', 'om', 'pk', 'pw', 'ps', 'pa', 'pg', 'py', 'pe',
51
+ 'ph', 'pn', 'pl', 'pt', 'pr', 'qa', 're', 'ro', 'ru', 'rw', 'sh', 'kn', 'lc', 'pm',
52
+ 'vc', 'ws', 'sm', 'st', 'sa', 'sn', 'rs', 'sc', 'sl', 'sg', 'sk', 'si', 'sb', 'so',
53
+ 'za', 'gs', 'es', 'lk', 'sd', 'sr', 'sj', 'sz', 'se', 'ch', 'sy', 'tw', 'tj', 'tz',
54
+ 'th', 'tl', 'tg', 'tk', 'to', 'tt', 'tn', 'tr', 'tm', 'tc', 'tv', 'ug', 'ua', 'ae',
55
+ 'uk', 'gb', 'us', 'um', 'uy', 'uz', 'vu', 've', 'vn', 'vg', 'vi', 'wf', 'eh', 'ye',
56
+ 'zm', 'zw'])
57
+ self.check_valid_value(self.language, "Google languages",
58
+ ['af', 'ak', 'sq', 'ws', 'am', 'ar', 'hy', 'az', 'eu', 'be', 'bem', 'bn', 'bh',
59
+ 'xx-bork', 'bs', 'br', 'bg', 'bt', 'km', 'ca', 'chr', 'ny', 'zh-cn', 'zh-tw', 'co',
60
+ 'hr', 'cs', 'da', 'nl', 'xx-elmer', 'en', 'eo', 'et', 'ee', 'fo', 'tl', 'fi', 'fr',
61
+ 'fy', 'gaa', 'gl', 'ka', 'de', 'el', 'kl', 'gn', 'gu', 'xx-hacker', 'ht', 'ha', 'haw',
62
+ 'iw', 'hi', 'hu', 'is', 'ig', 'id', 'ia', 'ga', 'it', 'ja', 'jw', 'kn', 'kk', 'rw',
63
+ 'rn', 'xx-klingon', 'kg', 'ko', 'kri', 'ku', 'ckb', 'ky', 'lo', 'la', 'lv', 'ln', 'lt',
64
+ 'loz', 'lg', 'ach', 'mk', 'mg', 'ms', 'ml', 'mt', 'mv', 'mi', 'mr', 'mfe', 'mo', 'mn',
65
+ 'sr-me', 'my', 'ne', 'pcm', 'nso', 'no', 'nn', 'oc', 'or', 'om', 'ps', 'fa',
66
+ 'xx-pirate', 'pl', 'pt', 'pt-br', 'pt-pt', 'pa', 'qu', 'ro', 'rm', 'nyn', 'ru', 'gd',
67
+ 'sr', 'sh', 'st', 'tn', 'crs', 'sn', 'sd', 'si', 'sk', 'sl', 'so', 'es', 'es-419', 'su',
68
+ 'sw', 'sv', 'tg', 'ta', 'tt', 'te', 'th', 'ti', 'to', 'lua', 'tum', 'tr', 'tk', 'tw',
69
+ 'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu']
70
+ )
71
+
72
+
73
+ class Google(ComponentBase, ABC):
74
+ component_name = "Google"
75
+
76
+ def _run(self, history, **kwargs):
77
+ ans = self.get_input()
78
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
79
+ if not ans:
80
+ return Google.be_output("")
81
+
82
+ try:
83
+ client = GoogleSearch(
84
+ {"engine": "google", "q": ans, "api_key": self._param.api_key, "gl": self._param.country,
85
+ "hl": self._param.language, "num": self._param.top_n})
86
+ google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in
87
+ client.get_dict()["organic_results"]]
88
+ except Exception as e:
89
+ return Google.be_output("**ERROR**: Existing Unavailable Parameters!")
90
+
91
+ if not google_res:
92
+ return Google.be_output("")
93
+
94
+ df = pd.DataFrame(google_res)
95
+ if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
96
+ return df
agent/component/googlescholar.py CHANGED
@@ -1,70 +1,70 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- import pandas as pd
18
- from agent.settings import DEBUG
19
- from agent.component.base import ComponentBase, ComponentParamBase
20
- from scholarly import scholarly
21
-
22
-
23
- class GoogleScholarParam(ComponentParamBase):
24
- """
25
- Define the GoogleScholar component parameters.
26
- """
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.top_n = 6
31
- self.sort_by = 'relevance'
32
- self.year_low = None
33
- self.year_high = None
34
- self.patents = True
35
-
36
- def check(self):
37
- self.check_positive_integer(self.top_n, "Top N")
38
- self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
39
- self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
40
-
41
-
42
- class GoogleScholar(ComponentBase, ABC):
43
- component_name = "GoogleScholar"
44
-
45
- def _run(self, history, **kwargs):
46
- ans = self.get_input()
47
- ans = " - ".join(ans["content"]) if "content" in ans else ""
48
- if not ans:
49
- return GoogleScholar.be_output("")
50
-
51
- scholar_client = scholarly.search_pubs(ans, patents=self._param.patents, year_low=self._param.year_low,
52
- year_high=self._param.year_high, sort_by=self._param.sort_by)
53
- scholar_res = []
54
- for i in range(self._param.top_n):
55
- try:
56
- pub = next(scholar_client)
57
- scholar_res.append({"content": 'Title: ' + pub['bib']['title'] + '\n_Url: <a href="' + pub[
58
- 'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[
59
- 'bib'].get('abstract', 'no abstract')})
60
-
61
- except StopIteration or Exception as e:
62
- print("**ERROR** " + str(e))
63
- break
64
-
65
- if not scholar_res:
66
- return GoogleScholar.be_output("")
67
-
68
- df = pd.DataFrame(scholar_res)
69
- if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
70
- return df
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ import pandas as pd
18
+ from agent.settings import DEBUG
19
+ from agent.component.base import ComponentBase, ComponentParamBase
20
+ from scholarly import scholarly
21
+
22
+
23
+ class GoogleScholarParam(ComponentParamBase):
24
+ """
25
+ Define the GoogleScholar component parameters.
26
+ """
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.top_n = 6
31
+ self.sort_by = 'relevance'
32
+ self.year_low = None
33
+ self.year_high = None
34
+ self.patents = True
35
+
36
+ def check(self):
37
+ self.check_positive_integer(self.top_n, "Top N")
38
+ self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
39
+ self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
40
+
41
+
42
+ class GoogleScholar(ComponentBase, ABC):
43
+ component_name = "GoogleScholar"
44
+
45
+ def _run(self, history, **kwargs):
46
+ ans = self.get_input()
47
+ ans = " - ".join(ans["content"]) if "content" in ans else ""
48
+ if not ans:
49
+ return GoogleScholar.be_output("")
50
+
51
+ scholar_client = scholarly.search_pubs(ans, patents=self._param.patents, year_low=self._param.year_low,
52
+ year_high=self._param.year_high, sort_by=self._param.sort_by)
53
+ scholar_res = []
54
+ for i in range(self._param.top_n):
55
+ try:
56
+ pub = next(scholar_client)
57
+ scholar_res.append({"content": 'Title: ' + pub['bib']['title'] + '\n_Url: <a href="' + pub[
58
+ 'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[
59
+ 'bib'].get('abstract', 'no abstract')})
60
+
61
+ except StopIteration or Exception as e:
62
+ print("**ERROR** " + str(e))
63
+ break
64
+
65
+ if not scholar_res:
66
+ return GoogleScholar.be_output("")
67
+
68
+ df = pd.DataFrame(scholar_res)
69
+ if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
70
+ return df
agent/component/qweather.py CHANGED
@@ -1,111 +1,111 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
- import pandas as pd
18
- import requests
19
- from agent.component.base import ComponentBase, ComponentParamBase
20
-
21
-
22
- class QWeatherParam(ComponentParamBase):
23
- """
24
- Define the QWeather component parameters.
25
- """
26
-
27
- def __init__(self):
28
- super().__init__()
29
- self.web_apikey = "xxx"
30
- self.lang = "zh"
31
- self.type = "weather"
32
- self.user_type = 'free'
33
- self.error_code = {
34
- "204": "The request was successful, but the region you are querying does not have the data you need at this time.",
35
- "400": "Request error, may contain incorrect request parameters or missing mandatory request parameters.",
36
- "401": "Authentication fails, possibly using the wrong KEY, wrong digital signature, wrong type of KEY (e.g. using the SDK's KEY to access the Web API).",
37
- "402": "Exceeded the number of accesses or the balance is not enough to support continued access to the service, you can recharge, upgrade the accesses or wait for the accesses to be reset.",
38
- "403": "No access, may be the binding PackageName, BundleID, domain IP address is inconsistent, or the data that requires additional payment.",
39
- "404": "The queried data or region does not exist.",
40
- "429": "Exceeded the limited QPM (number of accesses per minute), please refer to the QPM description",
41
- "500": "No response or timeout, interface service abnormality please contact us"
42
- }
43
- # Weather
44
- self.time_period = 'now'
45
-
46
- def check(self):
47
- self.check_empty(self.web_apikey, "BaiduFanyi APPID")
48
- self.check_valid_value(self.type, "Type", ["weather", "indices", "airquality"])
49
- self.check_valid_value(self.user_type, "Free subscription or paid subscription", ["free", "paid"])
50
- self.check_valid_value(self.lang, "Use language",
51
- ['zh', 'zh-hant', 'en', 'de', 'es', 'fr', 'it', 'ja', 'ko', 'ru', 'hi', 'th', 'ar', 'pt',
52
- 'bn', 'ms', 'nl', 'el', 'la', 'sv', 'id', 'pl', 'tr', 'cs', 'et', 'vi', 'fil', 'fi',
53
- 'he', 'is', 'nb'])
54
- self.check_vaild_value(self.time_period, "Time period", ['now', '3d', '7d', '10d', '15d', '30d'])
55
-
56
-
57
- class QWeather(ComponentBase, ABC):
58
- component_name = "QWeather"
59
-
60
- def _run(self, history, **kwargs):
61
- ans = self.get_input()
62
- ans = "".join(ans["content"]) if "content" in ans else ""
63
- if not ans:
64
- return QWeather.be_output("")
65
-
66
- try:
67
- response = requests.get(
68
- url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
69
- if response["code"] == "200":
70
- location_id = response["location"][0]["id"]
71
- else:
72
- return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
73
-
74
- base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
75
-
76
- if self._param.type == "weather":
77
- url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
78
- response = requests.get(url=url).json()
79
- if response["code"] == "200":
80
- if self._param.time_period == "now":
81
- return QWeather.be_output(str(response["now"]))
82
- else:
83
- qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
84
- if not qweather_res:
85
- return QWeather.be_output("")
86
-
87
- df = pd.DataFrame(qweather_res)
88
- return df
89
- else:
90
- return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
91
-
92
- elif self._param.type == "indices":
93
- url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
94
- response = requests.get(url=url).json()
95
- if response["code"] == "200":
96
- indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
97
- [i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
98
- return QWeather.be_output(indices_res)
99
-
100
- else:
101
- return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
102
-
103
- elif self._param.type == "airquality":
104
- url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
105
- response = requests.get(url=url).json()
106
- if response["code"] == "200":
107
- return QWeather.be_output(str(response["now"]))
108
- else:
109
- return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
110
- except Exception as e:
111
- return QWeather.be_output("**Error**" + str(e))
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from abc import ABC
17
+ import pandas as pd
18
+ import requests
19
+ from agent.component.base import ComponentBase, ComponentParamBase
20
+
21
+
22
+ class QWeatherParam(ComponentParamBase):
23
+ """
24
+ Define the QWeather component parameters.
25
+ """
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.web_apikey = "xxx"
30
+ self.lang = "zh"
31
+ self.type = "weather"
32
+ self.user_type = 'free'
33
+ self.error_code = {
34
+ "204": "The request was successful, but the region you are querying does not have the data you need at this time.",
35
+ "400": "Request error, may contain incorrect request parameters or missing mandatory request parameters.",
36
+ "401": "Authentication fails, possibly using the wrong KEY, wrong digital signature, wrong type of KEY (e.g. using the SDK's KEY to access the Web API).",
37
+ "402": "Exceeded the number of accesses or the balance is not enough to support continued access to the service, you can recharge, upgrade the accesses or wait for the accesses to be reset.",
38
+ "403": "No access, may be the binding PackageName, BundleID, domain IP address is inconsistent, or the data that requires additional payment.",
39
+ "404": "The queried data or region does not exist.",
40
+ "429": "Exceeded the limited QPM (number of accesses per minute), please refer to the QPM description",
41
+ "500": "No response or timeout, interface service abnormality please contact us"
42
+ }
43
+ # Weather
44
+ self.time_period = 'now'
45
+
46
+ def check(self):
47
+ self.check_empty(self.web_apikey, "BaiduFanyi APPID")
48
+ self.check_valid_value(self.type, "Type", ["weather", "indices", "airquality"])
49
+ self.check_valid_value(self.user_type, "Free subscription or paid subscription", ["free", "paid"])
50
+ self.check_valid_value(self.lang, "Use language",
51
+ ['zh', 'zh-hant', 'en', 'de', 'es', 'fr', 'it', 'ja', 'ko', 'ru', 'hi', 'th', 'ar', 'pt',
52
+ 'bn', 'ms', 'nl', 'el', 'la', 'sv', 'id', 'pl', 'tr', 'cs', 'et', 'vi', 'fil', 'fi',
53
+ 'he', 'is', 'nb'])
54
+ self.check_vaild_value(self.time_period, "Time period", ['now', '3d', '7d', '10d', '15d', '30d'])
55
+
56
+
57
+ class QWeather(ComponentBase, ABC):
58
+ component_name = "QWeather"
59
+
60
+ def _run(self, history, **kwargs):
61
+ ans = self.get_input()
62
+ ans = "".join(ans["content"]) if "content" in ans else ""
63
+ if not ans:
64
+ return QWeather.be_output("")
65
+
66
+ try:
67
+ response = requests.get(
68
+ url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
69
+ if response["code"] == "200":
70
+ location_id = response["location"][0]["id"]
71
+ else:
72
+ return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
73
+
74
+ base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
75
+
76
+ if self._param.type == "weather":
77
+ url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
78
+ response = requests.get(url=url).json()
79
+ if response["code"] == "200":
80
+ if self._param.time_period == "now":
81
+ return QWeather.be_output(str(response["now"]))
82
+ else:
83
+ qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
84
+ if not qweather_res:
85
+ return QWeather.be_output("")
86
+
87
+ df = pd.DataFrame(qweather_res)
88
+ return df
89
+ else:
90
+ return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
91
+
92
+ elif self._param.type == "indices":
93
+ url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
94
+ response = requests.get(url=url).json()
95
+ if response["code"] == "200":
96
+ indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
97
+ [i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
98
+ return QWeather.be_output(indices_res)
99
+
100
+ else:
101
+ return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
102
+
103
+ elif self._param.type == "airquality":
104
+ url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
105
+ response = requests.get(url=url).json()
106
+ if response["code"] == "200":
107
+ return QWeather.be_output(str(response["now"]))
108
+ else:
109
+ return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
110
+ except Exception as e:
111
+ return QWeather.be_output("**Error**" + str(e))
agent/templates/websearch_assistant.json CHANGED
The diff for this file is too large to render. See raw diff
 
agent/test/dsl_examples/keyword_wikipedia_and_generate.json CHANGED
@@ -1,62 +1,62 @@
1
- {
2
- "components": {
3
- "begin": {
4
- "obj":{
5
- "component_name": "Begin",
6
- "params": {
7
- "prologue": "Hi there!"
8
- }
9
- },
10
- "downstream": ["answer:0"],
11
- "upstream": []
12
- },
13
- "answer:0": {
14
- "obj": {
15
- "component_name": "Answer",
16
- "params": {}
17
- },
18
- "downstream": ["keyword:0"],
19
- "upstream": ["begin"]
20
- },
21
- "keyword:0": {
22
- "obj": {
23
- "component_name": "KeywordExtract",
24
- "params": {
25
- "llm_id": "deepseek-chat",
26
- "prompt": "- Role: You're a question analyzer.\n - Requirements:\n - Summarize user's question, and give top %s important keyword/phrase.\n - Use comma as a delimiter to separate keywords/phrases.\n - Answer format: (in language of user's question)\n - keyword: ",
27
- "temperature": 0.2,
28
- "top_n": 1
29
- }
30
- },
31
- "downstream": ["wikipedia:0"],
32
- "upstream": ["answer:0"]
33
- },
34
- "wikipedia:0": {
35
- "obj":{
36
- "component_name": "Wikipedia",
37
- "params": {
38
- "top_n": 10
39
- }
40
- },
41
- "downstream": ["generate:0"],
42
- "upstream": ["keyword:0"]
43
- },
44
- "generate:1": {
45
- "obj": {
46
- "component_name": "Generate",
47
- "params": {
48
- "llm_id": "deepseek-chat",
49
- "prompt": "You are an intelligent assistant. Please answer the question based on content from Wikipedia. When the answer from Wikipedia is incomplete, you need to output the URL link of the corresponding content as well. When all the content searched from Wikipedia is irrelevant to the question, your answer must include the sentence, \"The answer you are looking for is not found in the Wikipedia!\". Answers need to consider chat history.\n The content of Wikipedia is as follows:\n {input}\n The above is the content of Wikipedia.",
50
- "temperature": 0.2
51
- }
52
- },
53
- "downstream": ["answer:0"],
54
- "upstream": ["wikipedia:0"]
55
- }
56
- },
57
- "history": [],
58
- "path": [],
59
- "messages": [],
60
- "reference": {},
61
- "answer": []
62
- }
 
1
+ {
2
+ "components": {
3
+ "begin": {
4
+ "obj":{
5
+ "component_name": "Begin",
6
+ "params": {
7
+ "prologue": "Hi there!"
8
+ }
9
+ },
10
+ "downstream": ["answer:0"],
11
+ "upstream": []
12
+ },
13
+ "answer:0": {
14
+ "obj": {
15
+ "component_name": "Answer",
16
+ "params": {}
17
+ },
18
+ "downstream": ["keyword:0"],
19
+ "upstream": ["begin"]
20
+ },
21
+ "keyword:0": {
22
+ "obj": {
23
+ "component_name": "KeywordExtract",
24
+ "params": {
25
+ "llm_id": "deepseek-chat",
26
+ "prompt": "- Role: You're a question analyzer.\n - Requirements:\n - Summarize user's question, and give top %s important keyword/phrase.\n - Use comma as a delimiter to separate keywords/phrases.\n - Answer format: (in language of user's question)\n - keyword: ",
27
+ "temperature": 0.2,
28
+ "top_n": 1
29
+ }
30
+ },
31
+ "downstream": ["wikipedia:0"],
32
+ "upstream": ["answer:0"]
33
+ },
34
+ "wikipedia:0": {
35
+ "obj":{
36
+ "component_name": "Wikipedia",
37
+ "params": {
38
+ "top_n": 10
39
+ }
40
+ },
41
+ "downstream": ["generate:0"],
42
+ "upstream": ["keyword:0"]
43
+ },
44
+ "generate:1": {
45
+ "obj": {
46
+ "component_name": "Generate",
47
+ "params": {
48
+ "llm_id": "deepseek-chat",
49
+ "prompt": "You are an intelligent assistant. Please answer the question based on content from Wikipedia. When the answer from Wikipedia is incomplete, you need to output the URL link of the corresponding content as well. When all the content searched from Wikipedia is irrelevant to the question, your answer must include the sentence, \"The answer you are looking for is not found in the Wikipedia!\". Answers need to consider chat history.\n The content of Wikipedia is as follows:\n {input}\n The above is the content of Wikipedia.",
50
+ "temperature": 0.2
51
+ }
52
+ },
53
+ "downstream": ["answer:0"],
54
+ "upstream": ["wikipedia:0"]
55
+ }
56
+ },
57
+ "history": [],
58
+ "path": [],
59
+ "messages": [],
60
+ "reference": {},
61
+ "answer": []
62
+ }
api/apps/__init__.py CHANGED
@@ -1,125 +1,125 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import logging
17
- import os
18
- import sys
19
- from importlib.util import module_from_spec, spec_from_file_location
20
- from pathlib import Path
21
- from flask import Blueprint, Flask
22
- from werkzeug.wrappers.request import Request
23
- from flask_cors import CORS
24
-
25
- from api.db import StatusEnum
26
- from api.db.db_models import close_connection
27
- from api.db.services import UserService
28
- from api.utils import CustomJSONEncoder, commands
29
-
30
- from flask_session import Session
31
- from flask_login import LoginManager
32
- from api.settings import SECRET_KEY, stat_logger
33
- from api.settings import API_VERSION, access_logger
34
- from api.utils.api_utils import server_error_response
35
- from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
36
-
37
- __all__ = ['app']
38
-
39
-
40
- logger = logging.getLogger('flask.app')
41
- for h in access_logger.handlers:
42
- logger.addHandler(h)
43
-
44
- Request.json = property(lambda self: self.get_json(force=True, silent=True))
45
-
46
- app = Flask(__name__)
47
- CORS(app, supports_credentials=True,max_age=2592000)
48
- app.url_map.strict_slashes = False
49
- app.json_encoder = CustomJSONEncoder
50
- app.errorhandler(Exception)(server_error_response)
51
-
52
-
53
- ## convince for dev and debug
54
- #app.config["LOGIN_DISABLED"] = True
55
- app.config["SESSION_PERMANENT"] = False
56
- app.config["SESSION_TYPE"] = "filesystem"
57
- app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
58
-
59
- Session(app)
60
- login_manager = LoginManager()
61
- login_manager.init_app(app)
62
-
63
- commands.register_commands(app)
64
-
65
-
66
- def search_pages_path(pages_dir):
67
- app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
68
- api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')]
69
- app_path_list.extend(api_path_list)
70
- return app_path_list
71
-
72
-
73
- def register_page(page_path):
74
- path = f'{page_path}'
75
-
76
- page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app')
77
- module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,))
78
-
79
- spec = spec_from_file_location(module_name, page_path)
80
- page = module_from_spec(spec)
81
- page.app = app
82
- page.manager = Blueprint(page_name, module_name)
83
- sys.modules[module_name] = page
84
- spec.loader.exec_module(page)
85
- page_name = getattr(page, 'page_name', page_name)
86
- url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
87
-
88
- app.register_blueprint(page.manager, url_prefix=url_prefix)
89
- return url_prefix
90
-
91
-
92
- pages_dir = [
93
- Path(__file__).parent,
94
- Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove?
95
- ]
96
-
97
- client_urls_prefix = [
98
- register_page(path)
99
- for dir in pages_dir
100
- for path in search_pages_path(dir)
101
- ]
102
-
103
-
104
- @login_manager.request_loader
105
- def load_user(web_request):
106
- jwt = Serializer(secret_key=SECRET_KEY)
107
- authorization = web_request.headers.get("Authorization")
108
- if authorization:
109
- try:
110
- access_token = str(jwt.loads(authorization))
111
- user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
112
- if user:
113
- return user[0]
114
- else:
115
- return None
116
- except Exception as e:
117
- stat_logger.exception(e)
118
- return None
119
- else:
120
- return None
121
-
122
-
123
- @app.teardown_request
124
- def _db_close(exc):
125
  close_connection()
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import logging
17
+ import os
18
+ import sys
19
+ from importlib.util import module_from_spec, spec_from_file_location
20
+ from pathlib import Path
21
+ from flask import Blueprint, Flask
22
+ from werkzeug.wrappers.request import Request
23
+ from flask_cors import CORS
24
+
25
+ from api.db import StatusEnum
26
+ from api.db.db_models import close_connection
27
+ from api.db.services import UserService
28
+ from api.utils import CustomJSONEncoder, commands
29
+
30
+ from flask_session import Session
31
+ from flask_login import LoginManager
32
+ from api.settings import SECRET_KEY, stat_logger
33
+ from api.settings import API_VERSION, access_logger
34
+ from api.utils.api_utils import server_error_response
35
+ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
36
+
37
+ __all__ = ['app']
38
+
39
+
40
+ logger = logging.getLogger('flask.app')
41
+ for h in access_logger.handlers:
42
+ logger.addHandler(h)
43
+
44
+ Request.json = property(lambda self: self.get_json(force=True, silent=True))
45
+
46
+ app = Flask(__name__)
47
+ CORS(app, supports_credentials=True,max_age=2592000)
48
+ app.url_map.strict_slashes = False
49
+ app.json_encoder = CustomJSONEncoder
50
+ app.errorhandler(Exception)(server_error_response)
51
+
52
+
53
+ ## convince for dev and debug
54
+ #app.config["LOGIN_DISABLED"] = True
55
+ app.config["SESSION_PERMANENT"] = False
56
+ app.config["SESSION_TYPE"] = "filesystem"
57
+ app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
58
+
59
+ Session(app)
60
+ login_manager = LoginManager()
61
+ login_manager.init_app(app)
62
+
63
+ commands.register_commands(app)
64
+
65
+
66
+ def search_pages_path(pages_dir):
67
+ app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
68
+ api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')]
69
+ app_path_list.extend(api_path_list)
70
+ return app_path_list
71
+
72
+
73
+ def register_page(page_path):
74
+ path = f'{page_path}'
75
+
76
+ page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app')
77
+ module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,))
78
+
79
+ spec = spec_from_file_location(module_name, page_path)
80
+ page = module_from_spec(spec)
81
+ page.app = app
82
+ page.manager = Blueprint(page_name, module_name)
83
+ sys.modules[module_name] = page
84
+ spec.loader.exec_module(page)
85
+ page_name = getattr(page, 'page_name', page_name)
86
+ url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
87
+
88
+ app.register_blueprint(page.manager, url_prefix=url_prefix)
89
+ return url_prefix
90
+
91
+
92
+ pages_dir = [
93
+ Path(__file__).parent,
94
+ Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove?
95
+ ]
96
+
97
+ client_urls_prefix = [
98
+ register_page(path)
99
+ for dir in pages_dir
100
+ for path in search_pages_path(dir)
101
+ ]
102
+
103
+
104
+ @login_manager.request_loader
105
+ def load_user(web_request):
106
+ jwt = Serializer(secret_key=SECRET_KEY)
107
+ authorization = web_request.headers.get("Authorization")
108
+ if authorization:
109
+ try:
110
+ access_token = str(jwt.loads(authorization))
111
+ user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
112
+ if user:
113
+ return user[0]
114
+ else:
115
+ return None
116
+ except Exception as e:
117
+ stat_logger.exception(e)
118
+ return None
119
+ else:
120
+ return None
121
+
122
+
123
+ @app.teardown_request
124
+ def _db_close(exc):
125
  close_connection()
api/apps/api_app.py CHANGED
@@ -1,735 +1,735 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import json
17
- import os
18
- import re
19
- from datetime import datetime, timedelta
20
- from flask import request, Response
21
- from api.db.services.llm_service import TenantLLMService
22
- from flask_login import login_required, current_user
23
-
24
- from api.db import FileType, LLMType, ParserType, FileSource
25
- from api.db.db_models import APIToken, API4Conversation, Task, File
26
- from api.db.services import duplicate_name
27
- from api.db.services.api_service import APITokenService, API4ConversationService
28
- from api.db.services.dialog_service import DialogService, chat
29
- from api.db.services.document_service import DocumentService
30
- from api.db.services.file2document_service import File2DocumentService
31
- from api.db.services.file_service import FileService
32
- from api.db.services.knowledgebase_service import KnowledgebaseService
33
- from api.db.services.task_service import queue_tasks, TaskService
34
- from api.db.services.user_service import UserTenantService
35
- from api.settings import RetCode, retrievaler
36
- from api.utils import get_uuid, current_timestamp, datetime_format
37
- from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
38
- from itsdangerous import URLSafeTimedSerializer
39
-
40
- from api.utils.file_utils import filename_type, thumbnail
41
- from rag.nlp import keyword_extraction
42
- from rag.utils.minio_conn import MINIO
43
-
44
- from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
45
- from agent.canvas import Canvas
46
- from functools import partial
47
-
48
-
49
- def generate_confirmation_token(tenent_id):
50
- serializer = URLSafeTimedSerializer(tenent_id)
51
- return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
52
-
53
-
54
- @manager.route('/new_token', methods=['POST'])
55
- @login_required
56
- def new_token():
57
- req = request.json
58
- try:
59
- tenants = UserTenantService.query(user_id=current_user.id)
60
- if not tenants:
61
- return get_data_error_result(retmsg="Tenant not found!")
62
-
63
- tenant_id = tenants[0].tenant_id
64
- obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
65
- "create_time": current_timestamp(),
66
- "create_date": datetime_format(datetime.now()),
67
- "update_time": None,
68
- "update_date": None
69
- }
70
- if req.get("canvas_id"):
71
- obj["dialog_id"] = req["canvas_id"]
72
- obj["source"] = "agent"
73
- else:
74
- obj["dialog_id"] = req["dialog_id"]
75
-
76
- if not APITokenService.save(**obj):
77
- return get_data_error_result(retmsg="Fail to new a dialog!")
78
-
79
- return get_json_result(data=obj)
80
- except Exception as e:
81
- return server_error_response(e)
82
-
83
-
84
- @manager.route('/token_list', methods=['GET'])
85
- @login_required
86
- def token_list():
87
- try:
88
- tenants = UserTenantService.query(user_id=current_user.id)
89
- if not tenants:
90
- return get_data_error_result(retmsg="Tenant not found!")
91
-
92
- id = request.args["dialog_id"] if "dialog_id" in request.args else request.args["canvas_id"]
93
- objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id)
94
- return get_json_result(data=[o.to_dict() for o in objs])
95
- except Exception as e:
96
- return server_error_response(e)
97
-
98
-
99
- @manager.route('/rm', methods=['POST'])
100
- @validate_request("tokens", "tenant_id")
101
- @login_required
102
- def rm():
103
- req = request.json
104
- try:
105
- for token in req["tokens"]:
106
- APITokenService.filter_delete(
107
- [APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
108
- return get_json_result(data=True)
109
- except Exception as e:
110
- return server_error_response(e)
111
-
112
-
113
- @manager.route('/stats', methods=['GET'])
114
- @login_required
115
- def stats():
116
- try:
117
- tenants = UserTenantService.query(user_id=current_user.id)
118
- if not tenants:
119
- return get_data_error_result(retmsg="Tenant not found!")
120
- objs = API4ConversationService.stats(
121
- tenants[0].tenant_id,
122
- request.args.get(
123
- "from_date",
124
- (datetime.now() -
125
- timedelta(
126
- days=7)).strftime("%Y-%m-%d 00:00:00")),
127
- request.args.get(
128
- "to_date",
129
- datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
130
- "agent" if "canvas_id" in request.args else None)
131
- res = {
132
- "pv": [(o["dt"], o["pv"]) for o in objs],
133
- "uv": [(o["dt"], o["uv"]) for o in objs],
134
- "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
135
- "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
136
- "round": [(o["dt"], o["round"]) for o in objs],
137
- "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
138
- }
139
- return get_json_result(data=res)
140
- except Exception as e:
141
- return server_error_response(e)
142
-
143
-
144
- @manager.route('/new_conversation', methods=['GET'])
145
- def set_conversation():
146
- token = request.headers.get('Authorization').split()[1]
147
- objs = APIToken.query(token=token)
148
- if not objs:
149
- return get_json_result(
150
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
151
- req = request.json
152
- try:
153
- if objs[0].source == "agent":
154
- e, c = UserCanvasService.get_by_id(objs[0].dialog_id)
155
- if not e:
156
- return server_error_response("canvas not found.")
157
- conv = {
158
- "id": get_uuid(),
159
- "dialog_id": c.id,
160
- "user_id": request.args.get("user_id", ""),
161
- "message": [{"role": "assistant", "content": "Hi there!"}],
162
- "source": "agent"
163
- }
164
- API4ConversationService.save(**conv)
165
- return get_json_result(data=conv)
166
- else:
167
- e, dia = DialogService.get_by_id(objs[0].dialog_id)
168
- if not e:
169
- return get_data_error_result(retmsg="Dialog not found")
170
- conv = {
171
- "id": get_uuid(),
172
- "dialog_id": dia.id,
173
- "user_id": request.args.get("user_id", ""),
174
- "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
175
- }
176
- API4ConversationService.save(**conv)
177
- return get_json_result(data=conv)
178
- except Exception as e:
179
- return server_error_response(e)
180
-
181
-
182
- @manager.route('/completion', methods=['POST'])
183
- @validate_request("conversation_id", "messages")
184
- def completion():
185
- token = request.headers.get('Authorization').split()[1]
186
- objs = APIToken.query(token=token)
187
- if not objs:
188
- return get_json_result(
189
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
190
- req = request.json
191
- e, conv = API4ConversationService.get_by_id(req["conversation_id"])
192
- if not e:
193
- return get_data_error_result(retmsg="Conversation not found!")
194
- if "quote" not in req: req["quote"] = False
195
-
196
- msg = []
197
- for m in req["messages"]:
198
- if m["role"] == "system":
199
- continue
200
- if m["role"] == "assistant" and not msg:
201
- continue
202
- msg.append({"role": m["role"], "content": m["content"]})
203
-
204
- def fillin_conv(ans):
205
- nonlocal conv
206
- if not conv.reference:
207
- conv.reference.append(ans["reference"])
208
- else:
209
- conv.reference[-1] = ans["reference"]
210
- conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
211
-
212
- def rename_field(ans):
213
- reference = ans['reference']
214
- if not isinstance(reference, dict):
215
- return
216
- for chunk_i in reference.get('chunks', []):
217
- if 'docnm_kwd' in chunk_i:
218
- chunk_i['doc_name'] = chunk_i['docnm_kwd']
219
- chunk_i.pop('docnm_kwd')
220
-
221
- try:
222
- if conv.source == "agent":
223
- stream = req.get("stream", True)
224
- conv.message.append(msg[-1])
225
- e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
226
- if not e:
227
- return server_error_response("canvas not found.")
228
- del req["conversation_id"]
229
- del req["messages"]
230
-
231
- if not isinstance(cvs.dsl, str):
232
- cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
233
-
234
- if not conv.reference:
235
- conv.reference = []
236
- conv.message.append({"role": "assistant", "content": ""})
237
- conv.reference.append({"chunks": [], "doc_aggs": []})
238
-
239
- final_ans = {"reference": [], "content": ""}
240
- canvas = Canvas(cvs.dsl, objs[0].tenant_id)
241
-
242
- canvas.messages.append(msg[-1])
243
- canvas.add_user_input(msg[-1]["content"])
244
- answer = canvas.run(stream=stream)
245
-
246
- assert answer is not None, "Nothing. Is it over?"
247
-
248
- if stream:
249
- assert isinstance(answer, partial), "Nothing. Is it over?"
250
-
251
- def sse():
252
- nonlocal answer, cvs, conv
253
- try:
254
- for ans in answer():
255
- for k in ans.keys():
256
- final_ans[k] = ans[k]
257
- ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
258
- fillin_conv(ans)
259
- rename_field(ans)
260
- yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
261
- ensure_ascii=False) + "\n\n"
262
-
263
- canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
264
- if final_ans.get("reference"):
265
- canvas.reference.append(final_ans["reference"])
266
- cvs.dsl = json.loads(str(canvas))
267
- API4ConversationService.append_message(conv.id, conv.to_dict())
268
- except Exception as e:
269
- yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
270
- "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
271
- ensure_ascii=False) + "\n\n"
272
- yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
273
-
274
- resp = Response(sse(), mimetype="text/event-stream")
275
- resp.headers.add_header("Cache-control", "no-cache")
276
- resp.headers.add_header("Connection", "keep-alive")
277
- resp.headers.add_header("X-Accel-Buffering", "no")
278
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
279
- return resp
280
-
281
- final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
282
- canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
283
- if final_ans.get("reference"):
284
- canvas.reference.append(final_ans["reference"])
285
- cvs.dsl = json.loads(str(canvas))
286
-
287
- result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
288
- fillin_conv(result)
289
- API4ConversationService.append_message(conv.id, conv.to_dict())
290
- rename_field(result)
291
- return get_json_result(data=result)
292
-
293
- #******************For dialog******************
294
- conv.message.append(msg[-1])
295
- e, dia = DialogService.get_by_id(conv.dialog_id)
296
- if not e:
297
- return get_data_error_result(retmsg="Dialog not found!")
298
- del req["conversation_id"]
299
- del req["messages"]
300
-
301
- if not conv.reference:
302
- conv.reference = []
303
- conv.message.append({"role": "assistant", "content": ""})
304
- conv.reference.append({"chunks": [], "doc_aggs": []})
305
-
306
- def stream():
307
- nonlocal dia, msg, req, conv
308
- try:
309
- for ans in chat(dia, msg, True, **req):
310
- fillin_conv(ans)
311
- rename_field(ans)
312
- yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
313
- ensure_ascii=False) + "\n\n"
314
- API4ConversationService.append_message(conv.id, conv.to_dict())
315
- except Exception as e:
316
- yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
317
- "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
318
- ensure_ascii=False) + "\n\n"
319
- yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
320
-
321
- if req.get("stream", True):
322
- resp = Response(stream(), mimetype="text/event-stream")
323
- resp.headers.add_header("Cache-control", "no-cache")
324
- resp.headers.add_header("Connection", "keep-alive")
325
- resp.headers.add_header("X-Accel-Buffering", "no")
326
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
327
- return resp
328
-
329
- answer = None
330
- for ans in chat(dia, msg, **req):
331
- answer = ans
332
- fillin_conv(ans)
333
- API4ConversationService.append_message(conv.id, conv.to_dict())
334
- break
335
- rename_field(answer)
336
- return get_json_result(data=answer)
337
-
338
- except Exception as e:
339
- return server_error_response(e)
340
-
341
-
342
- @manager.route('/conversation/<conversation_id>', methods=['GET'])
343
- # @login_required
344
- def get(conversation_id):
345
- try:
346
- e, conv = API4ConversationService.get_by_id(conversation_id)
347
- if not e:
348
- return get_data_error_result(retmsg="Conversation not found!")
349
-
350
- conv = conv.to_dict()
351
- for referenct_i in conv['reference']:
352
- if referenct_i is None or len(referenct_i) == 0:
353
- continue
354
- for chunk_i in referenct_i['chunks']:
355
- if 'docnm_kwd' in chunk_i.keys():
356
- chunk_i['doc_name'] = chunk_i['docnm_kwd']
357
- chunk_i.pop('docnm_kwd')
358
- return get_json_result(data=conv)
359
- except Exception as e:
360
- return server_error_response(e)
361
-
362
-
363
- @manager.route('/document/upload', methods=['POST'])
364
- @validate_request("kb_name")
365
- def upload():
366
- token = request.headers.get('Authorization').split()[1]
367
- objs = APIToken.query(token=token)
368
- if not objs:
369
- return get_json_result(
370
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
371
-
372
- kb_name = request.form.get("kb_name").strip()
373
- tenant_id = objs[0].tenant_id
374
-
375
- try:
376
- e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
377
- if not e:
378
- return get_data_error_result(
379
- retmsg="Can't find this knowledgebase!")
380
- kb_id = kb.id
381
- except Exception as e:
382
- return server_error_response(e)
383
-
384
- if 'file' not in request.files:
385
- return get_json_result(
386
- data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
387
-
388
- file = request.files['file']
389
- if file.filename == '':
390
- return get_json_result(
391
- data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
392
-
393
- root_folder = FileService.get_root_folder(tenant_id)
394
- pf_id = root_folder["id"]
395
- FileService.init_knowledgebase_docs(pf_id, tenant_id)
396
- kb_root_folder = FileService.get_kb_folder(tenant_id)
397
- kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
398
-
399
- try:
400
- if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
401
- return get_data_error_result(
402
- retmsg="Exceed the maximum file number of a free user!")
403
-
404
- filename = duplicate_name(
405
- DocumentService.query,
406
- name=file.filename,
407
- kb_id=kb_id)
408
- filetype = filename_type(filename)
409
- if not filetype:
410
- return get_data_error_result(
411
- retmsg="This type of file has not been supported yet!")
412
-
413
- location = filename
414
- while MINIO.obj_exist(kb_id, location):
415
- location += "_"
416
- blob = request.files['file'].read()
417
- MINIO.put(kb_id, location, blob)
418
- doc = {
419
- "id": get_uuid(),
420
- "kb_id": kb.id,
421
- "parser_id": kb.parser_id,
422
- "parser_config": kb.parser_config,
423
- "created_by": kb.tenant_id,
424
- "type": filetype,
425
- "name": filename,
426
- "location": location,
427
- "size": len(blob),
428
- "thumbnail": thumbnail(filename, blob)
429
- }
430
-
431
- form_data = request.form
432
- if "parser_id" in form_data.keys():
433
- if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
434
- doc["parser_id"] = request.form.get("parser_id").strip()
435
- if doc["type"] == FileType.VISUAL:
436
- doc["parser_id"] = ParserType.PICTURE.value
437
- if doc["type"] == FileType.AURAL:
438
- doc["parser_id"] = ParserType.AUDIO.value
439
- if re.search(r"\.(ppt|pptx|pages)$", filename):
440
- doc["parser_id"] = ParserType.PRESENTATION.value
441
-
442
- doc_result = DocumentService.insert(doc)
443
- FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
444
- except Exception as e:
445
- return server_error_response(e)
446
-
447
- if "run" in form_data.keys():
448
- if request.form.get("run").strip() == "1":
449
- try:
450
- info = {"run": 1, "progress": 0}
451
- info["progress_msg"] = ""
452
- info["chunk_num"] = 0
453
- info["token_num"] = 0
454
- DocumentService.update_by_id(doc["id"], info)
455
- # if str(req["run"]) == TaskStatus.CANCEL.value:
456
- tenant_id = DocumentService.get_tenant_id(doc["id"])
457
- if not tenant_id:
458
- return get_data_error_result(retmsg="Tenant not found!")
459
-
460
- # e, doc = DocumentService.get_by_id(doc["id"])
461
- TaskService.filter_delete([Task.doc_id == doc["id"]])
462
- e, doc = DocumentService.get_by_id(doc["id"])
463
- doc = doc.to_dict()
464
- doc["tenant_id"] = tenant_id
465
- bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
466
- queue_tasks(doc, bucket, name)
467
- except Exception as e:
468
- return server_error_response(e)
469
-
470
- return get_json_result(data=doc_result.to_json())
471
-
472
-
473
- @manager.route('/list_chunks', methods=['POST'])
474
- # @login_required
475
- def list_chunks():
476
- token = request.headers.get('Authorization').split()[1]
477
- objs = APIToken.query(token=token)
478
- if not objs:
479
- return get_json_result(
480
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
481
-
482
- req = request.json
483
-
484
- try:
485
- if "doc_name" in req.keys():
486
- tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
487
- doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
488
-
489
- elif "doc_id" in req.keys():
490
- tenant_id = DocumentService.get_tenant_id(req['doc_id'])
491
- doc_id = req['doc_id']
492
- else:
493
- return get_json_result(
494
- data=False, retmsg="Can't find doc_name or doc_id"
495
- )
496
-
497
- res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
498
- res = [
499
- {
500
- "content": res_item["content_with_weight"],
501
- "doc_name": res_item["docnm_kwd"],
502
- "img_id": res_item["img_id"]
503
- } for res_item in res
504
- ]
505
-
506
- except Exception as e:
507
- return server_error_response(e)
508
-
509
- return get_json_result(data=res)
510
-
511
-
512
- @manager.route('/list_kb_docs', methods=['POST'])
513
- # @login_required
514
- def list_kb_docs():
515
- token = request.headers.get('Authorization').split()[1]
516
- objs = APIToken.query(token=token)
517
- if not objs:
518
- return get_json_result(
519
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
520
-
521
- req = request.json
522
- tenant_id = objs[0].tenant_id
523
- kb_name = req.get("kb_name", "").strip()
524
-
525
- try:
526
- e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
527
- if not e:
528
- return get_data_error_result(
529
- retmsg="Can't find this knowledgebase!")
530
- kb_id = kb.id
531
-
532
- except Exception as e:
533
- return server_error_response(e)
534
-
535
- page_number = int(req.get("page", 1))
536
- items_per_page = int(req.get("page_size", 15))
537
- orderby = req.get("orderby", "create_time")
538
- desc = req.get("desc", True)
539
- keywords = req.get("keywords", "")
540
-
541
- try:
542
- docs, tol = DocumentService.get_by_kb_id(
543
- kb_id, page_number, items_per_page, orderby, desc, keywords)
544
- docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
545
-
546
- return get_json_result(data={"total": tol, "docs": docs})
547
-
548
- except Exception as e:
549
- return server_error_response(e)
550
-
551
-
552
- @manager.route('/document', methods=['DELETE'])
553
- # @login_required
554
- def document_rm():
555
- token = request.headers.get('Authorization').split()[1]
556
- objs = APIToken.query(token=token)
557
- if not objs:
558
- return get_json_result(
559
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
560
-
561
- tenant_id = objs[0].tenant_id
562
- req = request.json
563
- doc_ids = []
564
- try:
565
- doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])]
566
- for doc_id in req.get("doc_ids", []):
567
- if doc_id not in doc_ids:
568
- doc_ids.append(doc_id)
569
-
570
- if not doc_ids:
571
- return get_json_result(
572
- data=False, retmsg="Can't find doc_names or doc_ids"
573
- )
574
-
575
- except Exception as e:
576
- return server_error_response(e)
577
-
578
- root_folder = FileService.get_root_folder(tenant_id)
579
- pf_id = root_folder["id"]
580
- FileService.init_knowledgebase_docs(pf_id, tenant_id)
581
-
582
- errors = ""
583
- for doc_id in doc_ids:
584
- try:
585
- e, doc = DocumentService.get_by_id(doc_id)
586
- if not e:
587
- return get_data_error_result(retmsg="Document not found!")
588
- tenant_id = DocumentService.get_tenant_id(doc_id)
589
- if not tenant_id:
590
- return get_data_error_result(retmsg="Tenant not found!")
591
-
592
- b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
593
-
594
- if not DocumentService.remove_document(doc, tenant_id):
595
- return get_data_error_result(
596
- retmsg="Database error (Document removal)!")
597
-
598
- f2d = File2DocumentService.get_by_document_id(doc_id)
599
- FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
600
- File2DocumentService.delete_by_document_id(doc_id)
601
-
602
- MINIO.rm(b, n)
603
- except Exception as e:
604
- errors += str(e)
605
-
606
- if errors:
607
- return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
608
-
609
- return get_json_result(data=True)
610
-
611
-
612
- @manager.route('/completion_aibotk', methods=['POST'])
613
- @validate_request("Authorization", "conversation_id", "word")
614
- def completion_faq():
615
- import base64
616
- req = request.json
617
-
618
- token = req["Authorization"]
619
- objs = APIToken.query(token=token)
620
- if not objs:
621
- return get_json_result(
622
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
623
-
624
- e, conv = API4ConversationService.get_by_id(req["conversation_id"])
625
- if not e:
626
- return get_data_error_result(retmsg="Conversation not found!")
627
- if "quote" not in req: req["quote"] = True
628
-
629
- msg = []
630
- msg.append({"role": "user", "content": req["word"]})
631
-
632
- try:
633
- conv.message.append(msg[-1])
634
- e, dia = DialogService.get_by_id(conv.dialog_id)
635
- if not e:
636
- return get_data_error_result(retmsg="Dialog not found!")
637
- del req["conversation_id"]
638
-
639
- if not conv.reference:
640
- conv.reference = []
641
- conv.message.append({"role": "assistant", "content": ""})
642
- conv.reference.append({"chunks": [], "doc_aggs": []})
643
-
644
- def fillin_conv(ans):
645
- nonlocal conv
646
- if not conv.reference:
647
- conv.reference.append(ans["reference"])
648
- else:
649
- conv.reference[-1] = ans["reference"]
650
- conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
651
-
652
- data_type_picture = {
653
- "type": 3,
654
- "url": "base64 content"
655
- }
656
- data = [
657
- {
658
- "type": 1,
659
- "content": ""
660
- }
661
- ]
662
- ans = ""
663
- for a in chat(dia, msg, stream=False, **req):
664
- ans = a
665
- break
666
- data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
667
- fillin_conv(ans)
668
- API4ConversationService.append_message(conv.id, conv.to_dict())
669
-
670
- chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
671
- for chunk_idx in chunk_idxs[:1]:
672
- if ans["reference"]["chunks"][chunk_idx]["img_id"]:
673
- try:
674
- bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
675
- response = MINIO.get(bkt, nm)
676
- data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
677
- data.append(data_type_picture)
678
- break
679
- except Exception as e:
680
- return server_error_response(e)
681
-
682
- response = {"code": 200, "msg": "success", "data": data}
683
- return response
684
-
685
- except Exception as e:
686
- return server_error_response(e)
687
-
688
-
689
- @manager.route('/retrieval', methods=['POST'])
690
- @validate_request("kb_id", "question")
691
- def retrieval():
692
- token = request.headers.get('Authorization').split()[1]
693
- objs = APIToken.query(token=token)
694
- if not objs:
695
- return get_json_result(
696
- data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
697
-
698
- req = request.json
699
- kb_ids = req.get("kb_id",[])
700
- doc_ids = req.get("doc_ids", [])
701
- question = req.get("question")
702
- page = int(req.get("page", 1))
703
- size = int(req.get("size", 30))
704
- similarity_threshold = float(req.get("similarity_threshold", 0.2))
705
- vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
706
- top = int(req.get("top_k", 1024))
707
-
708
- try:
709
- kbs = KnowledgebaseService.get_by_ids(kb_ids)
710
- embd_nms = list(set([kb.embd_id for kb in kbs]))
711
- if len(embd_nms) != 1:
712
- return get_json_result(
713
- data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
714
-
715
- embd_mdl = TenantLLMService.model_instance(
716
- kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
717
- rerank_mdl = None
718
- if req.get("rerank_id"):
719
- rerank_mdl = TenantLLMService.model_instance(
720
- kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
721
- if req.get("keyword", False):
722
- chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
723
- question += keyword_extraction(chat_mdl, question)
724
- ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
725
- similarity_threshold, vector_similarity_weight, top,
726
- doc_ids, rerank_mdl=rerank_mdl)
727
- for c in ranks["chunks"]:
728
- if "vector" in c:
729
- del c["vector"]
730
- return get_json_result(data=ranks)
731
- except Exception as e:
732
- if str(e).find("not_found") > 0:
733
- return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
734
- retcode=RetCode.DATA_ERROR)
735
  return server_error_response(e)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import os
18
+ import re
19
+ from datetime import datetime, timedelta
20
+ from flask import request, Response
21
+ from api.db.services.llm_service import TenantLLMService
22
+ from flask_login import login_required, current_user
23
+
24
+ from api.db import FileType, LLMType, ParserType, FileSource
25
+ from api.db.db_models import APIToken, API4Conversation, Task, File
26
+ from api.db.services import duplicate_name
27
+ from api.db.services.api_service import APITokenService, API4ConversationService
28
+ from api.db.services.dialog_service import DialogService, chat
29
+ from api.db.services.document_service import DocumentService
30
+ from api.db.services.file2document_service import File2DocumentService
31
+ from api.db.services.file_service import FileService
32
+ from api.db.services.knowledgebase_service import KnowledgebaseService
33
+ from api.db.services.task_service import queue_tasks, TaskService
34
+ from api.db.services.user_service import UserTenantService
35
+ from api.settings import RetCode, retrievaler
36
+ from api.utils import get_uuid, current_timestamp, datetime_format
37
+ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
38
+ from itsdangerous import URLSafeTimedSerializer
39
+
40
+ from api.utils.file_utils import filename_type, thumbnail
41
+ from rag.nlp import keyword_extraction
42
+ from rag.utils.minio_conn import MINIO
43
+
44
+ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
45
+ from agent.canvas import Canvas
46
+ from functools import partial
47
+
48
+
49
+ def generate_confirmation_token(tenent_id):
50
+ serializer = URLSafeTimedSerializer(tenent_id)
51
+ return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
52
+
53
+
54
+ @manager.route('/new_token', methods=['POST'])
55
+ @login_required
56
+ def new_token():
57
+ req = request.json
58
+ try:
59
+ tenants = UserTenantService.query(user_id=current_user.id)
60
+ if not tenants:
61
+ return get_data_error_result(retmsg="Tenant not found!")
62
+
63
+ tenant_id = tenants[0].tenant_id
64
+ obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
65
+ "create_time": current_timestamp(),
66
+ "create_date": datetime_format(datetime.now()),
67
+ "update_time": None,
68
+ "update_date": None
69
+ }
70
+ if req.get("canvas_id"):
71
+ obj["dialog_id"] = req["canvas_id"]
72
+ obj["source"] = "agent"
73
+ else:
74
+ obj["dialog_id"] = req["dialog_id"]
75
+
76
+ if not APITokenService.save(**obj):
77
+ return get_data_error_result(retmsg="Fail to new a dialog!")
78
+
79
+ return get_json_result(data=obj)
80
+ except Exception as e:
81
+ return server_error_response(e)
82
+
83
+
84
+ @manager.route('/token_list', methods=['GET'])
85
+ @login_required
86
+ def token_list():
87
+ try:
88
+ tenants = UserTenantService.query(user_id=current_user.id)
89
+ if not tenants:
90
+ return get_data_error_result(retmsg="Tenant not found!")
91
+
92
+ id = request.args["dialog_id"] if "dialog_id" in request.args else request.args["canvas_id"]
93
+ objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id)
94
+ return get_json_result(data=[o.to_dict() for o in objs])
95
+ except Exception as e:
96
+ return server_error_response(e)
97
+
98
+
99
+ @manager.route('/rm', methods=['POST'])
100
+ @validate_request("tokens", "tenant_id")
101
+ @login_required
102
+ def rm():
103
+ req = request.json
104
+ try:
105
+ for token in req["tokens"]:
106
+ APITokenService.filter_delete(
107
+ [APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
108
+ return get_json_result(data=True)
109
+ except Exception as e:
110
+ return server_error_response(e)
111
+
112
+
113
+ @manager.route('/stats', methods=['GET'])
114
+ @login_required
115
+ def stats():
116
+ try:
117
+ tenants = UserTenantService.query(user_id=current_user.id)
118
+ if not tenants:
119
+ return get_data_error_result(retmsg="Tenant not found!")
120
+ objs = API4ConversationService.stats(
121
+ tenants[0].tenant_id,
122
+ request.args.get(
123
+ "from_date",
124
+ (datetime.now() -
125
+ timedelta(
126
+ days=7)).strftime("%Y-%m-%d 00:00:00")),
127
+ request.args.get(
128
+ "to_date",
129
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
130
+ "agent" if "canvas_id" in request.args else None)
131
+ res = {
132
+ "pv": [(o["dt"], o["pv"]) for o in objs],
133
+ "uv": [(o["dt"], o["uv"]) for o in objs],
134
+ "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
135
+ "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
136
+ "round": [(o["dt"], o["round"]) for o in objs],
137
+ "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
138
+ }
139
+ return get_json_result(data=res)
140
+ except Exception as e:
141
+ return server_error_response(e)
142
+
143
+
144
+ @manager.route('/new_conversation', methods=['GET'])
145
+ def set_conversation():
146
+ token = request.headers.get('Authorization').split()[1]
147
+ objs = APIToken.query(token=token)
148
+ if not objs:
149
+ return get_json_result(
150
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
151
+ req = request.json
152
+ try:
153
+ if objs[0].source == "agent":
154
+ e, c = UserCanvasService.get_by_id(objs[0].dialog_id)
155
+ if not e:
156
+ return server_error_response("canvas not found.")
157
+ conv = {
158
+ "id": get_uuid(),
159
+ "dialog_id": c.id,
160
+ "user_id": request.args.get("user_id", ""),
161
+ "message": [{"role": "assistant", "content": "Hi there!"}],
162
+ "source": "agent"
163
+ }
164
+ API4ConversationService.save(**conv)
165
+ return get_json_result(data=conv)
166
+ else:
167
+ e, dia = DialogService.get_by_id(objs[0].dialog_id)
168
+ if not e:
169
+ return get_data_error_result(retmsg="Dialog not found")
170
+ conv = {
171
+ "id": get_uuid(),
172
+ "dialog_id": dia.id,
173
+ "user_id": request.args.get("user_id", ""),
174
+ "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
175
+ }
176
+ API4ConversationService.save(**conv)
177
+ return get_json_result(data=conv)
178
+ except Exception as e:
179
+ return server_error_response(e)
180
+
181
+
182
+ @manager.route('/completion', methods=['POST'])
183
+ @validate_request("conversation_id", "messages")
184
+ def completion():
185
+ token = request.headers.get('Authorization').split()[1]
186
+ objs = APIToken.query(token=token)
187
+ if not objs:
188
+ return get_json_result(
189
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
190
+ req = request.json
191
+ e, conv = API4ConversationService.get_by_id(req["conversation_id"])
192
+ if not e:
193
+ return get_data_error_result(retmsg="Conversation not found!")
194
+ if "quote" not in req: req["quote"] = False
195
+
196
+ msg = []
197
+ for m in req["messages"]:
198
+ if m["role"] == "system":
199
+ continue
200
+ if m["role"] == "assistant" and not msg:
201
+ continue
202
+ msg.append({"role": m["role"], "content": m["content"]})
203
+
204
+ def fillin_conv(ans):
205
+ nonlocal conv
206
+ if not conv.reference:
207
+ conv.reference.append(ans["reference"])
208
+ else:
209
+ conv.reference[-1] = ans["reference"]
210
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
211
+
212
+ def rename_field(ans):
213
+ reference = ans['reference']
214
+ if not isinstance(reference, dict):
215
+ return
216
+ for chunk_i in reference.get('chunks', []):
217
+ if 'docnm_kwd' in chunk_i:
218
+ chunk_i['doc_name'] = chunk_i['docnm_kwd']
219
+ chunk_i.pop('docnm_kwd')
220
+
221
+ try:
222
+ if conv.source == "agent":
223
+ stream = req.get("stream", True)
224
+ conv.message.append(msg[-1])
225
+ e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
226
+ if not e:
227
+ return server_error_response("canvas not found.")
228
+ del req["conversation_id"]
229
+ del req["messages"]
230
+
231
+ if not isinstance(cvs.dsl, str):
232
+ cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
233
+
234
+ if not conv.reference:
235
+ conv.reference = []
236
+ conv.message.append({"role": "assistant", "content": ""})
237
+ conv.reference.append({"chunks": [], "doc_aggs": []})
238
+
239
+ final_ans = {"reference": [], "content": ""}
240
+ canvas = Canvas(cvs.dsl, objs[0].tenant_id)
241
+
242
+ canvas.messages.append(msg[-1])
243
+ canvas.add_user_input(msg[-1]["content"])
244
+ answer = canvas.run(stream=stream)
245
+
246
+ assert answer is not None, "Nothing. Is it over?"
247
+
248
+ if stream:
249
+ assert isinstance(answer, partial), "Nothing. Is it over?"
250
+
251
+ def sse():
252
+ nonlocal answer, cvs, conv
253
+ try:
254
+ for ans in answer():
255
+ for k in ans.keys():
256
+ final_ans[k] = ans[k]
257
+ ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
258
+ fillin_conv(ans)
259
+ rename_field(ans)
260
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
261
+ ensure_ascii=False) + "\n\n"
262
+
263
+ canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
264
+ if final_ans.get("reference"):
265
+ canvas.reference.append(final_ans["reference"])
266
+ cvs.dsl = json.loads(str(canvas))
267
+ API4ConversationService.append_message(conv.id, conv.to_dict())
268
+ except Exception as e:
269
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
270
+ "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
271
+ ensure_ascii=False) + "\n\n"
272
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
273
+
274
+ resp = Response(sse(), mimetype="text/event-stream")
275
+ resp.headers.add_header("Cache-control", "no-cache")
276
+ resp.headers.add_header("Connection", "keep-alive")
277
+ resp.headers.add_header("X-Accel-Buffering", "no")
278
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
279
+ return resp
280
+
281
+ final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
282
+ canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
283
+ if final_ans.get("reference"):
284
+ canvas.reference.append(final_ans["reference"])
285
+ cvs.dsl = json.loads(str(canvas))
286
+
287
+ result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
288
+ fillin_conv(result)
289
+ API4ConversationService.append_message(conv.id, conv.to_dict())
290
+ rename_field(result)
291
+ return get_json_result(data=result)
292
+
293
+ #******************For dialog******************
294
+ conv.message.append(msg[-1])
295
+ e, dia = DialogService.get_by_id(conv.dialog_id)
296
+ if not e:
297
+ return get_data_error_result(retmsg="Dialog not found!")
298
+ del req["conversation_id"]
299
+ del req["messages"]
300
+
301
+ if not conv.reference:
302
+ conv.reference = []
303
+ conv.message.append({"role": "assistant", "content": ""})
304
+ conv.reference.append({"chunks": [], "doc_aggs": []})
305
+
306
+ def stream():
307
+ nonlocal dia, msg, req, conv
308
+ try:
309
+ for ans in chat(dia, msg, True, **req):
310
+ fillin_conv(ans)
311
+ rename_field(ans)
312
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans},
313
+ ensure_ascii=False) + "\n\n"
314
+ API4ConversationService.append_message(conv.id, conv.to_dict())
315
+ except Exception as e:
316
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
317
+ "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
318
+ ensure_ascii=False) + "\n\n"
319
+ yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
320
+
321
+ if req.get("stream", True):
322
+ resp = Response(stream(), mimetype="text/event-stream")
323
+ resp.headers.add_header("Cache-control", "no-cache")
324
+ resp.headers.add_header("Connection", "keep-alive")
325
+ resp.headers.add_header("X-Accel-Buffering", "no")
326
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
327
+ return resp
328
+
329
+ answer = None
330
+ for ans in chat(dia, msg, **req):
331
+ answer = ans
332
+ fillin_conv(ans)
333
+ API4ConversationService.append_message(conv.id, conv.to_dict())
334
+ break
335
+ rename_field(answer)
336
+ return get_json_result(data=answer)
337
+
338
+ except Exception as e:
339
+ return server_error_response(e)
340
+
341
+
342
+ @manager.route('/conversation/<conversation_id>', methods=['GET'])
343
+ # @login_required
344
+ def get(conversation_id):
345
+ try:
346
+ e, conv = API4ConversationService.get_by_id(conversation_id)
347
+ if not e:
348
+ return get_data_error_result(retmsg="Conversation not found!")
349
+
350
+ conv = conv.to_dict()
351
+ for referenct_i in conv['reference']:
352
+ if referenct_i is None or len(referenct_i) == 0:
353
+ continue
354
+ for chunk_i in referenct_i['chunks']:
355
+ if 'docnm_kwd' in chunk_i.keys():
356
+ chunk_i['doc_name'] = chunk_i['docnm_kwd']
357
+ chunk_i.pop('docnm_kwd')
358
+ return get_json_result(data=conv)
359
+ except Exception as e:
360
+ return server_error_response(e)
361
+
362
+
363
+ @manager.route('/document/upload', methods=['POST'])
364
+ @validate_request("kb_name")
365
+ def upload():
366
+ token = request.headers.get('Authorization').split()[1]
367
+ objs = APIToken.query(token=token)
368
+ if not objs:
369
+ return get_json_result(
370
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
371
+
372
+ kb_name = request.form.get("kb_name").strip()
373
+ tenant_id = objs[0].tenant_id
374
+
375
+ try:
376
+ e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
377
+ if not e:
378
+ return get_data_error_result(
379
+ retmsg="Can't find this knowledgebase!")
380
+ kb_id = kb.id
381
+ except Exception as e:
382
+ return server_error_response(e)
383
+
384
+ if 'file' not in request.files:
385
+ return get_json_result(
386
+ data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
387
+
388
+ file = request.files['file']
389
+ if file.filename == '':
390
+ return get_json_result(
391
+ data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
392
+
393
+ root_folder = FileService.get_root_folder(tenant_id)
394
+ pf_id = root_folder["id"]
395
+ FileService.init_knowledgebase_docs(pf_id, tenant_id)
396
+ kb_root_folder = FileService.get_kb_folder(tenant_id)
397
+ kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
398
+
399
+ try:
400
+ if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
401
+ return get_data_error_result(
402
+ retmsg="Exceed the maximum file number of a free user!")
403
+
404
+ filename = duplicate_name(
405
+ DocumentService.query,
406
+ name=file.filename,
407
+ kb_id=kb_id)
408
+ filetype = filename_type(filename)
409
+ if not filetype:
410
+ return get_data_error_result(
411
+ retmsg="This type of file has not been supported yet!")
412
+
413
+ location = filename
414
+ while MINIO.obj_exist(kb_id, location):
415
+ location += "_"
416
+ blob = request.files['file'].read()
417
+ MINIO.put(kb_id, location, blob)
418
+ doc = {
419
+ "id": get_uuid(),
420
+ "kb_id": kb.id,
421
+ "parser_id": kb.parser_id,
422
+ "parser_config": kb.parser_config,
423
+ "created_by": kb.tenant_id,
424
+ "type": filetype,
425
+ "name": filename,
426
+ "location": location,
427
+ "size": len(blob),
428
+ "thumbnail": thumbnail(filename, blob)
429
+ }
430
+
431
+ form_data = request.form
432
+ if "parser_id" in form_data.keys():
433
+ if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
434
+ doc["parser_id"] = request.form.get("parser_id").strip()
435
+ if doc["type"] == FileType.VISUAL:
436
+ doc["parser_id"] = ParserType.PICTURE.value
437
+ if doc["type"] == FileType.AURAL:
438
+ doc["parser_id"] = ParserType.AUDIO.value
439
+ if re.search(r"\.(ppt|pptx|pages)$", filename):
440
+ doc["parser_id"] = ParserType.PRESENTATION.value
441
+
442
+ doc_result = DocumentService.insert(doc)
443
+ FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
444
+ except Exception as e:
445
+ return server_error_response(e)
446
+
447
+ if "run" in form_data.keys():
448
+ if request.form.get("run").strip() == "1":
449
+ try:
450
+ info = {"run": 1, "progress": 0}
451
+ info["progress_msg"] = ""
452
+ info["chunk_num"] = 0
453
+ info["token_num"] = 0
454
+ DocumentService.update_by_id(doc["id"], info)
455
+ # if str(req["run"]) == TaskStatus.CANCEL.value:
456
+ tenant_id = DocumentService.get_tenant_id(doc["id"])
457
+ if not tenant_id:
458
+ return get_data_error_result(retmsg="Tenant not found!")
459
+
460
+ # e, doc = DocumentService.get_by_id(doc["id"])
461
+ TaskService.filter_delete([Task.doc_id == doc["id"]])
462
+ e, doc = DocumentService.get_by_id(doc["id"])
463
+ doc = doc.to_dict()
464
+ doc["tenant_id"] = tenant_id
465
+ bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
466
+ queue_tasks(doc, bucket, name)
467
+ except Exception as e:
468
+ return server_error_response(e)
469
+
470
+ return get_json_result(data=doc_result.to_json())
471
+
472
+
473
+ @manager.route('/list_chunks', methods=['POST'])
474
+ # @login_required
475
+ def list_chunks():
476
+ token = request.headers.get('Authorization').split()[1]
477
+ objs = APIToken.query(token=token)
478
+ if not objs:
479
+ return get_json_result(
480
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
481
+
482
+ req = request.json
483
+
484
+ try:
485
+ if "doc_name" in req.keys():
486
+ tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
487
+ doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
488
+
489
+ elif "doc_id" in req.keys():
490
+ tenant_id = DocumentService.get_tenant_id(req['doc_id'])
491
+ doc_id = req['doc_id']
492
+ else:
493
+ return get_json_result(
494
+ data=False, retmsg="Can't find doc_name or doc_id"
495
+ )
496
+
497
+ res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
498
+ res = [
499
+ {
500
+ "content": res_item["content_with_weight"],
501
+ "doc_name": res_item["docnm_kwd"],
502
+ "img_id": res_item["img_id"]
503
+ } for res_item in res
504
+ ]
505
+
506
+ except Exception as e:
507
+ return server_error_response(e)
508
+
509
+ return get_json_result(data=res)
510
+
511
+
512
+ @manager.route('/list_kb_docs', methods=['POST'])
513
+ # @login_required
514
+ def list_kb_docs():
515
+ token = request.headers.get('Authorization').split()[1]
516
+ objs = APIToken.query(token=token)
517
+ if not objs:
518
+ return get_json_result(
519
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
520
+
521
+ req = request.json
522
+ tenant_id = objs[0].tenant_id
523
+ kb_name = req.get("kb_name", "").strip()
524
+
525
+ try:
526
+ e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
527
+ if not e:
528
+ return get_data_error_result(
529
+ retmsg="Can't find this knowledgebase!")
530
+ kb_id = kb.id
531
+
532
+ except Exception as e:
533
+ return server_error_response(e)
534
+
535
+ page_number = int(req.get("page", 1))
536
+ items_per_page = int(req.get("page_size", 15))
537
+ orderby = req.get("orderby", "create_time")
538
+ desc = req.get("desc", True)
539
+ keywords = req.get("keywords", "")
540
+
541
+ try:
542
+ docs, tol = DocumentService.get_by_kb_id(
543
+ kb_id, page_number, items_per_page, orderby, desc, keywords)
544
+ docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
545
+
546
+ return get_json_result(data={"total": tol, "docs": docs})
547
+
548
+ except Exception as e:
549
+ return server_error_response(e)
550
+
551
+
552
+ @manager.route('/document', methods=['DELETE'])
553
+ # @login_required
554
+ def document_rm():
555
+ token = request.headers.get('Authorization').split()[1]
556
+ objs = APIToken.query(token=token)
557
+ if not objs:
558
+ return get_json_result(
559
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
560
+
561
+ tenant_id = objs[0].tenant_id
562
+ req = request.json
563
+ doc_ids = []
564
+ try:
565
+ doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])]
566
+ for doc_id in req.get("doc_ids", []):
567
+ if doc_id not in doc_ids:
568
+ doc_ids.append(doc_id)
569
+
570
+ if not doc_ids:
571
+ return get_json_result(
572
+ data=False, retmsg="Can't find doc_names or doc_ids"
573
+ )
574
+
575
+ except Exception as e:
576
+ return server_error_response(e)
577
+
578
+ root_folder = FileService.get_root_folder(tenant_id)
579
+ pf_id = root_folder["id"]
580
+ FileService.init_knowledgebase_docs(pf_id, tenant_id)
581
+
582
+ errors = ""
583
+ for doc_id in doc_ids:
584
+ try:
585
+ e, doc = DocumentService.get_by_id(doc_id)
586
+ if not e:
587
+ return get_data_error_result(retmsg="Document not found!")
588
+ tenant_id = DocumentService.get_tenant_id(doc_id)
589
+ if not tenant_id:
590
+ return get_data_error_result(retmsg="Tenant not found!")
591
+
592
+ b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
593
+
594
+ if not DocumentService.remove_document(doc, tenant_id):
595
+ return get_data_error_result(
596
+ retmsg="Database error (Document removal)!")
597
+
598
+ f2d = File2DocumentService.get_by_document_id(doc_id)
599
+ FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
600
+ File2DocumentService.delete_by_document_id(doc_id)
601
+
602
+ MINIO.rm(b, n)
603
+ except Exception as e:
604
+ errors += str(e)
605
+
606
+ if errors:
607
+ return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
608
+
609
+ return get_json_result(data=True)
610
+
611
+
612
+ @manager.route('/completion_aibotk', methods=['POST'])
613
+ @validate_request("Authorization", "conversation_id", "word")
614
+ def completion_faq():
615
+ import base64
616
+ req = request.json
617
+
618
+ token = req["Authorization"]
619
+ objs = APIToken.query(token=token)
620
+ if not objs:
621
+ return get_json_result(
622
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
623
+
624
+ e, conv = API4ConversationService.get_by_id(req["conversation_id"])
625
+ if not e:
626
+ return get_data_error_result(retmsg="Conversation not found!")
627
+ if "quote" not in req: req["quote"] = True
628
+
629
+ msg = []
630
+ msg.append({"role": "user", "content": req["word"]})
631
+
632
+ try:
633
+ conv.message.append(msg[-1])
634
+ e, dia = DialogService.get_by_id(conv.dialog_id)
635
+ if not e:
636
+ return get_data_error_result(retmsg="Dialog not found!")
637
+ del req["conversation_id"]
638
+
639
+ if not conv.reference:
640
+ conv.reference = []
641
+ conv.message.append({"role": "assistant", "content": ""})
642
+ conv.reference.append({"chunks": [], "doc_aggs": []})
643
+
644
+ def fillin_conv(ans):
645
+ nonlocal conv
646
+ if not conv.reference:
647
+ conv.reference.append(ans["reference"])
648
+ else:
649
+ conv.reference[-1] = ans["reference"]
650
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
651
+
652
+ data_type_picture = {
653
+ "type": 3,
654
+ "url": "base64 content"
655
+ }
656
+ data = [
657
+ {
658
+ "type": 1,
659
+ "content": ""
660
+ }
661
+ ]
662
+ ans = ""
663
+ for a in chat(dia, msg, stream=False, **req):
664
+ ans = a
665
+ break
666
+ data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
667
+ fillin_conv(ans)
668
+ API4ConversationService.append_message(conv.id, conv.to_dict())
669
+
670
+ chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
671
+ for chunk_idx in chunk_idxs[:1]:
672
+ if ans["reference"]["chunks"][chunk_idx]["img_id"]:
673
+ try:
674
+ bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
675
+ response = MINIO.get(bkt, nm)
676
+ data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
677
+ data.append(data_type_picture)
678
+ break
679
+ except Exception as e:
680
+ return server_error_response(e)
681
+
682
+ response = {"code": 200, "msg": "success", "data": data}
683
+ return response
684
+
685
+ except Exception as e:
686
+ return server_error_response(e)
687
+
688
+
689
+ @manager.route('/retrieval', methods=['POST'])
690
+ @validate_request("kb_id", "question")
691
+ def retrieval():
692
+ token = request.headers.get('Authorization').split()[1]
693
+ objs = APIToken.query(token=token)
694
+ if not objs:
695
+ return get_json_result(
696
+ data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
697
+
698
+ req = request.json
699
+ kb_ids = req.get("kb_id",[])
700
+ doc_ids = req.get("doc_ids", [])
701
+ question = req.get("question")
702
+ page = int(req.get("page", 1))
703
+ size = int(req.get("size", 30))
704
+ similarity_threshold = float(req.get("similarity_threshold", 0.2))
705
+ vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
706
+ top = int(req.get("top_k", 1024))
707
+
708
+ try:
709
+ kbs = KnowledgebaseService.get_by_ids(kb_ids)
710
+ embd_nms = list(set([kb.embd_id for kb in kbs]))
711
+ if len(embd_nms) != 1:
712
+ return get_json_result(
713
+ data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
714
+
715
+ embd_mdl = TenantLLMService.model_instance(
716
+ kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
717
+ rerank_mdl = None
718
+ if req.get("rerank_id"):
719
+ rerank_mdl = TenantLLMService.model_instance(
720
+ kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
721
+ if req.get("keyword", False):
722
+ chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
723
+ question += keyword_extraction(chat_mdl, question)
724
+ ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
725
+ similarity_threshold, vector_similarity_weight, top,
726
+ doc_ids, rerank_mdl=rerank_mdl)
727
+ for c in ranks["chunks"]:
728
+ if "vector" in c:
729
+ del c["vector"]
730
+ return get_json_result(data=ranks)
731
+ except Exception as e:
732
+ if str(e).find("not_found") > 0:
733
+ return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
734
+ retcode=RetCode.DATA_ERROR)
735
  return server_error_response(e)
api/apps/chunk_app.py CHANGED
@@ -1,318 +1,318 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import datetime
17
- import json
18
- import traceback
19
-
20
- from flask import request
21
- from flask_login import login_required, current_user
22
- from elasticsearch_dsl import Q
23
-
24
- from rag.app.qa import rmPrefix, beAdoc
25
- from rag.nlp import search, rag_tokenizer, keyword_extraction
26
- from rag.utils.es_conn import ELASTICSEARCH
27
- from rag.utils import rmSpace
28
- from api.db import LLMType, ParserType
29
- from api.db.services.knowledgebase_service import KnowledgebaseService
30
- from api.db.services.llm_service import TenantLLMService
31
- from api.db.services.user_service import UserTenantService
32
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
33
- from api.db.services.document_service import DocumentService
34
- from api.settings import RetCode, retrievaler, kg_retrievaler
35
- from api.utils.api_utils import get_json_result
36
- import hashlib
37
- import re
38
-
39
-
40
- @manager.route('/list', methods=['POST'])
41
- @login_required
42
- @validate_request("doc_id")
43
- def list_chunk():
44
- req = request.json
45
- doc_id = req["doc_id"]
46
- page = int(req.get("page", 1))
47
- size = int(req.get("size", 30))
48
- question = req.get("keywords", "")
49
- try:
50
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
51
- if not tenant_id:
52
- return get_data_error_result(retmsg="Tenant not found!")
53
- e, doc = DocumentService.get_by_id(doc_id)
54
- if not e:
55
- return get_data_error_result(retmsg="Document not found!")
56
- query = {
57
- "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
58
- }
59
- if "available_int" in req:
60
- query["available_int"] = int(req["available_int"])
61
- sres = retrievaler.search(query, search.index_name(tenant_id))
62
- res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
63
- for id in sres.ids:
64
- d = {
65
- "chunk_id": id,
66
- "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
67
- id].get(
68
- "content_with_weight", ""),
69
- "doc_id": sres.field[id]["doc_id"],
70
- "docnm_kwd": sres.field[id]["docnm_kwd"],
71
- "important_kwd": sres.field[id].get("important_kwd", []),
72
- "img_id": sres.field[id].get("img_id", ""),
73
- "available_int": sres.field[id].get("available_int", 1),
74
- "positions": sres.field[id].get("position_int", "").split("\t")
75
- }
76
- if len(d["positions"]) % 5 == 0:
77
- poss = []
78
- for i in range(0, len(d["positions"]), 5):
79
- poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
80
- float(d["positions"][i + 3]), float(d["positions"][i + 4])])
81
- d["positions"] = poss
82
- res["chunks"].append(d)
83
- return get_json_result(data=res)
84
- except Exception as e:
85
- if str(e).find("not_found") > 0:
86
- return get_json_result(data=False, retmsg=f'No chunk found!',
87
- retcode=RetCode.DATA_ERROR)
88
- return server_error_response(e)
89
-
90
-
91
- @manager.route('/get', methods=['GET'])
92
- @login_required
93
- def get():
94
- chunk_id = request.args["chunk_id"]
95
- try:
96
- tenants = UserTenantService.query(user_id=current_user.id)
97
- if not tenants:
98
- return get_data_error_result(retmsg="Tenant not found!")
99
- res = ELASTICSEARCH.get(
100
- chunk_id, search.index_name(
101
- tenants[0].tenant_id))
102
- if not res.get("found"):
103
- return server_error_response("Chunk not found")
104
- id = res["_id"]
105
- res = res["_source"]
106
- res["chunk_id"] = id
107
- k = []
108
- for n in res.keys():
109
- if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
110
- k.append(n)
111
- for n in k:
112
- del res[n]
113
-
114
- return get_json_result(data=res)
115
- except Exception as e:
116
- if str(e).find("NotFoundError") >= 0:
117
- return get_json_result(data=False, retmsg=f'Chunk not found!',
118
- retcode=RetCode.DATA_ERROR)
119
- return server_error_response(e)
120
-
121
-
122
- @manager.route('/set', methods=['POST'])
123
- @login_required
124
- @validate_request("doc_id", "chunk_id", "content_with_weight",
125
- "important_kwd")
126
- def set():
127
- req = request.json
128
- d = {
129
- "id": req["chunk_id"],
130
- "content_with_weight": req["content_with_weight"]}
131
- d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
132
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
133
- d["important_kwd"] = req["important_kwd"]
134
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
135
- if "available_int" in req:
136
- d["available_int"] = req["available_int"]
137
-
138
- try:
139
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
140
- if not tenant_id:
141
- return get_data_error_result(retmsg="Tenant not found!")
142
-
143
- embd_id = DocumentService.get_embd_id(req["doc_id"])
144
- embd_mdl = TenantLLMService.model_instance(
145
- tenant_id, LLMType.EMBEDDING.value, embd_id)
146
-
147
- e, doc = DocumentService.get_by_id(req["doc_id"])
148
- if not e:
149
- return get_data_error_result(retmsg="Document not found!")
150
-
151
- if doc.parser_id == ParserType.QA:
152
- arr = [
153
- t for t in re.split(
154
- r"[\n\t]",
155
- req["content_with_weight"]) if len(t) > 1]
156
- if len(arr) != 2:
157
- return get_data_error_result(
158
- retmsg="Q&A must be separated by TAB/ENTER key.")
159
- q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
160
- d = beAdoc(d, arr[0], arr[1], not any(
161
- [rag_tokenizer.is_chinese(t) for t in q + a]))
162
-
163
- v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
164
- v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
165
- d["q_%d_vec" % len(v)] = v.tolist()
166
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
167
- return get_json_result(data=True)
168
- except Exception as e:
169
- return server_error_response(e)
170
-
171
-
172
- @manager.route('/switch', methods=['POST'])
173
- @login_required
174
- @validate_request("chunk_ids", "available_int", "doc_id")
175
- def switch():
176
- req = request.json
177
- try:
178
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
179
- if not tenant_id:
180
- return get_data_error_result(retmsg="Tenant not found!")
181
- if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
182
- search.index_name(tenant_id)):
183
- return get_data_error_result(retmsg="Index updating failure")
184
- return get_json_result(data=True)
185
- except Exception as e:
186
- return server_error_response(e)
187
-
188
-
189
- @manager.route('/rm', methods=['POST'])
190
- @login_required
191
- @validate_request("chunk_ids", "doc_id")
192
- def rm():
193
- req = request.json
194
- try:
195
- if not ELASTICSEARCH.deleteByQuery(
196
- Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
197
- return get_data_error_result(retmsg="Index updating failure")
198
- e, doc = DocumentService.get_by_id(req["doc_id"])
199
- if not e:
200
- return get_data_error_result(retmsg="Document not found!")
201
- deleted_chunk_ids = req["chunk_ids"]
202
- chunk_number = len(deleted_chunk_ids)
203
- DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
204
- return get_json_result(data=True)
205
- except Exception as e:
206
- return server_error_response(e)
207
-
208
-
209
- @manager.route('/create', methods=['POST'])
210
- @login_required
211
- @validate_request("doc_id", "content_with_weight")
212
- def create():
213
- req = request.json
214
- md5 = hashlib.md5()
215
- md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
216
- chunck_id = md5.hexdigest()
217
- d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
218
- "content_with_weight": req["content_with_weight"]}
219
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
220
- d["important_kwd"] = req.get("important_kwd", [])
221
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
222
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
223
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
224
-
225
- try:
226
- e, doc = DocumentService.get_by_id(req["doc_id"])
227
- if not e:
228
- return get_data_error_result(retmsg="Document not found!")
229
- d["kb_id"] = [doc.kb_id]
230
- d["docnm_kwd"] = doc.name
231
- d["doc_id"] = doc.id
232
-
233
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
234
- if not tenant_id:
235
- return get_data_error_result(retmsg="Tenant not found!")
236
-
237
- embd_id = DocumentService.get_embd_id(req["doc_id"])
238
- embd_mdl = TenantLLMService.model_instance(
239
- tenant_id, LLMType.EMBEDDING.value, embd_id)
240
-
241
- v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
242
- v = 0.1 * v[0] + 0.9 * v[1]
243
- d["q_%d_vec" % len(v)] = v.tolist()
244
- ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
245
-
246
- DocumentService.increment_chunk_num(
247
- doc.id, doc.kb_id, c, 1, 0)
248
- return get_json_result(data={"chunk_id": chunck_id})
249
- except Exception as e:
250
- return server_error_response(e)
251
-
252
-
253
- @manager.route('/retrieval_test', methods=['POST'])
254
- @login_required
255
- @validate_request("kb_id", "question")
256
- def retrieval_test():
257
- req = request.json
258
- page = int(req.get("page", 1))
259
- size = int(req.get("size", 30))
260
- question = req["question"]
261
- kb_id = req["kb_id"]
262
- doc_ids = req.get("doc_ids", [])
263
- similarity_threshold = float(req.get("similarity_threshold", 0.2))
264
- vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
265
- top = int(req.get("top_k", 1024))
266
- try:
267
- e, kb = KnowledgebaseService.get_by_id(kb_id)
268
- if not e:
269
- return get_data_error_result(retmsg="Knowledgebase not found!")
270
-
271
- embd_mdl = TenantLLMService.model_instance(
272
- kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
273
-
274
- rerank_mdl = None
275
- if req.get("rerank_id"):
276
- rerank_mdl = TenantLLMService.model_instance(
277
- kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
278
-
279
- if req.get("keyword", False):
280
- chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
281
- question += keyword_extraction(chat_mdl, question)
282
-
283
- retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
284
- ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
285
- similarity_threshold, vector_similarity_weight, top,
286
- doc_ids, rerank_mdl=rerank_mdl)
287
- for c in ranks["chunks"]:
288
- if "vector" in c:
289
- del c["vector"]
290
-
291
- return get_json_result(data=ranks)
292
- except Exception as e:
293
- if str(e).find("not_found") > 0:
294
- return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
295
- retcode=RetCode.DATA_ERROR)
296
- return server_error_response(e)
297
-
298
-
299
- @manager.route('/knowledge_graph', methods=['GET'])
300
- @login_required
301
- def knowledge_graph():
302
- doc_id = request.args["doc_id"]
303
- req = {
304
- "doc_ids":[doc_id],
305
- "knowledge_graph_kwd": ["graph", "mind_map"]
306
- }
307
- tenant_id = DocumentService.get_tenant_id(doc_id)
308
- sres = retrievaler.search(req, search.index_name(tenant_id))
309
- obj = {"graph": {}, "mind_map": {}}
310
- for id in sres.ids[:2]:
311
- ty = sres.field[id]["knowledge_graph_kwd"]
312
- try:
313
- obj[ty] = json.loads(sres.field[id]["content_with_weight"])
314
- except Exception as e:
315
- print(traceback.format_exc(), flush=True)
316
-
317
- return get_json_result(data=obj)
318
-
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import datetime
17
+ import json
18
+ import traceback
19
+
20
+ from flask import request
21
+ from flask_login import login_required, current_user
22
+ from elasticsearch_dsl import Q
23
+
24
+ from rag.app.qa import rmPrefix, beAdoc
25
+ from rag.nlp import search, rag_tokenizer, keyword_extraction
26
+ from rag.utils.es_conn import ELASTICSEARCH
27
+ from rag.utils import rmSpace
28
+ from api.db import LLMType, ParserType
29
+ from api.db.services.knowledgebase_service import KnowledgebaseService
30
+ from api.db.services.llm_service import TenantLLMService
31
+ from api.db.services.user_service import UserTenantService
32
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
33
+ from api.db.services.document_service import DocumentService
34
+ from api.settings import RetCode, retrievaler, kg_retrievaler
35
+ from api.utils.api_utils import get_json_result
36
+ import hashlib
37
+ import re
38
+
39
+
40
+ @manager.route('/list', methods=['POST'])
41
+ @login_required
42
+ @validate_request("doc_id")
43
+ def list_chunk():
44
+ req = request.json
45
+ doc_id = req["doc_id"]
46
+ page = int(req.get("page", 1))
47
+ size = int(req.get("size", 30))
48
+ question = req.get("keywords", "")
49
+ try:
50
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
51
+ if not tenant_id:
52
+ return get_data_error_result(retmsg="Tenant not found!")
53
+ e, doc = DocumentService.get_by_id(doc_id)
54
+ if not e:
55
+ return get_data_error_result(retmsg="Document not found!")
56
+ query = {
57
+ "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
58
+ }
59
+ if "available_int" in req:
60
+ query["available_int"] = int(req["available_int"])
61
+ sres = retrievaler.search(query, search.index_name(tenant_id))
62
+ res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
63
+ for id in sres.ids:
64
+ d = {
65
+ "chunk_id": id,
66
+ "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
67
+ id].get(
68
+ "content_with_weight", ""),
69
+ "doc_id": sres.field[id]["doc_id"],
70
+ "docnm_kwd": sres.field[id]["docnm_kwd"],
71
+ "important_kwd": sres.field[id].get("important_kwd", []),
72
+ "img_id": sres.field[id].get("img_id", ""),
73
+ "available_int": sres.field[id].get("available_int", 1),
74
+ "positions": sres.field[id].get("position_int", "").split("\t")
75
+ }
76
+ if len(d["positions"]) % 5 == 0:
77
+ poss = []
78
+ for i in range(0, len(d["positions"]), 5):
79
+ poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
80
+ float(d["positions"][i + 3]), float(d["positions"][i + 4])])
81
+ d["positions"] = poss
82
+ res["chunks"].append(d)
83
+ return get_json_result(data=res)
84
+ except Exception as e:
85
+ if str(e).find("not_found") > 0:
86
+ return get_json_result(data=False, retmsg=f'No chunk found!',
87
+ retcode=RetCode.DATA_ERROR)
88
+ return server_error_response(e)
89
+
90
+
91
+ @manager.route('/get', methods=['GET'])
92
+ @login_required
93
+ def get():
94
+ chunk_id = request.args["chunk_id"]
95
+ try:
96
+ tenants = UserTenantService.query(user_id=current_user.id)
97
+ if not tenants:
98
+ return get_data_error_result(retmsg="Tenant not found!")
99
+ res = ELASTICSEARCH.get(
100
+ chunk_id, search.index_name(
101
+ tenants[0].tenant_id))
102
+ if not res.get("found"):
103
+ return server_error_response("Chunk not found")
104
+ id = res["_id"]
105
+ res = res["_source"]
106
+ res["chunk_id"] = id
107
+ k = []
108
+ for n in res.keys():
109
+ if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
110
+ k.append(n)
111
+ for n in k:
112
+ del res[n]
113
+
114
+ return get_json_result(data=res)
115
+ except Exception as e:
116
+ if str(e).find("NotFoundError") >= 0:
117
+ return get_json_result(data=False, retmsg=f'Chunk not found!',
118
+ retcode=RetCode.DATA_ERROR)
119
+ return server_error_response(e)
120
+
121
+
122
+ @manager.route('/set', methods=['POST'])
123
+ @login_required
124
+ @validate_request("doc_id", "chunk_id", "content_with_weight",
125
+ "important_kwd")
126
+ def set():
127
+ req = request.json
128
+ d = {
129
+ "id": req["chunk_id"],
130
+ "content_with_weight": req["content_with_weight"]}
131
+ d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
132
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
133
+ d["important_kwd"] = req["important_kwd"]
134
+ d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
135
+ if "available_int" in req:
136
+ d["available_int"] = req["available_int"]
137
+
138
+ try:
139
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
140
+ if not tenant_id:
141
+ return get_data_error_result(retmsg="Tenant not found!")
142
+
143
+ embd_id = DocumentService.get_embd_id(req["doc_id"])
144
+ embd_mdl = TenantLLMService.model_instance(
145
+ tenant_id, LLMType.EMBEDDING.value, embd_id)
146
+
147
+ e, doc = DocumentService.get_by_id(req["doc_id"])
148
+ if not e:
149
+ return get_data_error_result(retmsg="Document not found!")
150
+
151
+ if doc.parser_id == ParserType.QA:
152
+ arr = [
153
+ t for t in re.split(
154
+ r"[\n\t]",
155
+ req["content_with_weight"]) if len(t) > 1]
156
+ if len(arr) != 2:
157
+ return get_data_error_result(
158
+ retmsg="Q&A must be separated by TAB/ENTER key.")
159
+ q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
160
+ d = beAdoc(d, arr[0], arr[1], not any(
161
+ [rag_tokenizer.is_chinese(t) for t in q + a]))
162
+
163
+ v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
164
+ v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
165
+ d["q_%d_vec" % len(v)] = v.tolist()
166
+ ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
167
+ return get_json_result(data=True)
168
+ except Exception as e:
169
+ return server_error_response(e)
170
+
171
+
172
+ @manager.route('/switch', methods=['POST'])
173
+ @login_required
174
+ @validate_request("chunk_ids", "available_int", "doc_id")
175
+ def switch():
176
+ req = request.json
177
+ try:
178
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
179
+ if not tenant_id:
180
+ return get_data_error_result(retmsg="Tenant not found!")
181
+ if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
182
+ search.index_name(tenant_id)):
183
+ return get_data_error_result(retmsg="Index updating failure")
184
+ return get_json_result(data=True)
185
+ except Exception as e:
186
+ return server_error_response(e)
187
+
188
+
189
+ @manager.route('/rm', methods=['POST'])
190
+ @login_required
191
+ @validate_request("chunk_ids", "doc_id")
192
+ def rm():
193
+ req = request.json
194
+ try:
195
+ if not ELASTICSEARCH.deleteByQuery(
196
+ Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
197
+ return get_data_error_result(retmsg="Index updating failure")
198
+ e, doc = DocumentService.get_by_id(req["doc_id"])
199
+ if not e:
200
+ return get_data_error_result(retmsg="Document not found!")
201
+ deleted_chunk_ids = req["chunk_ids"]
202
+ chunk_number = len(deleted_chunk_ids)
203
+ DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
204
+ return get_json_result(data=True)
205
+ except Exception as e:
206
+ return server_error_response(e)
207
+
208
+
209
+ @manager.route('/create', methods=['POST'])
210
+ @login_required
211
+ @validate_request("doc_id", "content_with_weight")
212
+ def create():
213
+ req = request.json
214
+ md5 = hashlib.md5()
215
+ md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
216
+ chunck_id = md5.hexdigest()
217
+ d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
218
+ "content_with_weight": req["content_with_weight"]}
219
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
220
+ d["important_kwd"] = req.get("important_kwd", [])
221
+ d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
222
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
223
+ d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
224
+
225
+ try:
226
+ e, doc = DocumentService.get_by_id(req["doc_id"])
227
+ if not e:
228
+ return get_data_error_result(retmsg="Document not found!")
229
+ d["kb_id"] = [doc.kb_id]
230
+ d["docnm_kwd"] = doc.name
231
+ d["doc_id"] = doc.id
232
+
233
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
234
+ if not tenant_id:
235
+ return get_data_error_result(retmsg="Tenant not found!")
236
+
237
+ embd_id = DocumentService.get_embd_id(req["doc_id"])
238
+ embd_mdl = TenantLLMService.model_instance(
239
+ tenant_id, LLMType.EMBEDDING.value, embd_id)
240
+
241
+ v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
242
+ v = 0.1 * v[0] + 0.9 * v[1]
243
+ d["q_%d_vec" % len(v)] = v.tolist()
244
+ ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
245
+
246
+ DocumentService.increment_chunk_num(
247
+ doc.id, doc.kb_id, c, 1, 0)
248
+ return get_json_result(data={"chunk_id": chunck_id})
249
+ except Exception as e:
250
+ return server_error_response(e)
251
+
252
+
253
+ @manager.route('/retrieval_test', methods=['POST'])
254
+ @login_required
255
+ @validate_request("kb_id", "question")
256
+ def retrieval_test():
257
+ req = request.json
258
+ page = int(req.get("page", 1))
259
+ size = int(req.get("size", 30))
260
+ question = req["question"]
261
+ kb_id = req["kb_id"]
262
+ doc_ids = req.get("doc_ids", [])
263
+ similarity_threshold = float(req.get("similarity_threshold", 0.2))
264
+ vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
265
+ top = int(req.get("top_k", 1024))
266
+ try:
267
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
268
+ if not e:
269
+ return get_data_error_result(retmsg="Knowledgebase not found!")
270
+
271
+ embd_mdl = TenantLLMService.model_instance(
272
+ kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
273
+
274
+ rerank_mdl = None
275
+ if req.get("rerank_id"):
276
+ rerank_mdl = TenantLLMService.model_instance(
277
+ kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
278
+
279
+ if req.get("keyword", False):
280
+ chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
281
+ question += keyword_extraction(chat_mdl, question)
282
+
283
+ retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
284
+ ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
285
+ similarity_threshold, vector_similarity_weight, top,
286
+ doc_ids, rerank_mdl=rerank_mdl)
287
+ for c in ranks["chunks"]:
288
+ if "vector" in c:
289
+ del c["vector"]
290
+
291
+ return get_json_result(data=ranks)
292
+ except Exception as e:
293
+ if str(e).find("not_found") > 0:
294
+ return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
295
+ retcode=RetCode.DATA_ERROR)
296
+ return server_error_response(e)
297
+
298
+
299
+ @manager.route('/knowledge_graph', methods=['GET'])
300
+ @login_required
301
+ def knowledge_graph():
302
+ doc_id = request.args["doc_id"]
303
+ req = {
304
+ "doc_ids":[doc_id],
305
+ "knowledge_graph_kwd": ["graph", "mind_map"]
306
+ }
307
+ tenant_id = DocumentService.get_tenant_id(doc_id)
308
+ sres = retrievaler.search(req, search.index_name(tenant_id))
309
+ obj = {"graph": {}, "mind_map": {}}
310
+ for id in sres.ids[:2]:
311
+ ty = sres.field[id]["knowledge_graph_kwd"]
312
+ try:
313
+ obj[ty] = json.loads(sres.field[id]["content_with_weight"])
314
+ except Exception as e:
315
+ print(traceback.format_exc(), flush=True)
316
+
317
+ return get_json_result(data=obj)
318
+
api/apps/conversation_app.py CHANGED
@@ -1,177 +1,177 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from copy import deepcopy
17
- from flask import request, Response
18
- from flask_login import login_required
19
- from api.db.services.dialog_service import DialogService, ConversationService, chat
20
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
21
- from api.utils import get_uuid
22
- from api.utils.api_utils import get_json_result
23
- import json
24
-
25
-
26
- @manager.route('/set', methods=['POST'])
27
- @login_required
28
- def set_conversation():
29
- req = request.json
30
- conv_id = req.get("conversation_id")
31
- if conv_id:
32
- del req["conversation_id"]
33
- try:
34
- if not ConversationService.update_by_id(conv_id, req):
35
- return get_data_error_result(retmsg="Conversation not found!")
36
- e, conv = ConversationService.get_by_id(conv_id)
37
- if not e:
38
- return get_data_error_result(
39
- retmsg="Fail to update a conversation!")
40
- conv = conv.to_dict()
41
- return get_json_result(data=conv)
42
- except Exception as e:
43
- return server_error_response(e)
44
-
45
- try:
46
- e, dia = DialogService.get_by_id(req["dialog_id"])
47
- if not e:
48
- return get_data_error_result(retmsg="Dialog not found")
49
- conv = {
50
- "id": get_uuid(),
51
- "dialog_id": req["dialog_id"],
52
- "name": req.get("name", "New conversation"),
53
- "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
54
- }
55
- ConversationService.save(**conv)
56
- e, conv = ConversationService.get_by_id(conv["id"])
57
- if not e:
58
- return get_data_error_result(retmsg="Fail to new a conversation!")
59
- conv = conv.to_dict()
60
- return get_json_result(data=conv)
61
- except Exception as e:
62
- return server_error_response(e)
63
-
64
-
65
- @manager.route('/get', methods=['GET'])
66
- @login_required
67
- def get():
68
- conv_id = request.args["conversation_id"]
69
- try:
70
- e, conv = ConversationService.get_by_id(conv_id)
71
- if not e:
72
- return get_data_error_result(retmsg="Conversation not found!")
73
- conv = conv.to_dict()
74
- return get_json_result(data=conv)
75
- except Exception as e:
76
- return server_error_response(e)
77
-
78
-
79
- @manager.route('/rm', methods=['POST'])
80
- @login_required
81
- def rm():
82
- conv_ids = request.json["conversation_ids"]
83
- try:
84
- for cid in conv_ids:
85
- ConversationService.delete_by_id(cid)
86
- return get_json_result(data=True)
87
- except Exception as e:
88
- return server_error_response(e)
89
-
90
-
91
- @manager.route('/list', methods=['GET'])
92
- @login_required
93
- def list_convsersation():
94
- dialog_id = request.args["dialog_id"]
95
- try:
96
- convs = ConversationService.query(
97
- dialog_id=dialog_id,
98
- order_by=ConversationService.model.create_time,
99
- reverse=True)
100
- convs = [d.to_dict() for d in convs]
101
- return get_json_result(data=convs)
102
- except Exception as e:
103
- return server_error_response(e)
104
-
105
-
106
- @manager.route('/completion', methods=['POST'])
107
- @login_required
108
- #@validate_request("conversation_id", "messages")
109
- def completion():
110
- req = request.json
111
- #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
112
- # {"role": "user", "content": "上海有吗?"}
113
- #]}
114
- msg = []
115
- for m in req["messages"]:
116
- if m["role"] == "system":
117
- continue
118
- if m["role"] == "assistant" and not msg:
119
- continue
120
- msg.append({"role": m["role"], "content": m["content"]})
121
- if "doc_ids" in m:
122
- msg[-1]["doc_ids"] = m["doc_ids"]
123
- try:
124
- e, conv = ConversationService.get_by_id(req["conversation_id"])
125
- if not e:
126
- return get_data_error_result(retmsg="Conversation not found!")
127
- conv.message.append(deepcopy(msg[-1]))
128
- e, dia = DialogService.get_by_id(conv.dialog_id)
129
- if not e:
130
- return get_data_error_result(retmsg="Dialog not found!")
131
- del req["conversation_id"]
132
- del req["messages"]
133
-
134
- if not conv.reference:
135
- conv.reference = []
136
- conv.message.append({"role": "assistant", "content": ""})
137
- conv.reference.append({"chunks": [], "doc_aggs": []})
138
-
139
- def fillin_conv(ans):
140
- nonlocal conv
141
- if not conv.reference:
142
- conv.reference.append(ans["reference"])
143
- else: conv.reference[-1] = ans["reference"]
144
- conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
145
-
146
- def stream():
147
- nonlocal dia, msg, req, conv
148
- try:
149
- for ans in chat(dia, msg, True, **req):
150
- fillin_conv(ans)
151
- yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
152
- ConversationService.update_by_id(conv.id, conv.to_dict())
153
- except Exception as e:
154
- yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
155
- "data": {"answer": "**ERROR**: "+str(e), "reference": []}},
156
- ensure_ascii=False) + "\n\n"
157
- yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
158
-
159
- if req.get("stream", True):
160
- resp = Response(stream(), mimetype="text/event-stream")
161
- resp.headers.add_header("Cache-control", "no-cache")
162
- resp.headers.add_header("Connection", "keep-alive")
163
- resp.headers.add_header("X-Accel-Buffering", "no")
164
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
165
- return resp
166
-
167
- else:
168
- answer = None
169
- for ans in chat(dia, msg, **req):
170
- answer = ans
171
- fillin_conv(ans)
172
- ConversationService.update_by_id(conv.id, conv.to_dict())
173
- break
174
- return get_json_result(data=answer)
175
- except Exception as e:
176
- return server_error_response(e)
177
-
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from copy import deepcopy
17
+ from flask import request, Response
18
+ from flask_login import login_required
19
+ from api.db.services.dialog_service import DialogService, ConversationService, chat
20
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
21
+ from api.utils import get_uuid
22
+ from api.utils.api_utils import get_json_result
23
+ import json
24
+
25
+
26
+ @manager.route('/set', methods=['POST'])
27
+ @login_required
28
+ def set_conversation():
29
+ req = request.json
30
+ conv_id = req.get("conversation_id")
31
+ if conv_id:
32
+ del req["conversation_id"]
33
+ try:
34
+ if not ConversationService.update_by_id(conv_id, req):
35
+ return get_data_error_result(retmsg="Conversation not found!")
36
+ e, conv = ConversationService.get_by_id(conv_id)
37
+ if not e:
38
+ return get_data_error_result(
39
+ retmsg="Fail to update a conversation!")
40
+ conv = conv.to_dict()
41
+ return get_json_result(data=conv)
42
+ except Exception as e:
43
+ return server_error_response(e)
44
+
45
+ try:
46
+ e, dia = DialogService.get_by_id(req["dialog_id"])
47
+ if not e:
48
+ return get_data_error_result(retmsg="Dialog not found")
49
+ conv = {
50
+ "id": get_uuid(),
51
+ "dialog_id": req["dialog_id"],
52
+ "name": req.get("name", "New conversation"),
53
+ "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
54
+ }
55
+ ConversationService.save(**conv)
56
+ e, conv = ConversationService.get_by_id(conv["id"])
57
+ if not e:
58
+ return get_data_error_result(retmsg="Fail to new a conversation!")
59
+ conv = conv.to_dict()
60
+ return get_json_result(data=conv)
61
+ except Exception as e:
62
+ return server_error_response(e)
63
+
64
+
65
+ @manager.route('/get', methods=['GET'])
66
+ @login_required
67
+ def get():
68
+ conv_id = request.args["conversation_id"]
69
+ try:
70
+ e, conv = ConversationService.get_by_id(conv_id)
71
+ if not e:
72
+ return get_data_error_result(retmsg="Conversation not found!")
73
+ conv = conv.to_dict()
74
+ return get_json_result(data=conv)
75
+ except Exception as e:
76
+ return server_error_response(e)
77
+
78
+
79
+ @manager.route('/rm', methods=['POST'])
80
+ @login_required
81
+ def rm():
82
+ conv_ids = request.json["conversation_ids"]
83
+ try:
84
+ for cid in conv_ids:
85
+ ConversationService.delete_by_id(cid)
86
+ return get_json_result(data=True)
87
+ except Exception as e:
88
+ return server_error_response(e)
89
+
90
+
91
+ @manager.route('/list', methods=['GET'])
92
+ @login_required
93
+ def list_convsersation():
94
+ dialog_id = request.args["dialog_id"]
95
+ try:
96
+ convs = ConversationService.query(
97
+ dialog_id=dialog_id,
98
+ order_by=ConversationService.model.create_time,
99
+ reverse=True)
100
+ convs = [d.to_dict() for d in convs]
101
+ return get_json_result(data=convs)
102
+ except Exception as e:
103
+ return server_error_response(e)
104
+
105
+
106
+ @manager.route('/completion', methods=['POST'])
107
+ @login_required
108
+ #@validate_request("conversation_id", "messages")
109
+ def completion():
110
+ req = request.json
111
+ #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
112
+ # {"role": "user", "content": "上海有吗?"}
113
+ #]}
114
+ msg = []
115
+ for m in req["messages"]:
116
+ if m["role"] == "system":
117
+ continue
118
+ if m["role"] == "assistant" and not msg:
119
+ continue
120
+ msg.append({"role": m["role"], "content": m["content"]})
121
+ if "doc_ids" in m:
122
+ msg[-1]["doc_ids"] = m["doc_ids"]
123
+ try:
124
+ e, conv = ConversationService.get_by_id(req["conversation_id"])
125
+ if not e:
126
+ return get_data_error_result(retmsg="Conversation not found!")
127
+ conv.message.append(deepcopy(msg[-1]))
128
+ e, dia = DialogService.get_by_id(conv.dialog_id)
129
+ if not e:
130
+ return get_data_error_result(retmsg="Dialog not found!")
131
+ del req["conversation_id"]
132
+ del req["messages"]
133
+
134
+ if not conv.reference:
135
+ conv.reference = []
136
+ conv.message.append({"role": "assistant", "content": ""})
137
+ conv.reference.append({"chunks": [], "doc_aggs": []})
138
+
139
+ def fillin_conv(ans):
140
+ nonlocal conv
141
+ if not conv.reference:
142
+ conv.reference.append(ans["reference"])
143
+ else: conv.reference[-1] = ans["reference"]
144
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
145
+
146
+ def stream():
147
+ nonlocal dia, msg, req, conv
148
+ try:
149
+ for ans in chat(dia, msg, True, **req):
150
+ fillin_conv(ans)
151
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
152
+ ConversationService.update_by_id(conv.id, conv.to_dict())
153
+ except Exception as e:
154
+ yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
155
+ "data": {"answer": "**ERROR**: "+str(e), "reference": []}},
156
+ ensure_ascii=False) + "\n\n"
157
+ yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
158
+
159
+ if req.get("stream", True):
160
+ resp = Response(stream(), mimetype="text/event-stream")
161
+ resp.headers.add_header("Cache-control", "no-cache")
162
+ resp.headers.add_header("Connection", "keep-alive")
163
+ resp.headers.add_header("X-Accel-Buffering", "no")
164
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
165
+ return resp
166
+
167
+ else:
168
+ answer = None
169
+ for ans in chat(dia, msg, **req):
170
+ answer = ans
171
+ fillin_conv(ans)
172
+ ConversationService.update_by_id(conv.id, conv.to_dict())
173
+ break
174
+ return get_json_result(data=answer)
175
+ except Exception as e:
176
+ return server_error_response(e)
177
+
api/apps/dialog_app.py CHANGED
@@ -1,172 +1,172 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- from flask import request
18
- from flask_login import login_required, current_user
19
- from api.db.services.dialog_service import DialogService
20
- from api.db import StatusEnum
21
- from api.db.services.knowledgebase_service import KnowledgebaseService
22
- from api.db.services.user_service import TenantService
23
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
- from api.utils import get_uuid
25
- from api.utils.api_utils import get_json_result
26
-
27
-
28
- @manager.route('/set', methods=['POST'])
29
- @login_required
30
- def set_dialog():
31
- req = request.json
32
- dialog_id = req.get("dialog_id")
33
- name = req.get("name", "New Dialog")
34
- description = req.get("description", "A helpful Dialog")
35
- icon = req.get("icon", "")
36
- top_n = req.get("top_n", 6)
37
- top_k = req.get("top_k", 1024)
38
- rerank_id = req.get("rerank_id", "")
39
- if not rerank_id: req["rerank_id"] = ""
40
- similarity_threshold = req.get("similarity_threshold", 0.1)
41
- vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
42
- if vector_similarity_weight is None: vector_similarity_weight = 0.3
43
- llm_setting = req.get("llm_setting", {})
44
- default_prompt = {
45
- "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
46
- 以下是知识库:
47
- {knowledge}
48
- 以上是知识库。""",
49
- "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
50
- "parameters": [
51
- {"key": "knowledge", "optional": False}
52
- ],
53
- "empty_response": "Sorry! 知识库中未找到相关内容!"
54
- }
55
- prompt_config = req.get("prompt_config", default_prompt)
56
-
57
- if not prompt_config["system"]:
58
- prompt_config["system"] = default_prompt["system"]
59
- # if len(prompt_config["parameters"]) < 1:
60
- # prompt_config["parameters"] = default_prompt["parameters"]
61
- # for p in prompt_config["parameters"]:
62
- # if p["key"] == "knowledge":break
63
- # else: prompt_config["parameters"].append(default_prompt["parameters"][0])
64
-
65
- for p in prompt_config["parameters"]:
66
- if p["optional"]:
67
- continue
68
- if prompt_config["system"].find("{%s}" % p["key"]) < 0:
69
- return get_data_error_result(
70
- retmsg="Parameter '{}' is not used".format(p["key"]))
71
-
72
- try:
73
- e, tenant = TenantService.get_by_id(current_user.id)
74
- if not e:
75
- return get_data_error_result(retmsg="Tenant not found!")
76
- llm_id = req.get("llm_id", tenant.llm_id)
77
- if not dialog_id:
78
- if not req.get("kb_ids"):
79
- return get_data_error_result(
80
- retmsg="Fail! Please select knowledgebase!")
81
- dia = {
82
- "id": get_uuid(),
83
- "tenant_id": current_user.id,
84
- "name": name,
85
- "kb_ids": req["kb_ids"],
86
- "description": description,
87
- "llm_id": llm_id,
88
- "llm_setting": llm_setting,
89
- "prompt_config": prompt_config,
90
- "top_n": top_n,
91
- "top_k": top_k,
92
- "rerank_id": rerank_id,
93
- "similarity_threshold": similarity_threshold,
94
- "vector_similarity_weight": vector_similarity_weight,
95
- "icon": icon
96
- }
97
- if not DialogService.save(**dia):
98
- return get_data_error_result(retmsg="Fail to new a dialog!")
99
- e, dia = DialogService.get_by_id(dia["id"])
100
- if not e:
101
- return get_data_error_result(retmsg="Fail to new a dialog!")
102
- return get_json_result(data=dia.to_json())
103
- else:
104
- del req["dialog_id"]
105
- if "kb_names" in req:
106
- del req["kb_names"]
107
- if not DialogService.update_by_id(dialog_id, req):
108
- return get_data_error_result(retmsg="Dialog not found!")
109
- e, dia = DialogService.get_by_id(dialog_id)
110
- if not e:
111
- return get_data_error_result(retmsg="Fail to update a dialog!")
112
- dia = dia.to_dict()
113
- dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
114
- return get_json_result(data=dia)
115
- except Exception as e:
116
- return server_error_response(e)
117
-
118
-
119
- @manager.route('/get', methods=['GET'])
120
- @login_required
121
- def get():
122
- dialog_id = request.args["dialog_id"]
123
- try:
124
- e, dia = DialogService.get_by_id(dialog_id)
125
- if not e:
126
- return get_data_error_result(retmsg="Dialog not found!")
127
- dia = dia.to_dict()
128
- dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
129
- return get_json_result(data=dia)
130
- except Exception as e:
131
- return server_error_response(e)
132
-
133
-
134
- def get_kb_names(kb_ids):
135
- ids, nms = [], []
136
- for kid in kb_ids:
137
- e, kb = KnowledgebaseService.get_by_id(kid)
138
- if not e or kb.status != StatusEnum.VALID.value:
139
- continue
140
- ids.append(kid)
141
- nms.append(kb.name)
142
- return ids, nms
143
-
144
-
145
- @manager.route('/list', methods=['GET'])
146
- @login_required
147
- def list_dialogs():
148
- try:
149
- diags = DialogService.query(
150
- tenant_id=current_user.id,
151
- status=StatusEnum.VALID.value,
152
- reverse=True,
153
- order_by=DialogService.model.create_time)
154
- diags = [d.to_dict() for d in diags]
155
- for d in diags:
156
- d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
157
- return get_json_result(data=diags)
158
- except Exception as e:
159
- return server_error_response(e)
160
-
161
-
162
- @manager.route('/rm', methods=['POST'])
163
- @login_required
164
- @validate_request("dialog_ids")
165
- def rm():
166
- req = request.json
167
- try:
168
- DialogService.update_many_by_id(
169
- [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
170
- return get_json_result(data=True)
171
- except Exception as e:
172
- return server_error_response(e)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ from flask import request
18
+ from flask_login import login_required, current_user
19
+ from api.db.services.dialog_service import DialogService
20
+ from api.db import StatusEnum
21
+ from api.db.services.knowledgebase_service import KnowledgebaseService
22
+ from api.db.services.user_service import TenantService
23
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
24
+ from api.utils import get_uuid
25
+ from api.utils.api_utils import get_json_result
26
+
27
+
28
+ @manager.route('/set', methods=['POST'])
29
+ @login_required
30
+ def set_dialog():
31
+ req = request.json
32
+ dialog_id = req.get("dialog_id")
33
+ name = req.get("name", "New Dialog")
34
+ description = req.get("description", "A helpful Dialog")
35
+ icon = req.get("icon", "")
36
+ top_n = req.get("top_n", 6)
37
+ top_k = req.get("top_k", 1024)
38
+ rerank_id = req.get("rerank_id", "")
39
+ if not rerank_id: req["rerank_id"] = ""
40
+ similarity_threshold = req.get("similarity_threshold", 0.1)
41
+ vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
42
+ if vector_similarity_weight is None: vector_similarity_weight = 0.3
43
+ llm_setting = req.get("llm_setting", {})
44
+ default_prompt = {
45
+ "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
46
+ 以下是知识库:
47
+ {knowledge}
48
+ 以上是知识库。""",
49
+ "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
50
+ "parameters": [
51
+ {"key": "knowledge", "optional": False}
52
+ ],
53
+ "empty_response": "Sorry! 知识库中未找到相关内容!"
54
+ }
55
+ prompt_config = req.get("prompt_config", default_prompt)
56
+
57
+ if not prompt_config["system"]:
58
+ prompt_config["system"] = default_prompt["system"]
59
+ # if len(prompt_config["parameters"]) < 1:
60
+ # prompt_config["parameters"] = default_prompt["parameters"]
61
+ # for p in prompt_config["parameters"]:
62
+ # if p["key"] == "knowledge":break
63
+ # else: prompt_config["parameters"].append(default_prompt["parameters"][0])
64
+
65
+ for p in prompt_config["parameters"]:
66
+ if p["optional"]:
67
+ continue
68
+ if prompt_config["system"].find("{%s}" % p["key"]) < 0:
69
+ return get_data_error_result(
70
+ retmsg="Parameter '{}' is not used".format(p["key"]))
71
+
72
+ try:
73
+ e, tenant = TenantService.get_by_id(current_user.id)
74
+ if not e:
75
+ return get_data_error_result(retmsg="Tenant not found!")
76
+ llm_id = req.get("llm_id", tenant.llm_id)
77
+ if not dialog_id:
78
+ if not req.get("kb_ids"):
79
+ return get_data_error_result(
80
+ retmsg="Fail! Please select knowledgebase!")
81
+ dia = {
82
+ "id": get_uuid(),
83
+ "tenant_id": current_user.id,
84
+ "name": name,
85
+ "kb_ids": req["kb_ids"],
86
+ "description": description,
87
+ "llm_id": llm_id,
88
+ "llm_setting": llm_setting,
89
+ "prompt_config": prompt_config,
90
+ "top_n": top_n,
91
+ "top_k": top_k,
92
+ "rerank_id": rerank_id,
93
+ "similarity_threshold": similarity_threshold,
94
+ "vector_similarity_weight": vector_similarity_weight,
95
+ "icon": icon
96
+ }
97
+ if not DialogService.save(**dia):
98
+ return get_data_error_result(retmsg="Fail to new a dialog!")
99
+ e, dia = DialogService.get_by_id(dia["id"])
100
+ if not e:
101
+ return get_data_error_result(retmsg="Fail to new a dialog!")
102
+ return get_json_result(data=dia.to_json())
103
+ else:
104
+ del req["dialog_id"]
105
+ if "kb_names" in req:
106
+ del req["kb_names"]
107
+ if not DialogService.update_by_id(dialog_id, req):
108
+ return get_data_error_result(retmsg="Dialog not found!")
109
+ e, dia = DialogService.get_by_id(dialog_id)
110
+ if not e:
111
+ return get_data_error_result(retmsg="Fail to update a dialog!")
112
+ dia = dia.to_dict()
113
+ dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
114
+ return get_json_result(data=dia)
115
+ except Exception as e:
116
+ return server_error_response(e)
117
+
118
+
119
+ @manager.route('/get', methods=['GET'])
120
+ @login_required
121
+ def get():
122
+ dialog_id = request.args["dialog_id"]
123
+ try:
124
+ e, dia = DialogService.get_by_id(dialog_id)
125
+ if not e:
126
+ return get_data_error_result(retmsg="Dialog not found!")
127
+ dia = dia.to_dict()
128
+ dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
129
+ return get_json_result(data=dia)
130
+ except Exception as e:
131
+ return server_error_response(e)
132
+
133
+
134
+ def get_kb_names(kb_ids):
135
+ ids, nms = [], []
136
+ for kid in kb_ids:
137
+ e, kb = KnowledgebaseService.get_by_id(kid)
138
+ if not e or kb.status != StatusEnum.VALID.value:
139
+ continue
140
+ ids.append(kid)
141
+ nms.append(kb.name)
142
+ return ids, nms
143
+
144
+
145
+ @manager.route('/list', methods=['GET'])
146
+ @login_required
147
+ def list_dialogs():
148
+ try:
149
+ diags = DialogService.query(
150
+ tenant_id=current_user.id,
151
+ status=StatusEnum.VALID.value,
152
+ reverse=True,
153
+ order_by=DialogService.model.create_time)
154
+ diags = [d.to_dict() for d in diags]
155
+ for d in diags:
156
+ d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
157
+ return get_json_result(data=diags)
158
+ except Exception as e:
159
+ return server_error_response(e)
160
+
161
+
162
+ @manager.route('/rm', methods=['POST'])
163
+ @login_required
164
+ @validate_request("dialog_ids")
165
+ def rm():
166
+ req = request.json
167
+ try:
168
+ DialogService.update_many_by_id(
169
+ [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
170
+ return get_json_result(data=True)
171
+ except Exception as e:
172
+ return server_error_response(e)
api/apps/document_app.py CHANGED
@@ -1,586 +1,586 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License
15
- #
16
- import datetime
17
- import hashlib
18
- import json
19
- import os
20
- import pathlib
21
- import re
22
- import traceback
23
- from concurrent.futures import ThreadPoolExecutor
24
- from copy import deepcopy
25
- from io import BytesIO
26
-
27
- import flask
28
- from elasticsearch_dsl import Q
29
- from flask import request
30
- from flask_login import login_required, current_user
31
-
32
- from api.db.db_models import Task, File
33
- from api.db.services.dialog_service import DialogService, ConversationService
34
- from api.db.services.file2document_service import File2DocumentService
35
- from api.db.services.file_service import FileService
36
- from api.db.services.llm_service import LLMBundle
37
- from api.db.services.task_service import TaskService, queue_tasks
38
- from api.db.services.user_service import TenantService
39
- from graphrag.mind_map_extractor import MindMapExtractor
40
- from rag.app import naive
41
- from rag.nlp import search
42
- from rag.utils.es_conn import ELASTICSEARCH
43
- from api.db.services import duplicate_name
44
- from api.db.services.knowledgebase_service import KnowledgebaseService
45
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
46
- from api.utils import get_uuid
47
- from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
48
- from api.db.services.document_service import DocumentService
49
- from api.settings import RetCode, stat_logger
50
- from api.utils.api_utils import get_json_result
51
- from rag.utils.minio_conn import MINIO
52
- from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
53
- from api.utils.web_utils import html2pdf, is_valid_url
54
-
55
-
56
- @manager.route('/upload', methods=['POST'])
57
- @login_required
58
- @validate_request("kb_id")
59
- def upload():
60
- kb_id = request.form.get("kb_id")
61
- if not kb_id:
62
- return get_json_result(
63
- data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
64
- if 'file' not in request.files:
65
- return get_json_result(
66
- data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
67
-
68
- file_objs = request.files.getlist('file')
69
- for file_obj in file_objs:
70
- if file_obj.filename == '':
71
- return get_json_result(
72
- data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
73
-
74
- e, kb = KnowledgebaseService.get_by_id(kb_id)
75
- if not e:
76
- raise LookupError("Can't find this knowledgebase!")
77
-
78
- err, _ = FileService.upload_document(kb, file_objs)
79
- if err:
80
- return get_json_result(
81
- data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
82
- return get_json_result(data=True)
83
-
84
-
85
- @manager.route('/web_crawl', methods=['POST'])
86
- @login_required
87
- @validate_request("kb_id", "name", "url")
88
- def web_crawl():
89
- kb_id = request.form.get("kb_id")
90
- if not kb_id:
91
- return get_json_result(
92
- data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
93
- name = request.form.get("name")
94
- url = request.form.get("url")
95
- if not is_valid_url(url):
96
- return get_json_result(
97
- data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
98
- e, kb = KnowledgebaseService.get_by_id(kb_id)
99
- if not e:
100
- raise LookupError("Can't find this knowledgebase!")
101
-
102
- blob = html2pdf(url)
103
- if not blob: return server_error_response(ValueError("Download failure."))
104
-
105
- root_folder = FileService.get_root_folder(current_user.id)
106
- pf_id = root_folder["id"]
107
- FileService.init_knowledgebase_docs(pf_id, current_user.id)
108
- kb_root_folder = FileService.get_kb_folder(current_user.id)
109
- kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
110
-
111
- try:
112
- filename = duplicate_name(
113
- DocumentService.query,
114
- name=name + ".pdf",
115
- kb_id=kb.id)
116
- filetype = filename_type(filename)
117
- if filetype == FileType.OTHER.value:
118
- raise RuntimeError("This type of file has not been supported yet!")
119
-
120
- location = filename
121
- while MINIO.obj_exist(kb_id, location):
122
- location += "_"
123
- MINIO.put(kb_id, location, blob)
124
- doc = {
125
- "id": get_uuid(),
126
- "kb_id": kb.id,
127
- "parser_id": kb.parser_id,
128
- "parser_config": kb.parser_config,
129
- "created_by": current_user.id,
130
- "type": filetype,
131
- "name": filename,
132
- "location": location,
133
- "size": len(blob),
134
- "thumbnail": thumbnail(filename, blob)
135
- }
136
- if doc["type"] == FileType.VISUAL:
137
- doc["parser_id"] = ParserType.PICTURE.value
138
- if doc["type"] == FileType.AURAL:
139
- doc["parser_id"] = ParserType.AUDIO.value
140
- if re.search(r"\.(ppt|pptx|pages)$", filename):
141
- doc["parser_id"] = ParserType.PRESENTATION.value
142
- DocumentService.insert(doc)
143
- FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
144
- except Exception as e:
145
- return server_error_response(e)
146
- return get_json_result(data=True)
147
-
148
-
149
- @manager.route('/create', methods=['POST'])
150
- @login_required
151
- @validate_request("name", "kb_id")
152
- def create():
153
- req = request.json
154
- kb_id = req["kb_id"]
155
- if not kb_id:
156
- return get_json_result(
157
- data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
158
-
159
- try:
160
- e, kb = KnowledgebaseService.get_by_id(kb_id)
161
- if not e:
162
- return get_data_error_result(
163
- retmsg="Can't find this knowledgebase!")
164
-
165
- if DocumentService.query(name=req["name"], kb_id=kb_id):
166
- return get_data_error_result(
167
- retmsg="Duplicated document name in the same knowledgebase.")
168
-
169
- doc = DocumentService.insert({
170
- "id": get_uuid(),
171
- "kb_id": kb.id,
172
- "parser_id": kb.parser_id,
173
- "parser_config": kb.parser_config,
174
- "created_by": current_user.id,
175
- "type": FileType.VIRTUAL,
176
- "name": req["name"],
177
- "location": "",
178
- "size": 0
179
- })
180
- return get_json_result(data=doc.to_json())
181
- except Exception as e:
182
- return server_error_response(e)
183
-
184
-
185
- @manager.route('/list', methods=['GET'])
186
- @login_required
187
- def list_docs():
188
- kb_id = request.args.get("kb_id")
189
- if not kb_id:
190
- return get_json_result(
191
- data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
192
- keywords = request.args.get("keywords", "")
193
-
194
- page_number = int(request.args.get("page", 1))
195
- items_per_page = int(request.args.get("page_size", 15))
196
- orderby = request.args.get("orderby", "create_time")
197
- desc = request.args.get("desc", True)
198
- try:
199
- docs, tol = DocumentService.get_by_kb_id(
200
- kb_id, page_number, items_per_page, orderby, desc, keywords)
201
- return get_json_result(data={"total": tol, "docs": docs})
202
- except Exception as e:
203
- return server_error_response(e)
204
-
205
-
206
- @manager.route('/thumbnails', methods=['GET'])
207
- @login_required
208
- def thumbnails():
209
- doc_ids = request.args.get("doc_ids").split(",")
210
- if not doc_ids:
211
- return get_json_result(
212
- data=False, retmsg='Lack of "Document ID"', retcode=RetCode.ARGUMENT_ERROR)
213
-
214
- try:
215
- docs = DocumentService.get_thumbnails(doc_ids)
216
- return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
217
- except Exception as e:
218
- return server_error_response(e)
219
-
220
-
221
- @manager.route('/change_status', methods=['POST'])
222
- @login_required
223
- @validate_request("doc_id", "status")
224
- def change_status():
225
- req = request.json
226
- if str(req["status"]) not in ["0", "1"]:
227
- get_json_result(
228
- data=False,
229
- retmsg='"Status" must be either 0 or 1!',
230
- retcode=RetCode.ARGUMENT_ERROR)
231
-
232
- try:
233
- e, doc = DocumentService.get_by_id(req["doc_id"])
234
- if not e:
235
- return get_data_error_result(retmsg="Document not found!")
236
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
237
- if not e:
238
- return get_data_error_result(
239
- retmsg="Can't find this knowledgebase!")
240
-
241
- if not DocumentService.update_by_id(
242
- req["doc_id"], {"status": str(req["status"])}):
243
- return get_data_error_result(
244
- retmsg="Database error (Document update)!")
245
-
246
- if str(req["status"]) == "0":
247
- ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
248
- scripts="ctx._source.available_int=0;",
249
- idxnm=search.index_name(
250
- kb.tenant_id)
251
- )
252
- else:
253
- ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
254
- scripts="ctx._source.available_int=1;",
255
- idxnm=search.index_name(
256
- kb.tenant_id)
257
- )
258
- return get_json_result(data=True)
259
- except Exception as e:
260
- return server_error_response(e)
261
-
262
-
263
- @manager.route('/rm', methods=['POST'])
264
- @login_required
265
- @validate_request("doc_id")
266
- def rm():
267
- req = request.json
268
- doc_ids = req["doc_id"]
269
- if isinstance(doc_ids, str): doc_ids = [doc_ids]
270
- root_folder = FileService.get_root_folder(current_user.id)
271
- pf_id = root_folder["id"]
272
- FileService.init_knowledgebase_docs(pf_id, current_user.id)
273
- errors = ""
274
- for doc_id in doc_ids:
275
- try:
276
- e, doc = DocumentService.get_by_id(doc_id)
277
- if not e:
278
- return get_data_error_result(retmsg="Document not found!")
279
- tenant_id = DocumentService.get_tenant_id(doc_id)
280
- if not tenant_id:
281
- return get_data_error_result(retmsg="Tenant not found!")
282
-
283
- b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
284
-
285
- if not DocumentService.remove_document(doc, tenant_id):
286
- return get_data_error_result(
287
- retmsg="Database error (Document removal)!")
288
-
289
- f2d = File2DocumentService.get_by_document_id(doc_id)
290
- FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
291
- File2DocumentService.delete_by_document_id(doc_id)
292
-
293
- MINIO.rm(b, n)
294
- except Exception as e:
295
- errors += str(e)
296
-
297
- if errors:
298
- return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
299
-
300
- return get_json_result(data=True)
301
-
302
-
303
- @manager.route('/run', methods=['POST'])
304
- @login_required
305
- @validate_request("doc_ids", "run")
306
- def run():
307
- req = request.json
308
- try:
309
- for id in req["doc_ids"]:
310
- info = {"run": str(req["run"]), "progress": 0}
311
- if str(req["run"]) == TaskStatus.RUNNING.value:
312
- info["progress_msg"] = ""
313
- info["chunk_num"] = 0
314
- info["token_num"] = 0
315
- DocumentService.update_by_id(id, info)
316
- # if str(req["run"]) == TaskStatus.CANCEL.value:
317
- tenant_id = DocumentService.get_tenant_id(id)
318
- if not tenant_id:
319
- return get_data_error_result(retmsg="Tenant not found!")
320
- ELASTICSEARCH.deleteByQuery(
321
- Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
322
-
323
- if str(req["run"]) == TaskStatus.RUNNING.value:
324
- TaskService.filter_delete([Task.doc_id == id])
325
- e, doc = DocumentService.get_by_id(id)
326
- doc = doc.to_dict()
327
- doc["tenant_id"] = tenant_id
328
- bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
329
- queue_tasks(doc, bucket, name)
330
-
331
- return get_json_result(data=True)
332
- except Exception as e:
333
- return server_error_response(e)
334
-
335
-
336
- @manager.route('/rename', methods=['POST'])
337
- @login_required
338
- @validate_request("doc_id", "name")
339
- def rename():
340
- req = request.json
341
- try:
342
- e, doc = DocumentService.get_by_id(req["doc_id"])
343
- if not e:
344
- return get_data_error_result(retmsg="Document not found!")
345
- if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
346
- doc.name.lower()).suffix:
347
- return get_json_result(
348
- data=False,
349
- retmsg="The extension of file can't be changed",
350
- retcode=RetCode.ARGUMENT_ERROR)
351
- for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
352
- if d.name == req["name"]:
353
- return get_data_error_result(
354
- retmsg="Duplicated document name in the same knowledgebase.")
355
-
356
- if not DocumentService.update_by_id(
357
- req["doc_id"], {"name": req["name"]}):
358
- return get_data_error_result(
359
- retmsg="Database error (Document rename)!")
360
-
361
- informs = File2DocumentService.get_by_document_id(req["doc_id"])
362
- if informs:
363
- e, file = FileService.get_by_id(informs[0].file_id)
364
- FileService.update_by_id(file.id, {"name": req["name"]})
365
-
366
- return get_json_result(data=True)
367
- except Exception as e:
368
- return server_error_response(e)
369
-
370
-
371
- @manager.route('/get/<doc_id>', methods=['GET'])
372
- # @login_required
373
- def get(doc_id):
374
- try:
375
- e, doc = DocumentService.get_by_id(doc_id)
376
- if not e:
377
- return get_data_error_result(retmsg="Document not found!")
378
-
379
- b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
380
- response = flask.make_response(MINIO.get(b, n))
381
-
382
- ext = re.search(r"\.([^.]+)$", doc.name)
383
- if ext:
384
- if doc.type == FileType.VISUAL.value:
385
- response.headers.set('Content-Type', 'image/%s' % ext.group(1))
386
- else:
387
- response.headers.set(
388
- 'Content-Type',
389
- 'application/%s' %
390
- ext.group(1))
391
- return response
392
- except Exception as e:
393
- return server_error_response(e)
394
-
395
-
396
- @manager.route('/change_parser', methods=['POST'])
397
- @login_required
398
- @validate_request("doc_id", "parser_id")
399
- def change_parser():
400
- req = request.json
401
- try:
402
- e, doc = DocumentService.get_by_id(req["doc_id"])
403
- if not e:
404
- return get_data_error_result(retmsg="Document not found!")
405
- if doc.parser_id.lower() == req["parser_id"].lower():
406
- if "parser_config" in req:
407
- if req["parser_config"] == doc.parser_config:
408
- return get_json_result(data=True)
409
- else:
410
- return get_json_result(data=True)
411
-
412
- if doc.type == FileType.VISUAL or re.search(
413
- r"\.(ppt|pptx|pages)$", doc.name):
414
- return get_data_error_result(retmsg="Not supported yet!")
415
-
416
- e = DocumentService.update_by_id(doc.id,
417
- {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
418
- "run": TaskStatus.UNSTART.value})
419
- if not e:
420
- return get_data_error_result(retmsg="Document not found!")
421
- if "parser_config" in req:
422
- DocumentService.update_parser_config(doc.id, req["parser_config"])
423
- if doc.token_num > 0:
424
- e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
425
- doc.process_duation * -1)
426
- if not e:
427
- return get_data_error_result(retmsg="Document not found!")
428
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
429
- if not tenant_id:
430
- return get_data_error_result(retmsg="Tenant not found!")
431
- ELASTICSEARCH.deleteByQuery(
432
- Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
433
-
434
- return get_json_result(data=True)
435
- except Exception as e:
436
- return server_error_response(e)
437
-
438
-
439
- @manager.route('/image/<image_id>', methods=['GET'])
440
- # @login_required
441
- def get_image(image_id):
442
- try:
443
- bkt, nm = image_id.split("-")
444
- response = flask.make_response(MINIO.get(bkt, nm))
445
- response.headers.set('Content-Type', 'image/JPEG')
446
- return response
447
- except Exception as e:
448
- return server_error_response(e)
449
-
450
-
451
- @manager.route('/upload_and_parse', methods=['POST'])
452
- @login_required
453
- @validate_request("conversation_id")
454
- def upload_and_parse():
455
- from rag.app import presentation, picture, naive, audio, email
456
- if 'file' not in request.files:
457
- return get_json_result(
458
- data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
459
-
460
- file_objs = request.files.getlist('file')
461
- for file_obj in file_objs:
462
- if file_obj.filename == '':
463
- return get_json_result(
464
- data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
465
-
466
- e, conv = ConversationService.get_by_id(request.form.get("conversation_id"))
467
- if not e:
468
- return get_data_error_result(retmsg="Conversation not found!")
469
- e, dia = DialogService.get_by_id(conv.dialog_id)
470
- kb_id = dia.kb_ids[0]
471
- e, kb = KnowledgebaseService.get_by_id(kb_id)
472
- if not e:
473
- raise LookupError("Can't find this knowledgebase!")
474
-
475
- idxnm = search.index_name(kb.tenant_id)
476
- if not ELASTICSEARCH.indexExist(idxnm):
477
- ELASTICSEARCH.createIdx(idxnm, json.load(
478
- open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
479
-
480
- embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
481
-
482
- err, files = FileService.upload_document(kb, file_objs)
483
- if err:
484
- return get_json_result(
485
- data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
486
-
487
- def dummy(prog=None, msg=""):
488
- pass
489
-
490
- FACTORY = {
491
- ParserType.PRESENTATION.value: presentation,
492
- ParserType.PICTURE.value: picture,
493
- ParserType.AUDIO.value: audio,
494
- ParserType.EMAIL.value: email
495
- }
496
- parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
497
- exe = ThreadPoolExecutor(max_workers=12)
498
- threads = []
499
- for d, blob in files:
500
- kwargs = {
501
- "callback": dummy,
502
- "parser_config": parser_config,
503
- "from_page": 0,
504
- "to_page": 100000,
505
- "tenant_id": kb.tenant_id,
506
- "lang": kb.language
507
- }
508
- threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
509
-
510
- for (docinfo,_), th in zip(files, threads):
511
- docs = []
512
- doc = {
513
- "doc_id": docinfo["id"],
514
- "kb_id": [kb.id]
515
- }
516
- for ck in th.result():
517
- d = deepcopy(doc)
518
- d.update(ck)
519
- md5 = hashlib.md5()
520
- md5.update((ck["content_with_weight"] +
521
- str(d["doc_id"])).encode("utf-8"))
522
- d["_id"] = md5.hexdigest()
523
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
524
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
525
- if not d.get("image"):
526
- docs.append(d)
527
- continue
528
-
529
- output_buffer = BytesIO()
530
- if isinstance(d["image"], bytes):
531
- output_buffer = BytesIO(d["image"])
532
- else:
533
- d["image"].save(output_buffer, format='JPEG')
534
-
535
- MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
536
- d["img_id"] = "{}-{}".format(kb.id, d["_id"])
537
- del d["image"]
538
- docs.append(d)
539
-
540
- parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
541
- docids = [d["id"] for d, _ in files]
542
- chunk_counts = {id: 0 for id in docids}
543
- token_counts = {id: 0 for id in docids}
544
- es_bulk_size = 64
545
-
546
- def embedding(doc_id, cnts, batch_size=16):
547
- nonlocal embd_mdl, chunk_counts, token_counts
548
- vects = []
549
- for i in range(0, len(cnts), batch_size):
550
- vts, c = embd_mdl.encode(cnts[i: i + batch_size])
551
- vects.extend(vts.tolist())
552
- chunk_counts[doc_id] += len(cnts[i:i + batch_size])
553
- token_counts[doc_id] += c
554
- return vects
555
-
556
- _, tenant = TenantService.get_by_id(kb.tenant_id)
557
- llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
558
- for doc_id in docids:
559
- cks = [c for c in docs if c["doc_id"] == doc_id]
560
-
561
- if False and parser_ids[doc_id] != ParserType.PICTURE.value:
562
- mindmap = MindMapExtractor(llm_bdl)
563
- try:
564
- mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
565
- if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
566
- cks.append({
567
- "doc_id": doc_id,
568
- "kb_id": [kb.id],
569
- "content_with_weight": mind_map,
570
- "knowledge_graph_kwd": "mind_map"
571
- })
572
- except Exception as e:
573
- stat_logger.error("Mind map generation error:", traceback.format_exc())
574
-
575
- vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
576
- assert len(cks) == len(vects)
577
- for i, d in enumerate(cks):
578
- v = vects[i]
579
- d["q_%d_vec" % len(v)] = v
580
- for b in range(0, len(cks), es_bulk_size):
581
- ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
582
-
583
- DocumentService.increment_chunk_num(
584
- doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
585
-
586
- return get_json_result(data=[d["id"] for d,_ in files])
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ #
16
+ import datetime
17
+ import hashlib
18
+ import json
19
+ import os
20
+ import pathlib
21
+ import re
22
+ import traceback
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from copy import deepcopy
25
+ from io import BytesIO
26
+
27
+ import flask
28
+ from elasticsearch_dsl import Q
29
+ from flask import request
30
+ from flask_login import login_required, current_user
31
+
32
+ from api.db.db_models import Task, File
33
+ from api.db.services.dialog_service import DialogService, ConversationService
34
+ from api.db.services.file2document_service import File2DocumentService
35
+ from api.db.services.file_service import FileService
36
+ from api.db.services.llm_service import LLMBundle
37
+ from api.db.services.task_service import TaskService, queue_tasks
38
+ from api.db.services.user_service import TenantService
39
+ from graphrag.mind_map_extractor import MindMapExtractor
40
+ from rag.app import naive
41
+ from rag.nlp import search
42
+ from rag.utils.es_conn import ELASTICSEARCH
43
+ from api.db.services import duplicate_name
44
+ from api.db.services.knowledgebase_service import KnowledgebaseService
45
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
46
+ from api.utils import get_uuid
47
+ from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
48
+ from api.db.services.document_service import DocumentService
49
+ from api.settings import RetCode, stat_logger
50
+ from api.utils.api_utils import get_json_result
51
+ from rag.utils.minio_conn import MINIO
52
+ from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
53
+ from api.utils.web_utils import html2pdf, is_valid_url
54
+
55
+
56
+ @manager.route('/upload', methods=['POST'])
57
+ @login_required
58
+ @validate_request("kb_id")
59
+ def upload():
60
+ kb_id = request.form.get("kb_id")
61
+ if not kb_id:
62
+ return get_json_result(
63
+ data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
64
+ if 'file' not in request.files:
65
+ return get_json_result(
66
+ data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
67
+
68
+ file_objs = request.files.getlist('file')
69
+ for file_obj in file_objs:
70
+ if file_obj.filename == '':
71
+ return get_json_result(
72
+ data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
73
+
74
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
75
+ if not e:
76
+ raise LookupError("Can't find this knowledgebase!")
77
+
78
+ err, _ = FileService.upload_document(kb, file_objs)
79
+ if err:
80
+ return get_json_result(
81
+ data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
82
+ return get_json_result(data=True)
83
+
84
+
85
+ @manager.route('/web_crawl', methods=['POST'])
86
+ @login_required
87
+ @validate_request("kb_id", "name", "url")
88
+ def web_crawl():
89
+ kb_id = request.form.get("kb_id")
90
+ if not kb_id:
91
+ return get_json_result(
92
+ data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
93
+ name = request.form.get("name")
94
+ url = request.form.get("url")
95
+ if not is_valid_url(url):
96
+ return get_json_result(
97
+ data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
98
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
99
+ if not e:
100
+ raise LookupError("Can't find this knowledgebase!")
101
+
102
+ blob = html2pdf(url)
103
+ if not blob: return server_error_response(ValueError("Download failure."))
104
+
105
+ root_folder = FileService.get_root_folder(current_user.id)
106
+ pf_id = root_folder["id"]
107
+ FileService.init_knowledgebase_docs(pf_id, current_user.id)
108
+ kb_root_folder = FileService.get_kb_folder(current_user.id)
109
+ kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
110
+
111
+ try:
112
+ filename = duplicate_name(
113
+ DocumentService.query,
114
+ name=name + ".pdf",
115
+ kb_id=kb.id)
116
+ filetype = filename_type(filename)
117
+ if filetype == FileType.OTHER.value:
118
+ raise RuntimeError("This type of file has not been supported yet!")
119
+
120
+ location = filename
121
+ while MINIO.obj_exist(kb_id, location):
122
+ location += "_"
123
+ MINIO.put(kb_id, location, blob)
124
+ doc = {
125
+ "id": get_uuid(),
126
+ "kb_id": kb.id,
127
+ "parser_id": kb.parser_id,
128
+ "parser_config": kb.parser_config,
129
+ "created_by": current_user.id,
130
+ "type": filetype,
131
+ "name": filename,
132
+ "location": location,
133
+ "size": len(blob),
134
+ "thumbnail": thumbnail(filename, blob)
135
+ }
136
+ if doc["type"] == FileType.VISUAL:
137
+ doc["parser_id"] = ParserType.PICTURE.value
138
+ if doc["type"] == FileType.AURAL:
139
+ doc["parser_id"] = ParserType.AUDIO.value
140
+ if re.search(r"\.(ppt|pptx|pages)$", filename):
141
+ doc["parser_id"] = ParserType.PRESENTATION.value
142
+ DocumentService.insert(doc)
143
+ FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
144
+ except Exception as e:
145
+ return server_error_response(e)
146
+ return get_json_result(data=True)
147
+
148
+
149
+ @manager.route('/create', methods=['POST'])
150
+ @login_required
151
+ @validate_request("name", "kb_id")
152
+ def create():
153
+ req = request.json
154
+ kb_id = req["kb_id"]
155
+ if not kb_id:
156
+ return get_json_result(
157
+ data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
158
+
159
+ try:
160
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
161
+ if not e:
162
+ return get_data_error_result(
163
+ retmsg="Can't find this knowledgebase!")
164
+
165
+ if DocumentService.query(name=req["name"], kb_id=kb_id):
166
+ return get_data_error_result(
167
+ retmsg="Duplicated document name in the same knowledgebase.")
168
+
169
+ doc = DocumentService.insert({
170
+ "id": get_uuid(),
171
+ "kb_id": kb.id,
172
+ "parser_id": kb.parser_id,
173
+ "parser_config": kb.parser_config,
174
+ "created_by": current_user.id,
175
+ "type": FileType.VIRTUAL,
176
+ "name": req["name"],
177
+ "location": "",
178
+ "size": 0
179
+ })
180
+ return get_json_result(data=doc.to_json())
181
+ except Exception as e:
182
+ return server_error_response(e)
183
+
184
+
185
+ @manager.route('/list', methods=['GET'])
186
+ @login_required
187
+ def list_docs():
188
+ kb_id = request.args.get("kb_id")
189
+ if not kb_id:
190
+ return get_json_result(
191
+ data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
192
+ keywords = request.args.get("keywords", "")
193
+
194
+ page_number = int(request.args.get("page", 1))
195
+ items_per_page = int(request.args.get("page_size", 15))
196
+ orderby = request.args.get("orderby", "create_time")
197
+ desc = request.args.get("desc", True)
198
+ try:
199
+ docs, tol = DocumentService.get_by_kb_id(
200
+ kb_id, page_number, items_per_page, orderby, desc, keywords)
201
+ return get_json_result(data={"total": tol, "docs": docs})
202
+ except Exception as e:
203
+ return server_error_response(e)
204
+
205
+
206
+ @manager.route('/thumbnails', methods=['GET'])
207
+ @login_required
208
+ def thumbnails():
209
+ doc_ids = request.args.get("doc_ids").split(",")
210
+ if not doc_ids:
211
+ return get_json_result(
212
+ data=False, retmsg='Lack of "Document ID"', retcode=RetCode.ARGUMENT_ERROR)
213
+
214
+ try:
215
+ docs = DocumentService.get_thumbnails(doc_ids)
216
+ return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
217
+ except Exception as e:
218
+ return server_error_response(e)
219
+
220
+
221
+ @manager.route('/change_status', methods=['POST'])
222
+ @login_required
223
+ @validate_request("doc_id", "status")
224
+ def change_status():
225
+ req = request.json
226
+ if str(req["status"]) not in ["0", "1"]:
227
+ get_json_result(
228
+ data=False,
229
+ retmsg='"Status" must be either 0 or 1!',
230
+ retcode=RetCode.ARGUMENT_ERROR)
231
+
232
+ try:
233
+ e, doc = DocumentService.get_by_id(req["doc_id"])
234
+ if not e:
235
+ return get_data_error_result(retmsg="Document not found!")
236
+ e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
237
+ if not e:
238
+ return get_data_error_result(
239
+ retmsg="Can't find this knowledgebase!")
240
+
241
+ if not DocumentService.update_by_id(
242
+ req["doc_id"], {"status": str(req["status"])}):
243
+ return get_data_error_result(
244
+ retmsg="Database error (Document update)!")
245
+
246
+ if str(req["status"]) == "0":
247
+ ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
248
+ scripts="ctx._source.available_int=0;",
249
+ idxnm=search.index_name(
250
+ kb.tenant_id)
251
+ )
252
+ else:
253
+ ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
254
+ scripts="ctx._source.available_int=1;",
255
+ idxnm=search.index_name(
256
+ kb.tenant_id)
257
+ )
258
+ return get_json_result(data=True)
259
+ except Exception as e:
260
+ return server_error_response(e)
261
+
262
+
263
+ @manager.route('/rm', methods=['POST'])
264
+ @login_required
265
+ @validate_request("doc_id")
266
+ def rm():
267
+ req = request.json
268
+ doc_ids = req["doc_id"]
269
+ if isinstance(doc_ids, str): doc_ids = [doc_ids]
270
+ root_folder = FileService.get_root_folder(current_user.id)
271
+ pf_id = root_folder["id"]
272
+ FileService.init_knowledgebase_docs(pf_id, current_user.id)
273
+ errors = ""
274
+ for doc_id in doc_ids:
275
+ try:
276
+ e, doc = DocumentService.get_by_id(doc_id)
277
+ if not e:
278
+ return get_data_error_result(retmsg="Document not found!")
279
+ tenant_id = DocumentService.get_tenant_id(doc_id)
280
+ if not tenant_id:
281
+ return get_data_error_result(retmsg="Tenant not found!")
282
+
283
+ b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
284
+
285
+ if not DocumentService.remove_document(doc, tenant_id):
286
+ return get_data_error_result(
287
+ retmsg="Database error (Document removal)!")
288
+
289
+ f2d = File2DocumentService.get_by_document_id(doc_id)
290
+ FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
291
+ File2DocumentService.delete_by_document_id(doc_id)
292
+
293
+ MINIO.rm(b, n)
294
+ except Exception as e:
295
+ errors += str(e)
296
+
297
+ if errors:
298
+ return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
299
+
300
+ return get_json_result(data=True)
301
+
302
+
303
+ @manager.route('/run', methods=['POST'])
304
+ @login_required
305
+ @validate_request("doc_ids", "run")
306
+ def run():
307
+ req = request.json
308
+ try:
309
+ for id in req["doc_ids"]:
310
+ info = {"run": str(req["run"]), "progress": 0}
311
+ if str(req["run"]) == TaskStatus.RUNNING.value:
312
+ info["progress_msg"] = ""
313
+ info["chunk_num"] = 0
314
+ info["token_num"] = 0
315
+ DocumentService.update_by_id(id, info)
316
+ # if str(req["run"]) == TaskStatus.CANCEL.value:
317
+ tenant_id = DocumentService.get_tenant_id(id)
318
+ if not tenant_id:
319
+ return get_data_error_result(retmsg="Tenant not found!")
320
+ ELASTICSEARCH.deleteByQuery(
321
+ Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
322
+
323
+ if str(req["run"]) == TaskStatus.RUNNING.value:
324
+ TaskService.filter_delete([Task.doc_id == id])
325
+ e, doc = DocumentService.get_by_id(id)
326
+ doc = doc.to_dict()
327
+ doc["tenant_id"] = tenant_id
328
+ bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
329
+ queue_tasks(doc, bucket, name)
330
+
331
+ return get_json_result(data=True)
332
+ except Exception as e:
333
+ return server_error_response(e)
334
+
335
+
336
+ @manager.route('/rename', methods=['POST'])
337
+ @login_required
338
+ @validate_request("doc_id", "name")
339
+ def rename():
340
+ req = request.json
341
+ try:
342
+ e, doc = DocumentService.get_by_id(req["doc_id"])
343
+ if not e:
344
+ return get_data_error_result(retmsg="Document not found!")
345
+ if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
346
+ doc.name.lower()).suffix:
347
+ return get_json_result(
348
+ data=False,
349
+ retmsg="The extension of file can't be changed",
350
+ retcode=RetCode.ARGUMENT_ERROR)
351
+ for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
352
+ if d.name == req["name"]:
353
+ return get_data_error_result(
354
+ retmsg="Duplicated document name in the same knowledgebase.")
355
+
356
+ if not DocumentService.update_by_id(
357
+ req["doc_id"], {"name": req["name"]}):
358
+ return get_data_error_result(
359
+ retmsg="Database error (Document rename)!")
360
+
361
+ informs = File2DocumentService.get_by_document_id(req["doc_id"])
362
+ if informs:
363
+ e, file = FileService.get_by_id(informs[0].file_id)
364
+ FileService.update_by_id(file.id, {"name": req["name"]})
365
+
366
+ return get_json_result(data=True)
367
+ except Exception as e:
368
+ return server_error_response(e)
369
+
370
+
371
+ @manager.route('/get/<doc_id>', methods=['GET'])
372
+ # @login_required
373
+ def get(doc_id):
374
+ try:
375
+ e, doc = DocumentService.get_by_id(doc_id)
376
+ if not e:
377
+ return get_data_error_result(retmsg="Document not found!")
378
+
379
+ b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
380
+ response = flask.make_response(MINIO.get(b, n))
381
+
382
+ ext = re.search(r"\.([^.]+)$", doc.name)
383
+ if ext:
384
+ if doc.type == FileType.VISUAL.value:
385
+ response.headers.set('Content-Type', 'image/%s' % ext.group(1))
386
+ else:
387
+ response.headers.set(
388
+ 'Content-Type',
389
+ 'application/%s' %
390
+ ext.group(1))
391
+ return response
392
+ except Exception as e:
393
+ return server_error_response(e)
394
+
395
+
396
+ @manager.route('/change_parser', methods=['POST'])
397
+ @login_required
398
+ @validate_request("doc_id", "parser_id")
399
+ def change_parser():
400
+ req = request.json
401
+ try:
402
+ e, doc = DocumentService.get_by_id(req["doc_id"])
403
+ if not e:
404
+ return get_data_error_result(retmsg="Document not found!")
405
+ if doc.parser_id.lower() == req["parser_id"].lower():
406
+ if "parser_config" in req:
407
+ if req["parser_config"] == doc.parser_config:
408
+ return get_json_result(data=True)
409
+ else:
410
+ return get_json_result(data=True)
411
+
412
+ if doc.type == FileType.VISUAL or re.search(
413
+ r"\.(ppt|pptx|pages)$", doc.name):
414
+ return get_data_error_result(retmsg="Not supported yet!")
415
+
416
+ e = DocumentService.update_by_id(doc.id,
417
+ {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
418
+ "run": TaskStatus.UNSTART.value})
419
+ if not e:
420
+ return get_data_error_result(retmsg="Document not found!")
421
+ if "parser_config" in req:
422
+ DocumentService.update_parser_config(doc.id, req["parser_config"])
423
+ if doc.token_num > 0:
424
+ e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
425
+ doc.process_duation * -1)
426
+ if not e:
427
+ return get_data_error_result(retmsg="Document not found!")
428
+ tenant_id = DocumentService.get_tenant_id(req["doc_id"])
429
+ if not tenant_id:
430
+ return get_data_error_result(retmsg="Tenant not found!")
431
+ ELASTICSEARCH.deleteByQuery(
432
+ Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
433
+
434
+ return get_json_result(data=True)
435
+ except Exception as e:
436
+ return server_error_response(e)
437
+
438
+
439
+ @manager.route('/image/<image_id>', methods=['GET'])
440
+ # @login_required
441
+ def get_image(image_id):
442
+ try:
443
+ bkt, nm = image_id.split("-")
444
+ response = flask.make_response(MINIO.get(bkt, nm))
445
+ response.headers.set('Content-Type', 'image/JPEG')
446
+ return response
447
+ except Exception as e:
448
+ return server_error_response(e)
449
+
450
+
451
+ @manager.route('/upload_and_parse', methods=['POST'])
452
+ @login_required
453
+ @validate_request("conversation_id")
454
+ def upload_and_parse():
455
+ from rag.app import presentation, picture, naive, audio, email
456
+ if 'file' not in request.files:
457
+ return get_json_result(
458
+ data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
459
+
460
+ file_objs = request.files.getlist('file')
461
+ for file_obj in file_objs:
462
+ if file_obj.filename == '':
463
+ return get_json_result(
464
+ data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
465
+
466
+ e, conv = ConversationService.get_by_id(request.form.get("conversation_id"))
467
+ if not e:
468
+ return get_data_error_result(retmsg="Conversation not found!")
469
+ e, dia = DialogService.get_by_id(conv.dialog_id)
470
+ kb_id = dia.kb_ids[0]
471
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
472
+ if not e:
473
+ raise LookupError("Can't find this knowledgebase!")
474
+
475
+ idxnm = search.index_name(kb.tenant_id)
476
+ if not ELASTICSEARCH.indexExist(idxnm):
477
+ ELASTICSEARCH.createIdx(idxnm, json.load(
478
+ open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
479
+
480
+ embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
481
+
482
+ err, files = FileService.upload_document(kb, file_objs)
483
+ if err:
484
+ return get_json_result(
485
+ data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
486
+
487
+ def dummy(prog=None, msg=""):
488
+ pass
489
+
490
+ FACTORY = {
491
+ ParserType.PRESENTATION.value: presentation,
492
+ ParserType.PICTURE.value: picture,
493
+ ParserType.AUDIO.value: audio,
494
+ ParserType.EMAIL.value: email
495
+ }
496
+ parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
497
+ exe = ThreadPoolExecutor(max_workers=12)
498
+ threads = []
499
+ for d, blob in files:
500
+ kwargs = {
501
+ "callback": dummy,
502
+ "parser_config": parser_config,
503
+ "from_page": 0,
504
+ "to_page": 100000,
505
+ "tenant_id": kb.tenant_id,
506
+ "lang": kb.language
507
+ }
508
+ threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
509
+
510
+ for (docinfo,_), th in zip(files, threads):
511
+ docs = []
512
+ doc = {
513
+ "doc_id": docinfo["id"],
514
+ "kb_id": [kb.id]
515
+ }
516
+ for ck in th.result():
517
+ d = deepcopy(doc)
518
+ d.update(ck)
519
+ md5 = hashlib.md5()
520
+ md5.update((ck["content_with_weight"] +
521
+ str(d["doc_id"])).encode("utf-8"))
522
+ d["_id"] = md5.hexdigest()
523
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
524
+ d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
525
+ if not d.get("image"):
526
+ docs.append(d)
527
+ continue
528
+
529
+ output_buffer = BytesIO()
530
+ if isinstance(d["image"], bytes):
531
+ output_buffer = BytesIO(d["image"])
532
+ else:
533
+ d["image"].save(output_buffer, format='JPEG')
534
+
535
+ MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
536
+ d["img_id"] = "{}-{}".format(kb.id, d["_id"])
537
+ del d["image"]
538
+ docs.append(d)
539
+
540
+ parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
541
+ docids = [d["id"] for d, _ in files]
542
+ chunk_counts = {id: 0 for id in docids}
543
+ token_counts = {id: 0 for id in docids}
544
+ es_bulk_size = 64
545
+
546
+ def embedding(doc_id, cnts, batch_size=16):
547
+ nonlocal embd_mdl, chunk_counts, token_counts
548
+ vects = []
549
+ for i in range(0, len(cnts), batch_size):
550
+ vts, c = embd_mdl.encode(cnts[i: i + batch_size])
551
+ vects.extend(vts.tolist())
552
+ chunk_counts[doc_id] += len(cnts[i:i + batch_size])
553
+ token_counts[doc_id] += c
554
+ return vects
555
+
556
+ _, tenant = TenantService.get_by_id(kb.tenant_id)
557
+ llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
558
+ for doc_id in docids:
559
+ cks = [c for c in docs if c["doc_id"] == doc_id]
560
+
561
+ if False and parser_ids[doc_id] != ParserType.PICTURE.value:
562
+ mindmap = MindMapExtractor(llm_bdl)
563
+ try:
564
+ mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
565
+ if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
566
+ cks.append({
567
+ "doc_id": doc_id,
568
+ "kb_id": [kb.id],
569
+ "content_with_weight": mind_map,
570
+ "knowledge_graph_kwd": "mind_map"
571
+ })
572
+ except Exception as e:
573
+ stat_logger.error("Mind map generation error:", traceback.format_exc())
574
+
575
+ vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
576
+ assert len(cks) == len(vects)
577
+ for i, d in enumerate(cks):
578
+ v = vects[i]
579
+ d["q_%d_vec" % len(v)] = v
580
+ for b in range(0, len(cks), es_bulk_size):
581
+ ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
582
+
583
+ DocumentService.increment_chunk_num(
584
+ doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
585
+
586
+ return get_json_result(data=[d["id"] for d,_ in files])
api/apps/kb_app.py CHANGED
@@ -1,153 +1,153 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from elasticsearch_dsl import Q
17
- from flask import request
18
- from flask_login import login_required, current_user
19
-
20
- from api.db.services import duplicate_name
21
- from api.db.services.document_service import DocumentService
22
- from api.db.services.file2document_service import File2DocumentService
23
- from api.db.services.file_service import FileService
24
- from api.db.services.user_service import TenantService, UserTenantService
25
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
- from api.utils import get_uuid, get_format_time
27
- from api.db import StatusEnum, UserTenantRole, FileSource
28
- from api.db.services.knowledgebase_service import KnowledgebaseService
29
- from api.db.db_models import Knowledgebase, File
30
- from api.settings import stat_logger, RetCode
31
- from api.utils.api_utils import get_json_result
32
- from rag.nlp import search
33
- from rag.utils.es_conn import ELASTICSEARCH
34
-
35
-
36
- @manager.route('/create', methods=['post'])
37
- @login_required
38
- @validate_request("name")
39
- def create():
40
- req = request.json
41
- req["name"] = req["name"].strip()
42
- req["name"] = duplicate_name(
43
- KnowledgebaseService.query,
44
- name=req["name"],
45
- tenant_id=current_user.id,
46
- status=StatusEnum.VALID.value)
47
- try:
48
- req["id"] = get_uuid()
49
- req["tenant_id"] = current_user.id
50
- req["created_by"] = current_user.id
51
- e, t = TenantService.get_by_id(current_user.id)
52
- if not e:
53
- return get_data_error_result(retmsg="Tenant not found.")
54
- req["embd_id"] = t.embd_id
55
- if not KnowledgebaseService.save(**req):
56
- return get_data_error_result()
57
- return get_json_result(data={"kb_id": req["id"]})
58
- except Exception as e:
59
- return server_error_response(e)
60
-
61
-
62
- @manager.route('/update', methods=['post'])
63
- @login_required
64
- @validate_request("kb_id", "name", "description", "permission", "parser_id")
65
- def update():
66
- req = request.json
67
- req["name"] = req["name"].strip()
68
- try:
69
- if not KnowledgebaseService.query(
70
- created_by=current_user.id, id=req["kb_id"]):
71
- return get_json_result(
72
- data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
73
-
74
- e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
75
- if not e:
76
- return get_data_error_result(
77
- retmsg="Can't find this knowledgebase!")
78
-
79
- if req["name"].lower() != kb.name.lower() \
80
- and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
81
- return get_data_error_result(
82
- retmsg="Duplicated knowledgebase name.")
83
-
84
- del req["kb_id"]
85
- if not KnowledgebaseService.update_by_id(kb.id, req):
86
- return get_data_error_result()
87
-
88
- e, kb = KnowledgebaseService.get_by_id(kb.id)
89
- if not e:
90
- return get_data_error_result(
91
- retmsg="Database error (Knowledgebase rename)!")
92
-
93
- return get_json_result(data=kb.to_json())
94
- except Exception as e:
95
- return server_error_response(e)
96
-
97
-
98
- @manager.route('/detail', methods=['GET'])
99
- @login_required
100
- def detail():
101
- kb_id = request.args["kb_id"]
102
- try:
103
- kb = KnowledgebaseService.get_detail(kb_id)
104
- if not kb:
105
- return get_data_error_result(
106
- retmsg="Can't find this knowledgebase!")
107
- return get_json_result(data=kb)
108
- except Exception as e:
109
- return server_error_response(e)
110
-
111
-
112
- @manager.route('/list', methods=['GET'])
113
- @login_required
114
- def list_kbs():
115
- page_number = request.args.get("page", 1)
116
- items_per_page = request.args.get("page_size", 150)
117
- orderby = request.args.get("orderby", "create_time")
118
- desc = request.args.get("desc", True)
119
- try:
120
- tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
121
- kbs = KnowledgebaseService.get_by_tenant_ids(
122
- [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
123
- return get_json_result(data=kbs)
124
- except Exception as e:
125
- return server_error_response(e)
126
-
127
-
128
- @manager.route('/rm', methods=['post'])
129
- @login_required
130
- @validate_request("kb_id")
131
- def rm():
132
- req = request.json
133
- try:
134
- kbs = KnowledgebaseService.query(
135
- created_by=current_user.id, id=req["kb_id"])
136
- if not kbs:
137
- return get_json_result(
138
- data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
139
-
140
- for doc in DocumentService.query(kb_id=req["kb_id"]):
141
- if not DocumentService.remove_document(doc, kbs[0].tenant_id):
142
- return get_data_error_result(
143
- retmsg="Database error (Document removal)!")
144
- f2d = File2DocumentService.get_by_document_id(doc.id)
145
- FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
146
- File2DocumentService.delete_by_document_id(doc.id)
147
-
148
- if not KnowledgebaseService.delete_by_id(req["kb_id"]):
149
- return get_data_error_result(
150
- retmsg="Database error (Knowledgebase removal)!")
151
- return get_json_result(data=True)
152
- except Exception as e:
153
- return server_error_response(e)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from elasticsearch_dsl import Q
17
+ from flask import request
18
+ from flask_login import login_required, current_user
19
+
20
+ from api.db.services import duplicate_name
21
+ from api.db.services.document_service import DocumentService
22
+ from api.db.services.file2document_service import File2DocumentService
23
+ from api.db.services.file_service import FileService
24
+ from api.db.services.user_service import TenantService, UserTenantService
25
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
+ from api.utils import get_uuid, get_format_time
27
+ from api.db import StatusEnum, UserTenantRole, FileSource
28
+ from api.db.services.knowledgebase_service import KnowledgebaseService
29
+ from api.db.db_models import Knowledgebase, File
30
+ from api.settings import stat_logger, RetCode
31
+ from api.utils.api_utils import get_json_result
32
+ from rag.nlp import search
33
+ from rag.utils.es_conn import ELASTICSEARCH
34
+
35
+
36
+ @manager.route('/create', methods=['post'])
37
+ @login_required
38
+ @validate_request("name")
39
+ def create():
40
+ req = request.json
41
+ req["name"] = req["name"].strip()
42
+ req["name"] = duplicate_name(
43
+ KnowledgebaseService.query,
44
+ name=req["name"],
45
+ tenant_id=current_user.id,
46
+ status=StatusEnum.VALID.value)
47
+ try:
48
+ req["id"] = get_uuid()
49
+ req["tenant_id"] = current_user.id
50
+ req["created_by"] = current_user.id
51
+ e, t = TenantService.get_by_id(current_user.id)
52
+ if not e:
53
+ return get_data_error_result(retmsg="Tenant not found.")
54
+ req["embd_id"] = t.embd_id
55
+ if not KnowledgebaseService.save(**req):
56
+ return get_data_error_result()
57
+ return get_json_result(data={"kb_id": req["id"]})
58
+ except Exception as e:
59
+ return server_error_response(e)
60
+
61
+
62
+ @manager.route('/update', methods=['post'])
63
+ @login_required
64
+ @validate_request("kb_id", "name", "description", "permission", "parser_id")
65
+ def update():
66
+ req = request.json
67
+ req["name"] = req["name"].strip()
68
+ try:
69
+ if not KnowledgebaseService.query(
70
+ created_by=current_user.id, id=req["kb_id"]):
71
+ return get_json_result(
72
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
73
+
74
+ e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
75
+ if not e:
76
+ return get_data_error_result(
77
+ retmsg="Can't find this knowledgebase!")
78
+
79
+ if req["name"].lower() != kb.name.lower() \
80
+ and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
81
+ return get_data_error_result(
82
+ retmsg="Duplicated knowledgebase name.")
83
+
84
+ del req["kb_id"]
85
+ if not KnowledgebaseService.update_by_id(kb.id, req):
86
+ return get_data_error_result()
87
+
88
+ e, kb = KnowledgebaseService.get_by_id(kb.id)
89
+ if not e:
90
+ return get_data_error_result(
91
+ retmsg="Database error (Knowledgebase rename)!")
92
+
93
+ return get_json_result(data=kb.to_json())
94
+ except Exception as e:
95
+ return server_error_response(e)
96
+
97
+
98
+ @manager.route('/detail', methods=['GET'])
99
+ @login_required
100
+ def detail():
101
+ kb_id = request.args["kb_id"]
102
+ try:
103
+ kb = KnowledgebaseService.get_detail(kb_id)
104
+ if not kb:
105
+ return get_data_error_result(
106
+ retmsg="Can't find this knowledgebase!")
107
+ return get_json_result(data=kb)
108
+ except Exception as e:
109
+ return server_error_response(e)
110
+
111
+
112
+ @manager.route('/list', methods=['GET'])
113
+ @login_required
114
+ def list_kbs():
115
+ page_number = request.args.get("page", 1)
116
+ items_per_page = request.args.get("page_size", 150)
117
+ orderby = request.args.get("orderby", "create_time")
118
+ desc = request.args.get("desc", True)
119
+ try:
120
+ tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
121
+ kbs = KnowledgebaseService.get_by_tenant_ids(
122
+ [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
123
+ return get_json_result(data=kbs)
124
+ except Exception as e:
125
+ return server_error_response(e)
126
+
127
+
128
+ @manager.route('/rm', methods=['post'])
129
+ @login_required
130
+ @validate_request("kb_id")
131
+ def rm():
132
+ req = request.json
133
+ try:
134
+ kbs = KnowledgebaseService.query(
135
+ created_by=current_user.id, id=req["kb_id"])
136
+ if not kbs:
137
+ return get_json_result(
138
+ data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
139
+
140
+ for doc in DocumentService.query(kb_id=req["kb_id"]):
141
+ if not DocumentService.remove_document(doc, kbs[0].tenant_id):
142
+ return get_data_error_result(
143
+ retmsg="Database error (Document removal)!")
144
+ f2d = File2DocumentService.get_by_document_id(doc.id)
145
+ FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
146
+ File2DocumentService.delete_by_document_id(doc.id)
147
+
148
+ if not KnowledgebaseService.delete_by_id(req["kb_id"]):
149
+ return get_data_error_result(
150
+ retmsg="Database error (Knowledgebase removal)!")
151
+ return get_json_result(data=True)
152
+ except Exception as e:
153
+ return server_error_response(e)
api/apps/llm_app.py CHANGED
@@ -1,279 +1,279 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from flask import request
17
- from flask_login import login_required, current_user
18
- from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
19
- from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
20
- from api.db import StatusEnum, LLMType
21
- from api.db.db_models import TenantLLM
22
- from api.utils.api_utils import get_json_result
23
- from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
24
- import requests
25
- import ast
26
-
27
- @manager.route('/factories', methods=['GET'])
28
- @login_required
29
- def factories():
30
- try:
31
- fac = LLMFactoriesService.get_all()
32
- return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
33
- except Exception as e:
34
- return server_error_response(e)
35
-
36
-
37
- @manager.route('/set_api_key', methods=['POST'])
38
- @login_required
39
- @validate_request("llm_factory", "api_key")
40
- def set_api_key():
41
- req = request.json
42
- # test if api key works
43
- chat_passed, embd_passed, rerank_passed = False, False, False
44
- factory = req["llm_factory"]
45
- msg = ""
46
- for llm in LLMService.query(fid=factory):
47
- if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
48
- mdl = EmbeddingModel[factory](
49
- req["api_key"], llm.llm_name, base_url=req.get("base_url"))
50
- try:
51
- arr, tc = mdl.encode(["Test if the api key is available"])
52
- if len(arr[0]) == 0:
53
- raise Exception("Fail")
54
- embd_passed = True
55
- except Exception as e:
56
- msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
57
- elif not chat_passed and llm.model_type == LLMType.CHAT.value:
58
- mdl = ChatModel[factory](
59
- req["api_key"], llm.llm_name, base_url=req.get("base_url"))
60
- try:
61
- m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
62
- {"temperature": 0.9,'max_tokens':50})
63
- if m.find("**ERROR**") >=0:
64
- raise Exception(m)
65
- except Exception as e:
66
- msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
67
- e)
68
- chat_passed = True
69
- elif not rerank_passed and llm.model_type == LLMType.RERANK:
70
- mdl = RerankModel[factory](
71
- req["api_key"], llm.llm_name, base_url=req.get("base_url"))
72
- try:
73
- arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
74
- if len(arr) == 0 or tc == 0:
75
- raise Exception("Fail")
76
- except Exception as e:
77
- msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
78
- e)
79
- rerank_passed = True
80
-
81
- if msg:
82
- return get_data_error_result(retmsg=msg)
83
-
84
- llm = {
85
- "api_key": req["api_key"],
86
- "api_base": req.get("base_url", "")
87
- }
88
- for n in ["model_type", "llm_name"]:
89
- if n in req:
90
- llm[n] = req[n]
91
-
92
- if not TenantLLMService.filter_update(
93
- [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
94
- for llm in LLMService.query(fid=factory):
95
- TenantLLMService.save(
96
- tenant_id=current_user.id,
97
- llm_factory=factory,
98
- llm_name=llm.llm_name,
99
- model_type=llm.model_type,
100
- api_key=req["api_key"],
101
- api_base=req.get("base_url", "")
102
- )
103
-
104
- return get_json_result(data=True)
105
-
106
-
107
- @manager.route('/add_llm', methods=['POST'])
108
- @login_required
109
- @validate_request("llm_factory", "llm_name", "model_type")
110
- def add_llm():
111
- req = request.json
112
- factory = req["llm_factory"]
113
-
114
- if factory == "VolcEngine":
115
- # For VolcEngine, due to its special authentication method
116
- # Assemble volc_ak, volc_sk, endpoint_id into api_key
117
- temp = list(ast.literal_eval(req["llm_name"]).items())[0]
118
- llm_name = temp[0]
119
- endpoint_id = temp[1]
120
- api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
121
- f'"volc_sk": "{req.get("volc_sk", "")}", ' \
122
- f'"ep_id": "{endpoint_id}", ' + '}'
123
- elif factory == "Bedrock":
124
- # For Bedrock, due to its special authentication method
125
- # Assemble bedrock_ak, bedrock_sk, bedrock_region
126
- llm_name = req["llm_name"]
127
- api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
128
- f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
129
- f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
130
- elif factory == "LocalAI":
131
- llm_name = req["llm_name"]+"___LocalAI"
132
- api_key = "xxxxxxxxxxxxxxx"
133
- elif factory == "OpenAI-API-Compatible":
134
- llm_name = req["llm_name"]+"___OpenAI-API"
135
- api_key = req.get("api_key","xxxxxxxxxxxxxxx")
136
- else:
137
- llm_name = req["llm_name"]
138
- api_key = req.get("api_key","xxxxxxxxxxxxxxx")
139
-
140
- llm = {
141
- "tenant_id": current_user.id,
142
- "llm_factory": factory,
143
- "model_type": req["model_type"],
144
- "llm_name": llm_name,
145
- "api_base": req.get("api_base", ""),
146
- "api_key": api_key
147
- }
148
-
149
- msg = ""
150
- if llm["model_type"] == LLMType.EMBEDDING.value:
151
- mdl = EmbeddingModel[factory](
152
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
153
- model_name=llm["llm_name"],
154
- base_url=llm["api_base"])
155
- try:
156
- arr, tc = mdl.encode(["Test if the api key is available"])
157
- if len(arr[0]) == 0 or tc == 0:
158
- raise Exception("Fail")
159
- except Exception as e:
160
- msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
161
- elif llm["model_type"] == LLMType.CHAT.value:
162
- mdl = ChatModel[factory](
163
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
164
- model_name=llm["llm_name"],
165
- base_url=llm["api_base"]
166
- )
167
- try:
168
- m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
169
- "temperature": 0.9})
170
- if not tc:
171
- raise Exception(m)
172
- except Exception as e:
173
- msg += f"\nFail to access model({llm['llm_name']})." + str(
174
- e)
175
- elif llm["model_type"] == LLMType.RERANK:
176
- mdl = RerankModel[factory](
177
- key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
178
- )
179
- try:
180
- arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
181
- if len(arr) == 0 or tc == 0:
182
- raise Exception("Not known.")
183
- except Exception as e:
184
- msg += f"\nFail to access model({llm['llm_name']})." + str(
185
- e)
186
- elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
187
- mdl = CvModel[factory](
188
- key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]
189
- )
190
- try:
191
- img_url = (
192
- "https://upload.wikimedia.org/wikipedia/comm"
193
- "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
194
- "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
195
- )
196
- res = requests.get(img_url)
197
- if res.status_code == 200:
198
- m, tc = mdl.describe(res.content)
199
- if not tc:
200
- raise Exception(m)
201
- else:
202
- pass
203
- except Exception as e:
204
- msg += f"\nFail to access model({llm['llm_name']})." + str(e)
205
- else:
206
- # TODO: check other type of models
207
- pass
208
-
209
- if msg:
210
- return get_data_error_result(retmsg=msg)
211
-
212
- if not TenantLLMService.filter_update(
213
- [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
214
- TenantLLMService.save(**llm)
215
-
216
- return get_json_result(data=True)
217
-
218
-
219
- @manager.route('/delete_llm', methods=['POST'])
220
- @login_required
221
- @validate_request("llm_factory", "llm_name")
222
- def delete_llm():
223
- req = request.json
224
- TenantLLMService.filter_delete(
225
- [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
226
- return get_json_result(data=True)
227
-
228
-
229
- @manager.route('/my_llms', methods=['GET'])
230
- @login_required
231
- def my_llms():
232
- try:
233
- res = {}
234
- for o in TenantLLMService.get_my_llms(current_user.id):
235
- if o["llm_factory"] not in res:
236
- res[o["llm_factory"]] = {
237
- "tags": o["tags"],
238
- "llm": []
239
- }
240
- res[o["llm_factory"]]["llm"].append({
241
- "type": o["model_type"],
242
- "name": o["llm_name"],
243
- "used_token": o["used_tokens"]
244
- })
245
- return get_json_result(data=res)
246
- except Exception as e:
247
- return server_error_response(e)
248
-
249
-
250
- @manager.route('/list', methods=['GET'])
251
- @login_required
252
- def list_app():
253
- model_type = request.args.get("model_type")
254
- try:
255
- objs = TenantLLMService.query(tenant_id=current_user.id)
256
- facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
257
- llms = LLMService.get_all()
258
- llms = [m.to_dict()
259
- for m in llms if m.status == StatusEnum.VALID.value]
260
- for m in llms:
261
- m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
262
-
263
- llm_set = set([m["llm_name"] for m in llms])
264
- for o in objs:
265
- if not o.api_key:continue
266
- if o.llm_name in llm_set:continue
267
- llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
268
-
269
- res = {}
270
- for m in llms:
271
- if model_type and m["model_type"].find(model_type)<0:
272
- continue
273
- if m["fid"] not in res:
274
- res[m["fid"]] = []
275
- res[m["fid"]].append(m)
276
-
277
- return get_json_result(data=res)
278
- except Exception as e:
279
- return server_error_response(e)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from flask import request
17
+ from flask_login import login_required, current_user
18
+ from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
19
+ from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
20
+ from api.db import StatusEnum, LLMType
21
+ from api.db.db_models import TenantLLM
22
+ from api.utils.api_utils import get_json_result
23
+ from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
24
+ import requests
25
+ import ast
26
+
27
+ @manager.route('/factories', methods=['GET'])
28
+ @login_required
29
+ def factories():
30
+ try:
31
+ fac = LLMFactoriesService.get_all()
32
+ return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
33
+ except Exception as e:
34
+ return server_error_response(e)
35
+
36
+
37
+ @manager.route('/set_api_key', methods=['POST'])
38
+ @login_required
39
+ @validate_request("llm_factory", "api_key")
40
+ def set_api_key():
41
+ req = request.json
42
+ # test if api key works
43
+ chat_passed, embd_passed, rerank_passed = False, False, False
44
+ factory = req["llm_factory"]
45
+ msg = ""
46
+ for llm in LLMService.query(fid=factory):
47
+ if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
48
+ mdl = EmbeddingModel[factory](
49
+ req["api_key"], llm.llm_name, base_url=req.get("base_url"))
50
+ try:
51
+ arr, tc = mdl.encode(["Test if the api key is available"])
52
+ if len(arr[0]) == 0:
53
+ raise Exception("Fail")
54
+ embd_passed = True
55
+ except Exception as e:
56
+ msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
57
+ elif not chat_passed and llm.model_type == LLMType.CHAT.value:
58
+ mdl = ChatModel[factory](
59
+ req["api_key"], llm.llm_name, base_url=req.get("base_url"))
60
+ try:
61
+ m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
62
+ {"temperature": 0.9,'max_tokens':50})
63
+ if m.find("**ERROR**") >=0:
64
+ raise Exception(m)
65
+ except Exception as e:
66
+ msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
67
+ e)
68
+ chat_passed = True
69
+ elif not rerank_passed and llm.model_type == LLMType.RERANK:
70
+ mdl = RerankModel[factory](
71
+ req["api_key"], llm.llm_name, base_url=req.get("base_url"))
72
+ try:
73
+ arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
74
+ if len(arr) == 0 or tc == 0:
75
+ raise Exception("Fail")
76
+ except Exception as e:
77
+ msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
78
+ e)
79
+ rerank_passed = True
80
+
81
+ if msg:
82
+ return get_data_error_result(retmsg=msg)
83
+
84
+ llm = {
85
+ "api_key": req["api_key"],
86
+ "api_base": req.get("base_url", "")
87
+ }
88
+ for n in ["model_type", "llm_name"]:
89
+ if n in req:
90
+ llm[n] = req[n]
91
+
92
+ if not TenantLLMService.filter_update(
93
+ [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
94
+ for llm in LLMService.query(fid=factory):
95
+ TenantLLMService.save(
96
+ tenant_id=current_user.id,
97
+ llm_factory=factory,
98
+ llm_name=llm.llm_name,
99
+ model_type=llm.model_type,
100
+ api_key=req["api_key"],
101
+ api_base=req.get("base_url", "")
102
+ )
103
+
104
+ return get_json_result(data=True)
105
+
106
+
107
+ @manager.route('/add_llm', methods=['POST'])
108
+ @login_required
109
+ @validate_request("llm_factory", "llm_name", "model_type")
110
+ def add_llm():
111
+ req = request.json
112
+ factory = req["llm_factory"]
113
+
114
+ if factory == "VolcEngine":
115
+ # For VolcEngine, due to its special authentication method
116
+ # Assemble volc_ak, volc_sk, endpoint_id into api_key
117
+ temp = list(ast.literal_eval(req["llm_name"]).items())[0]
118
+ llm_name = temp[0]
119
+ endpoint_id = temp[1]
120
+ api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
121
+ f'"volc_sk": "{req.get("volc_sk", "")}", ' \
122
+ f'"ep_id": "{endpoint_id}", ' + '}'
123
+ elif factory == "Bedrock":
124
+ # For Bedrock, due to its special authentication method
125
+ # Assemble bedrock_ak, bedrock_sk, bedrock_region
126
+ llm_name = req["llm_name"]
127
+ api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
128
+ f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
129
+ f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
130
+ elif factory == "LocalAI":
131
+ llm_name = req["llm_name"]+"___LocalAI"
132
+ api_key = "xxxxxxxxxxxxxxx"
133
+ elif factory == "OpenAI-API-Compatible":
134
+ llm_name = req["llm_name"]+"___OpenAI-API"
135
+ api_key = req.get("api_key","xxxxxxxxxxxxxxx")
136
+ else:
137
+ llm_name = req["llm_name"]
138
+ api_key = req.get("api_key","xxxxxxxxxxxxxxx")
139
+
140
+ llm = {
141
+ "tenant_id": current_user.id,
142
+ "llm_factory": factory,
143
+ "model_type": req["model_type"],
144
+ "llm_name": llm_name,
145
+ "api_base": req.get("api_base", ""),
146
+ "api_key": api_key
147
+ }
148
+
149
+ msg = ""
150
+ if llm["model_type"] == LLMType.EMBEDDING.value:
151
+ mdl = EmbeddingModel[factory](
152
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
153
+ model_name=llm["llm_name"],
154
+ base_url=llm["api_base"])
155
+ try:
156
+ arr, tc = mdl.encode(["Test if the api key is available"])
157
+ if len(arr[0]) == 0 or tc == 0:
158
+ raise Exception("Fail")
159
+ except Exception as e:
160
+ msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
161
+ elif llm["model_type"] == LLMType.CHAT.value:
162
+ mdl = ChatModel[factory](
163
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
164
+ model_name=llm["llm_name"],
165
+ base_url=llm["api_base"]
166
+ )
167
+ try:
168
+ m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
169
+ "temperature": 0.9})
170
+ if not tc:
171
+ raise Exception(m)
172
+ except Exception as e:
173
+ msg += f"\nFail to access model({llm['llm_name']})." + str(
174
+ e)
175
+ elif llm["model_type"] == LLMType.RERANK:
176
+ mdl = RerankModel[factory](
177
+ key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
178
+ )
179
+ try:
180
+ arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
181
+ if len(arr) == 0 or tc == 0:
182
+ raise Exception("Not known.")
183
+ except Exception as e:
184
+ msg += f"\nFail to access model({llm['llm_name']})." + str(
185
+ e)
186
+ elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
187
+ mdl = CvModel[factory](
188
+ key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]
189
+ )
190
+ try:
191
+ img_url = (
192
+ "https://upload.wikimedia.org/wikipedia/comm"
193
+ "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
194
+ "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
195
+ )
196
+ res = requests.get(img_url)
197
+ if res.status_code == 200:
198
+ m, tc = mdl.describe(res.content)
199
+ if not tc:
200
+ raise Exception(m)
201
+ else:
202
+ pass
203
+ except Exception as e:
204
+ msg += f"\nFail to access model({llm['llm_name']})." + str(e)
205
+ else:
206
+ # TODO: check other type of models
207
+ pass
208
+
209
+ if msg:
210
+ return get_data_error_result(retmsg=msg)
211
+
212
+ if not TenantLLMService.filter_update(
213
+ [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
214
+ TenantLLMService.save(**llm)
215
+
216
+ return get_json_result(data=True)
217
+
218
+
219
+ @manager.route('/delete_llm', methods=['POST'])
220
+ @login_required
221
+ @validate_request("llm_factory", "llm_name")
222
+ def delete_llm():
223
+ req = request.json
224
+ TenantLLMService.filter_delete(
225
+ [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
226
+ return get_json_result(data=True)
227
+
228
+
229
+ @manager.route('/my_llms', methods=['GET'])
230
+ @login_required
231
+ def my_llms():
232
+ try:
233
+ res = {}
234
+ for o in TenantLLMService.get_my_llms(current_user.id):
235
+ if o["llm_factory"] not in res:
236
+ res[o["llm_factory"]] = {
237
+ "tags": o["tags"],
238
+ "llm": []
239
+ }
240
+ res[o["llm_factory"]]["llm"].append({
241
+ "type": o["model_type"],
242
+ "name": o["llm_name"],
243
+ "used_token": o["used_tokens"]
244
+ })
245
+ return get_json_result(data=res)
246
+ except Exception as e:
247
+ return server_error_response(e)
248
+
249
+
250
+ @manager.route('/list', methods=['GET'])
251
+ @login_required
252
+ def list_app():
253
+ model_type = request.args.get("model_type")
254
+ try:
255
+ objs = TenantLLMService.query(tenant_id=current_user.id)
256
+ facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
257
+ llms = LLMService.get_all()
258
+ llms = [m.to_dict()
259
+ for m in llms if m.status == StatusEnum.VALID.value]
260
+ for m in llms:
261
+ m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
262
+
263
+ llm_set = set([m["llm_name"] for m in llms])
264
+ for o in objs:
265
+ if not o.api_key:continue
266
+ if o.llm_name in llm_set:continue
267
+ llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
268
+
269
+ res = {}
270
+ for m in llms:
271
+ if model_type and m["model_type"].find(model_type)<0:
272
+ continue
273
+ if m["fid"] not in res:
274
+ res[m["fid"]] = []
275
+ res[m["fid"]].append(m)
276
+
277
+ return get_json_result(data=res)
278
+ except Exception as e:
279
+ return server_error_response(e)
api/apps/user_app.py CHANGED
@@ -1,391 +1,391 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import json
17
- import re
18
- from datetime import datetime
19
-
20
- from flask import request, session, redirect
21
- from werkzeug.security import generate_password_hash, check_password_hash
22
- from flask_login import login_required, current_user, login_user, logout_user
23
-
24
- from api.db.db_models import TenantLLM
25
- from api.db.services.llm_service import TenantLLMService, LLMService
26
- from api.utils.api_utils import server_error_response, validate_request
27
- from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
28
- from api.db import UserTenantRole, LLMType, FileType
29
- from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
30
- API_KEY, \
31
- LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
32
- from api.db.services.user_service import UserService, TenantService, UserTenantService
33
- from api.db.services.file_service import FileService
34
- from api.settings import stat_logger
35
- from api.utils.api_utils import get_json_result, cors_reponse
36
-
37
-
38
- @manager.route('/login', methods=['POST', 'GET'])
39
- def login():
40
- login_channel = "password"
41
- if not request.json:
42
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
43
- retmsg='Unautherized!')
44
-
45
- email = request.json.get('email', "")
46
- users = UserService.query(email=email)
47
- if not users:
48
- return get_json_result(
49
- data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
50
-
51
- password = request.json.get('password')
52
- try:
53
- password = decrypt(password)
54
- except BaseException:
55
- return get_json_result(
56
- data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
57
-
58
- user = UserService.query_user(email, password)
59
- if user:
60
- response_data = user.to_json()
61
- user.access_token = get_uuid()
62
- login_user(user)
63
- user.update_time = current_timestamp(),
64
- user.update_date = datetime_format(datetime.now()),
65
- user.save()
66
- msg = "Welcome back!"
67
- return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
68
- else:
69
- return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
70
- retmsg='Email and Password do not match!')
71
-
72
-
73
- @manager.route('/github_callback', methods=['GET'])
74
- def github_callback():
75
- import requests
76
- res = requests.post(GITHUB_OAUTH.get("url"), data={
77
- "client_id": GITHUB_OAUTH.get("client_id"),
78
- "client_secret": GITHUB_OAUTH.get("secret_key"),
79
- "code": request.args.get('code')
80
- }, headers={"Accept": "application/json"})
81
- res = res.json()
82
- if "error" in res:
83
- return redirect("/?error=%s" % res["error_description"])
84
-
85
- if "user:email" not in res["scope"].split(","):
86
- return redirect("/?error=user:email not in scope")
87
-
88
- session["access_token"] = res["access_token"]
89
- session["access_token_from"] = "github"
90
- userinfo = user_info_from_github(session["access_token"])
91
- users = UserService.query(email=userinfo["email"])
92
- user_id = get_uuid()
93
- if not users:
94
- try:
95
- try:
96
- avatar = download_img(userinfo["avatar_url"])
97
- except Exception as e:
98
- stat_logger.exception(e)
99
- avatar = ""
100
- users = user_register(user_id, {
101
- "access_token": session["access_token"],
102
- "email": userinfo["email"],
103
- "avatar": avatar,
104
- "nickname": userinfo["login"],
105
- "login_channel": "github",
106
- "last_login_time": get_format_time(),
107
- "is_superuser": False,
108
- })
109
- if not users:
110
- raise Exception('Register user failure.')
111
- if len(users) > 1:
112
- raise Exception('Same E-mail exist!')
113
- user = users[0]
114
- login_user(user)
115
- return redirect("/?auth=%s" % user.get_id())
116
- except Exception as e:
117
- rollback_user_registration(user_id)
118
- stat_logger.exception(e)
119
- return redirect("/?error=%s" % str(e))
120
- user = users[0]
121
- user.access_token = get_uuid()
122
- login_user(user)
123
- user.save()
124
- return redirect("/?auth=%s" % user.get_id())
125
-
126
-
127
- @manager.route('/feishu_callback', methods=['GET'])
128
- def feishu_callback():
129
- import requests
130
- app_access_token_res = requests.post(FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps({
131
- "app_id": FEISHU_OAUTH.get("app_id"),
132
- "app_secret": FEISHU_OAUTH.get("app_secret")
133
- }), headers={"Content-Type": "application/json; charset=utf-8"})
134
- app_access_token_res = app_access_token_res.json()
135
- if app_access_token_res['code'] != 0:
136
- return redirect("/?error=%s" % app_access_token_res)
137
-
138
- res = requests.post(FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps({
139
- "grant_type": FEISHU_OAUTH.get("grant_type"),
140
- "code": request.args.get('code')
141
- }), headers={"Content-Type": "application/json; charset=utf-8",
142
- 'Authorization': f"Bearer {app_access_token_res['app_access_token']}"})
143
- res = res.json()
144
- if res['code'] != 0:
145
- return redirect("/?error=%s" % res["message"])
146
-
147
- if "contact:user.email:readonly" not in res["data"]["scope"].split(" "):
148
- return redirect("/?error=contact:user.email:readonly not in scope")
149
- session["access_token"] = res["data"]["access_token"]
150
- session["access_token_from"] = "feishu"
151
- userinfo = user_info_from_feishu(session["access_token"])
152
- users = UserService.query(email=userinfo["email"])
153
- user_id = get_uuid()
154
- if not users:
155
- try:
156
- try:
157
- avatar = download_img(userinfo["avatar_url"])
158
- except Exception as e:
159
- stat_logger.exception(e)
160
- avatar = ""
161
- users = user_register(user_id, {
162
- "access_token": session["access_token"],
163
- "email": userinfo["email"],
164
- "avatar": avatar,
165
- "nickname": userinfo["en_name"],
166
- "login_channel": "feishu",
167
- "last_login_time": get_format_time(),
168
- "is_superuser": False,
169
- })
170
- if not users:
171
- raise Exception('Register user failure.')
172
- if len(users) > 1:
173
- raise Exception('Same E-mail exist!')
174
- user = users[0]
175
- login_user(user)
176
- return redirect("/?auth=%s" % user.get_id())
177
- except Exception as e:
178
- rollback_user_registration(user_id)
179
- stat_logger.exception(e)
180
- return redirect("/?error=%s" % str(e))
181
- user = users[0]
182
- user.access_token = get_uuid()
183
- login_user(user)
184
- user.save()
185
- return redirect("/?auth=%s" % user.get_id())
186
-
187
-
188
- def user_info_from_feishu(access_token):
189
- import requests
190
- headers = {"Content-Type": "application/json; charset=utf-8",
191
- 'Authorization': f"Bearer {access_token}"}
192
- res = requests.get(
193
- f"https://open.feishu.cn/open-apis/authen/v1/user_info",
194
- headers=headers)
195
- user_info = res.json()["data"]
196
- user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
197
- return user_info
198
-
199
-
200
- def user_info_from_github(access_token):
201
- import requests
202
- headers = {"Accept": "application/json",
203
- 'Authorization': f"token {access_token}"}
204
- res = requests.get(
205
- f"https://api.github.com/user?access_token={access_token}",
206
- headers=headers)
207
- user_info = res.json()
208
- email_info = requests.get(
209
- f"https://api.github.com/user/emails?access_token={access_token}",
210
- headers=headers).json()
211
- user_info["email"] = next(
212
- (email for email in email_info if email['primary'] == True),
213
- None)["email"]
214
- return user_info
215
-
216
-
217
- @manager.route("/logout", methods=['GET'])
218
- @login_required
219
- def log_out():
220
- current_user.access_token = ""
221
- current_user.save()
222
- logout_user()
223
- return get_json_result(data=True)
224
-
225
-
226
- @manager.route("/setting", methods=["POST"])
227
- @login_required
228
- def setting_user():
229
- update_dict = {}
230
- request_data = request.json
231
- if request_data.get("password"):
232
- new_password = request_data.get("new_password")
233
- if not check_password_hash(
234
- current_user.password, decrypt(request_data["password"])):
235
- return get_json_result(
236
- data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
237
-
238
- if new_password:
239
- update_dict["password"] = generate_password_hash(
240
- decrypt(new_password))
241
-
242
- for k in request_data.keys():
243
- if k in ["password", "new_password"]:
244
- continue
245
- update_dict[k] = request_data[k]
246
-
247
- try:
248
- UserService.update_by_id(current_user.id, update_dict)
249
- return get_json_result(data=True)
250
- except Exception as e:
251
- stat_logger.exception(e)
252
- return get_json_result(
253
- data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
254
-
255
-
256
- @manager.route("/info", methods=["GET"])
257
- @login_required
258
- def user_info():
259
- return get_json_result(data=current_user.to_dict())
260
-
261
-
262
- def rollback_user_registration(user_id):
263
- try:
264
- UserService.delete_by_id(user_id)
265
- except Exception as e:
266
- pass
267
- try:
268
- TenantService.delete_by_id(user_id)
269
- except Exception as e:
270
- pass
271
- try:
272
- u = UserTenantService.query(tenant_id=user_id)
273
- if u:
274
- UserTenantService.delete_by_id(u[0].id)
275
- except Exception as e:
276
- pass
277
- try:
278
- TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
279
- except Exception as e:
280
- pass
281
-
282
-
283
- def user_register(user_id, user):
284
- user["id"] = user_id
285
- tenant = {
286
- "id": user_id,
287
- "name": user["nickname"] + "‘s Kingdom",
288
- "llm_id": CHAT_MDL,
289
- "embd_id": EMBEDDING_MDL,
290
- "asr_id": ASR_MDL,
291
- "parser_ids": PARSERS,
292
- "img2txt_id": IMAGE2TEXT_MDL,
293
- "rerank_id": RERANK_MDL
294
- }
295
- usr_tenant = {
296
- "tenant_id": user_id,
297
- "user_id": user_id,
298
- "invited_by": user_id,
299
- "role": UserTenantRole.OWNER
300
- }
301
- file_id = get_uuid()
302
- file = {
303
- "id": file_id,
304
- "parent_id": file_id,
305
- "tenant_id": user_id,
306
- "created_by": user_id,
307
- "name": "/",
308
- "type": FileType.FOLDER.value,
309
- "size": 0,
310
- "location": "",
311
- }
312
- tenant_llm = []
313
- for llm in LLMService.query(fid=LLM_FACTORY):
314
- tenant_llm.append({"tenant_id": user_id,
315
- "llm_factory": LLM_FACTORY,
316
- "llm_name": llm.llm_name,
317
- "model_type": llm.model_type,
318
- "api_key": API_KEY,
319
- "api_base": LLM_BASE_URL
320
- })
321
-
322
- if not UserService.save(**user):
323
- return
324
- TenantService.insert(**tenant)
325
- UserTenantService.insert(**usr_tenant)
326
- TenantLLMService.insert_many(tenant_llm)
327
- FileService.insert(file)
328
- return UserService.query(email=user["email"])
329
-
330
-
331
- @manager.route("/register", methods=["POST"])
332
- @validate_request("nickname", "email", "password")
333
- def user_add():
334
- req = request.json
335
- if UserService.query(email=req["email"]):
336
- return get_json_result(
337
- data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
338
- if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
339
- return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
340
- retcode=RetCode.OPERATING_ERROR)
341
-
342
- user_dict = {
343
- "access_token": get_uuid(),
344
- "email": req["email"],
345
- "nickname": req["nickname"],
346
- "password": decrypt(req["password"]),
347
- "login_channel": "password",
348
- "last_login_time": get_format_time(),
349
- "is_superuser": False,
350
- }
351
-
352
- user_id = get_uuid()
353
- try:
354
- users = user_register(user_id, user_dict)
355
- if not users:
356
- raise Exception('Register user failure.')
357
- if len(users) > 1:
358
- raise Exception('Same E-mail exist!')
359
- user = users[0]
360
- login_user(user)
361
- return cors_reponse(data=user.to_json(),
362
- auth=user.get_id(), retmsg="Welcome aboard!")
363
- except Exception as e:
364
- rollback_user_registration(user_id)
365
- stat_logger.exception(e)
366
- return get_json_result(
367
- data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
368
-
369
-
370
- @manager.route("/tenant_info", methods=["GET"])
371
- @login_required
372
- def tenant_info():
373
- try:
374
- tenants = TenantService.get_by_user_id(current_user.id)[0]
375
- return get_json_result(data=tenants)
376
- except Exception as e:
377
- return server_error_response(e)
378
-
379
-
380
- @manager.route("/set_tenant_info", methods=["POST"])
381
- @login_required
382
- @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
383
- def set_tenant_info():
384
- req = request.json
385
- try:
386
- tid = req["tenant_id"]
387
- del req["tenant_id"]
388
- TenantService.update_by_id(tid, req)
389
- return get_json_result(data=True)
390
- except Exception as e:
391
- return server_error_response(e)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import re
18
+ from datetime import datetime
19
+
20
+ from flask import request, session, redirect
21
+ from werkzeug.security import generate_password_hash, check_password_hash
22
+ from flask_login import login_required, current_user, login_user, logout_user
23
+
24
+ from api.db.db_models import TenantLLM
25
+ from api.db.services.llm_service import TenantLLMService, LLMService
26
+ from api.utils.api_utils import server_error_response, validate_request
27
+ from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
28
+ from api.db import UserTenantRole, LLMType, FileType
29
+ from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
30
+ API_KEY, \
31
+ LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
32
+ from api.db.services.user_service import UserService, TenantService, UserTenantService
33
+ from api.db.services.file_service import FileService
34
+ from api.settings import stat_logger
35
+ from api.utils.api_utils import get_json_result, cors_reponse
36
+
37
+
38
+ @manager.route('/login', methods=['POST', 'GET'])
39
+ def login():
40
+ login_channel = "password"
41
+ if not request.json:
42
+ return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
43
+ retmsg='Unautherized!')
44
+
45
+ email = request.json.get('email', "")
46
+ users = UserService.query(email=email)
47
+ if not users:
48
+ return get_json_result(
49
+ data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
50
+
51
+ password = request.json.get('password')
52
+ try:
53
+ password = decrypt(password)
54
+ except BaseException:
55
+ return get_json_result(
56
+ data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
57
+
58
+ user = UserService.query_user(email, password)
59
+ if user:
60
+ response_data = user.to_json()
61
+ user.access_token = get_uuid()
62
+ login_user(user)
63
+ user.update_time = current_timestamp(),
64
+ user.update_date = datetime_format(datetime.now()),
65
+ user.save()
66
+ msg = "Welcome back!"
67
+ return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
68
+ else:
69
+ return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
70
+ retmsg='Email and Password do not match!')
71
+
72
+
73
+ @manager.route('/github_callback', methods=['GET'])
74
+ def github_callback():
75
+ import requests
76
+ res = requests.post(GITHUB_OAUTH.get("url"), data={
77
+ "client_id": GITHUB_OAUTH.get("client_id"),
78
+ "client_secret": GITHUB_OAUTH.get("secret_key"),
79
+ "code": request.args.get('code')
80
+ }, headers={"Accept": "application/json"})
81
+ res = res.json()
82
+ if "error" in res:
83
+ return redirect("/?error=%s" % res["error_description"])
84
+
85
+ if "user:email" not in res["scope"].split(","):
86
+ return redirect("/?error=user:email not in scope")
87
+
88
+ session["access_token"] = res["access_token"]
89
+ session["access_token_from"] = "github"
90
+ userinfo = user_info_from_github(session["access_token"])
91
+ users = UserService.query(email=userinfo["email"])
92
+ user_id = get_uuid()
93
+ if not users:
94
+ try:
95
+ try:
96
+ avatar = download_img(userinfo["avatar_url"])
97
+ except Exception as e:
98
+ stat_logger.exception(e)
99
+ avatar = ""
100
+ users = user_register(user_id, {
101
+ "access_token": session["access_token"],
102
+ "email": userinfo["email"],
103
+ "avatar": avatar,
104
+ "nickname": userinfo["login"],
105
+ "login_channel": "github",
106
+ "last_login_time": get_format_time(),
107
+ "is_superuser": False,
108
+ })
109
+ if not users:
110
+ raise Exception('Register user failure.')
111
+ if len(users) > 1:
112
+ raise Exception('Same E-mail exist!')
113
+ user = users[0]
114
+ login_user(user)
115
+ return redirect("/?auth=%s" % user.get_id())
116
+ except Exception as e:
117
+ rollback_user_registration(user_id)
118
+ stat_logger.exception(e)
119
+ return redirect("/?error=%s" % str(e))
120
+ user = users[0]
121
+ user.access_token = get_uuid()
122
+ login_user(user)
123
+ user.save()
124
+ return redirect("/?auth=%s" % user.get_id())
125
+
126
+
127
+ @manager.route('/feishu_callback', methods=['GET'])
128
+ def feishu_callback():
129
+ import requests
130
+ app_access_token_res = requests.post(FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps({
131
+ "app_id": FEISHU_OAUTH.get("app_id"),
132
+ "app_secret": FEISHU_OAUTH.get("app_secret")
133
+ }), headers={"Content-Type": "application/json; charset=utf-8"})
134
+ app_access_token_res = app_access_token_res.json()
135
+ if app_access_token_res['code'] != 0:
136
+ return redirect("/?error=%s" % app_access_token_res)
137
+
138
+ res = requests.post(FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps({
139
+ "grant_type": FEISHU_OAUTH.get("grant_type"),
140
+ "code": request.args.get('code')
141
+ }), headers={"Content-Type": "application/json; charset=utf-8",
142
+ 'Authorization': f"Bearer {app_access_token_res['app_access_token']}"})
143
+ res = res.json()
144
+ if res['code'] != 0:
145
+ return redirect("/?error=%s" % res["message"])
146
+
147
+ if "contact:user.email:readonly" not in res["data"]["scope"].split(" "):
148
+ return redirect("/?error=contact:user.email:readonly not in scope")
149
+ session["access_token"] = res["data"]["access_token"]
150
+ session["access_token_from"] = "feishu"
151
+ userinfo = user_info_from_feishu(session["access_token"])
152
+ users = UserService.query(email=userinfo["email"])
153
+ user_id = get_uuid()
154
+ if not users:
155
+ try:
156
+ try:
157
+ avatar = download_img(userinfo["avatar_url"])
158
+ except Exception as e:
159
+ stat_logger.exception(e)
160
+ avatar = ""
161
+ users = user_register(user_id, {
162
+ "access_token": session["access_token"],
163
+ "email": userinfo["email"],
164
+ "avatar": avatar,
165
+ "nickname": userinfo["en_name"],
166
+ "login_channel": "feishu",
167
+ "last_login_time": get_format_time(),
168
+ "is_superuser": False,
169
+ })
170
+ if not users:
171
+ raise Exception('Register user failure.')
172
+ if len(users) > 1:
173
+ raise Exception('Same E-mail exist!')
174
+ user = users[0]
175
+ login_user(user)
176
+ return redirect("/?auth=%s" % user.get_id())
177
+ except Exception as e:
178
+ rollback_user_registration(user_id)
179
+ stat_logger.exception(e)
180
+ return redirect("/?error=%s" % str(e))
181
+ user = users[0]
182
+ user.access_token = get_uuid()
183
+ login_user(user)
184
+ user.save()
185
+ return redirect("/?auth=%s" % user.get_id())
186
+
187
+
188
+ def user_info_from_feishu(access_token):
189
+ import requests
190
+ headers = {"Content-Type": "application/json; charset=utf-8",
191
+ 'Authorization': f"Bearer {access_token}"}
192
+ res = requests.get(
193
+ f"https://open.feishu.cn/open-apis/authen/v1/user_info",
194
+ headers=headers)
195
+ user_info = res.json()["data"]
196
+ user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
197
+ return user_info
198
+
199
+
200
+ def user_info_from_github(access_token):
201
+ import requests
202
+ headers = {"Accept": "application/json",
203
+ 'Authorization': f"token {access_token}"}
204
+ res = requests.get(
205
+ f"https://api.github.com/user?access_token={access_token}",
206
+ headers=headers)
207
+ user_info = res.json()
208
+ email_info = requests.get(
209
+ f"https://api.github.com/user/emails?access_token={access_token}",
210
+ headers=headers).json()
211
+ user_info["email"] = next(
212
+ (email for email in email_info if email['primary'] == True),
213
+ None)["email"]
214
+ return user_info
215
+
216
+
217
+ @manager.route("/logout", methods=['GET'])
218
+ @login_required
219
+ def log_out():
220
+ current_user.access_token = ""
221
+ current_user.save()
222
+ logout_user()
223
+ return get_json_result(data=True)
224
+
225
+
226
+ @manager.route("/setting", methods=["POST"])
227
+ @login_required
228
+ def setting_user():
229
+ update_dict = {}
230
+ request_data = request.json
231
+ if request_data.get("password"):
232
+ new_password = request_data.get("new_password")
233
+ if not check_password_hash(
234
+ current_user.password, decrypt(request_data["password"])):
235
+ return get_json_result(
236
+ data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
237
+
238
+ if new_password:
239
+ update_dict["password"] = generate_password_hash(
240
+ decrypt(new_password))
241
+
242
+ for k in request_data.keys():
243
+ if k in ["password", "new_password"]:
244
+ continue
245
+ update_dict[k] = request_data[k]
246
+
247
+ try:
248
+ UserService.update_by_id(current_user.id, update_dict)
249
+ return get_json_result(data=True)
250
+ except Exception as e:
251
+ stat_logger.exception(e)
252
+ return get_json_result(
253
+ data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
254
+
255
+
256
+ @manager.route("/info", methods=["GET"])
257
+ @login_required
258
+ def user_info():
259
+ return get_json_result(data=current_user.to_dict())
260
+
261
+
262
+ def rollback_user_registration(user_id):
263
+ try:
264
+ UserService.delete_by_id(user_id)
265
+ except Exception as e:
266
+ pass
267
+ try:
268
+ TenantService.delete_by_id(user_id)
269
+ except Exception as e:
270
+ pass
271
+ try:
272
+ u = UserTenantService.query(tenant_id=user_id)
273
+ if u:
274
+ UserTenantService.delete_by_id(u[0].id)
275
+ except Exception as e:
276
+ pass
277
+ try:
278
+ TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
279
+ except Exception as e:
280
+ pass
281
+
282
+
283
+ def user_register(user_id, user):
284
+ user["id"] = user_id
285
+ tenant = {
286
+ "id": user_id,
287
+ "name": user["nickname"] + "‘s Kingdom",
288
+ "llm_id": CHAT_MDL,
289
+ "embd_id": EMBEDDING_MDL,
290
+ "asr_id": ASR_MDL,
291
+ "parser_ids": PARSERS,
292
+ "img2txt_id": IMAGE2TEXT_MDL,
293
+ "rerank_id": RERANK_MDL
294
+ }
295
+ usr_tenant = {
296
+ "tenant_id": user_id,
297
+ "user_id": user_id,
298
+ "invited_by": user_id,
299
+ "role": UserTenantRole.OWNER
300
+ }
301
+ file_id = get_uuid()
302
+ file = {
303
+ "id": file_id,
304
+ "parent_id": file_id,
305
+ "tenant_id": user_id,
306
+ "created_by": user_id,
307
+ "name": "/",
308
+ "type": FileType.FOLDER.value,
309
+ "size": 0,
310
+ "location": "",
311
+ }
312
+ tenant_llm = []
313
+ for llm in LLMService.query(fid=LLM_FACTORY):
314
+ tenant_llm.append({"tenant_id": user_id,
315
+ "llm_factory": LLM_FACTORY,
316
+ "llm_name": llm.llm_name,
317
+ "model_type": llm.model_type,
318
+ "api_key": API_KEY,
319
+ "api_base": LLM_BASE_URL
320
+ })
321
+
322
+ if not UserService.save(**user):
323
+ return
324
+ TenantService.insert(**tenant)
325
+ UserTenantService.insert(**usr_tenant)
326
+ TenantLLMService.insert_many(tenant_llm)
327
+ FileService.insert(file)
328
+ return UserService.query(email=user["email"])
329
+
330
+
331
+ @manager.route("/register", methods=["POST"])
332
+ @validate_request("nickname", "email", "password")
333
+ def user_add():
334
+ req = request.json
335
+ if UserService.query(email=req["email"]):
336
+ return get_json_result(
337
+ data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
338
+ if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
339
+ return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
340
+ retcode=RetCode.OPERATING_ERROR)
341
+
342
+ user_dict = {
343
+ "access_token": get_uuid(),
344
+ "email": req["email"],
345
+ "nickname": req["nickname"],
346
+ "password": decrypt(req["password"]),
347
+ "login_channel": "password",
348
+ "last_login_time": get_format_time(),
349
+ "is_superuser": False,
350
+ }
351
+
352
+ user_id = get_uuid()
353
+ try:
354
+ users = user_register(user_id, user_dict)
355
+ if not users:
356
+ raise Exception('Register user failure.')
357
+ if len(users) > 1:
358
+ raise Exception('Same E-mail exist!')
359
+ user = users[0]
360
+ login_user(user)
361
+ return cors_reponse(data=user.to_json(),
362
+ auth=user.get_id(), retmsg="Welcome aboard!")
363
+ except Exception as e:
364
+ rollback_user_registration(user_id)
365
+ stat_logger.exception(e)
366
+ return get_json_result(
367
+ data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
368
+
369
+
370
+ @manager.route("/tenant_info", methods=["GET"])
371
+ @login_required
372
+ def tenant_info():
373
+ try:
374
+ tenants = TenantService.get_by_user_id(current_user.id)[0]
375
+ return get_json_result(data=tenants)
376
+ except Exception as e:
377
+ return server_error_response(e)
378
+
379
+
380
+ @manager.route("/set_tenant_info", methods=["POST"])
381
+ @login_required
382
+ @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
383
+ def set_tenant_info():
384
+ req = request.json
385
+ try:
386
+ tid = req["tenant_id"]
387
+ del req["tenant_id"]
388
+ TenantService.update_by_id(tid, req)
389
+ return get_json_result(data=True)
390
+ except Exception as e:
391
+ return server_error_response(e)
api/db/__init__.py CHANGED
@@ -1,102 +1,102 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from enum import Enum
17
- from enum import IntEnum
18
- from strenum import StrEnum
19
-
20
-
21
- class StatusEnum(Enum):
22
- VALID = "1"
23
- INVALID = "0"
24
-
25
-
26
- class UserTenantRole(StrEnum):
27
- OWNER = 'owner'
28
- ADMIN = 'admin'
29
- NORMAL = 'normal'
30
-
31
-
32
- class TenantPermission(StrEnum):
33
- ME = 'me'
34
- TEAM = 'team'
35
-
36
-
37
- class SerializedType(IntEnum):
38
- PICKLE = 1
39
- JSON = 2
40
-
41
-
42
- class FileType(StrEnum):
43
- PDF = 'pdf'
44
- DOC = 'doc'
45
- VISUAL = 'visual'
46
- AURAL = 'aural'
47
- VIRTUAL = 'virtual'
48
- FOLDER = 'folder'
49
- OTHER = "other"
50
-
51
-
52
- class LLMType(StrEnum):
53
- CHAT = 'chat'
54
- EMBEDDING = 'embedding'
55
- SPEECH2TEXT = 'speech2text'
56
- IMAGE2TEXT = 'image2text'
57
- RERANK = 'rerank'
58
-
59
-
60
- class ChatStyle(StrEnum):
61
- CREATIVE = 'Creative'
62
- PRECISE = 'Precise'
63
- EVENLY = 'Evenly'
64
- CUSTOM = 'Custom'
65
-
66
-
67
- class TaskStatus(StrEnum):
68
- UNSTART = "0"
69
- RUNNING = "1"
70
- CANCEL = "2"
71
- DONE = "3"
72
- FAIL = "4"
73
-
74
-
75
- class ParserType(StrEnum):
76
- PRESENTATION = "presentation"
77
- LAWS = "laws"
78
- MANUAL = "manual"
79
- PAPER = "paper"
80
- RESUME = "resume"
81
- BOOK = "book"
82
- QA = "qa"
83
- TABLE = "table"
84
- NAIVE = "naive"
85
- PICTURE = "picture"
86
- ONE = "one"
87
- AUDIO = "audio"
88
- EMAIL = "email"
89
- KG = "knowledge_graph"
90
-
91
-
92
- class FileSource(StrEnum):
93
- LOCAL = ""
94
- KNOWLEDGEBASE = "knowledgebase"
95
- S3 = "s3"
96
-
97
-
98
- class CanvasType(StrEnum):
99
- ChatBot = "chatbot"
100
- DocBot = "docbot"
101
-
102
- KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from enum import Enum
17
+ from enum import IntEnum
18
+ from strenum import StrEnum
19
+
20
+
21
+ class StatusEnum(Enum):
22
+ VALID = "1"
23
+ INVALID = "0"
24
+
25
+
26
+ class UserTenantRole(StrEnum):
27
+ OWNER = 'owner'
28
+ ADMIN = 'admin'
29
+ NORMAL = 'normal'
30
+
31
+
32
+ class TenantPermission(StrEnum):
33
+ ME = 'me'
34
+ TEAM = 'team'
35
+
36
+
37
+ class SerializedType(IntEnum):
38
+ PICKLE = 1
39
+ JSON = 2
40
+
41
+
42
+ class FileType(StrEnum):
43
+ PDF = 'pdf'
44
+ DOC = 'doc'
45
+ VISUAL = 'visual'
46
+ AURAL = 'aural'
47
+ VIRTUAL = 'virtual'
48
+ FOLDER = 'folder'
49
+ OTHER = "other"
50
+
51
+
52
+ class LLMType(StrEnum):
53
+ CHAT = 'chat'
54
+ EMBEDDING = 'embedding'
55
+ SPEECH2TEXT = 'speech2text'
56
+ IMAGE2TEXT = 'image2text'
57
+ RERANK = 'rerank'
58
+
59
+
60
+ class ChatStyle(StrEnum):
61
+ CREATIVE = 'Creative'
62
+ PRECISE = 'Precise'
63
+ EVENLY = 'Evenly'
64
+ CUSTOM = 'Custom'
65
+
66
+
67
+ class TaskStatus(StrEnum):
68
+ UNSTART = "0"
69
+ RUNNING = "1"
70
+ CANCEL = "2"
71
+ DONE = "3"
72
+ FAIL = "4"
73
+
74
+
75
+ class ParserType(StrEnum):
76
+ PRESENTATION = "presentation"
77
+ LAWS = "laws"
78
+ MANUAL = "manual"
79
+ PAPER = "paper"
80
+ RESUME = "resume"
81
+ BOOK = "book"
82
+ QA = "qa"
83
+ TABLE = "table"
84
+ NAIVE = "naive"
85
+ PICTURE = "picture"
86
+ ONE = "one"
87
+ AUDIO = "audio"
88
+ EMAIL = "email"
89
+ KG = "knowledge_graph"
90
+
91
+
92
+ class FileSource(StrEnum):
93
+ LOCAL = ""
94
+ KNOWLEDGEBASE = "knowledgebase"
95
+ S3 = "s3"
96
+
97
+
98
+ class CanvasType(StrEnum):
99
+ ChatBot = "chatbot"
100
+ DocBot = "docbot"
101
+
102
+ KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
api/db/db_models.py CHANGED
@@ -1,972 +1,972 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import inspect
17
- import os
18
- import sys
19
- import typing
20
- import operator
21
- from functools import wraps
22
- from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
23
- from flask_login import UserMixin
24
- from playhouse.migrate import MySQLMigrator, migrate
25
- from peewee import (
26
- BigIntegerField, BooleanField, CharField,
27
- CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
28
- Field, Model, Metadata
29
- )
30
- from playhouse.pool import PooledMySQLDatabase
31
- from api.db import SerializedType, ParserType
32
- from api.settings import DATABASE, stat_logger, SECRET_KEY
33
- from api.utils.log_utils import getLogger
34
- from api import utils
35
-
36
- LOGGER = getLogger()
37
-
38
-
39
- def singleton(cls, *args, **kw):
40
- instances = {}
41
-
42
- def _singleton():
43
- key = str(cls) + str(os.getpid())
44
- if key not in instances:
45
- instances[key] = cls(*args, **kw)
46
- return instances[key]
47
-
48
- return _singleton
49
-
50
-
51
- CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
52
- AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
53
- "create",
54
- "start",
55
- "end",
56
- "update",
57
- "read_access",
58
- "write_access"}
59
-
60
-
61
- class LongTextField(TextField):
62
- field_type = 'LONGTEXT'
63
-
64
-
65
- class JSONField(LongTextField):
66
- default_value = {}
67
-
68
- def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
69
- self._object_hook = object_hook
70
- self._object_pairs_hook = object_pairs_hook
71
- super().__init__(**kwargs)
72
-
73
- def db_value(self, value):
74
- if value is None:
75
- value = self.default_value
76
- return utils.json_dumps(value)
77
-
78
- def python_value(self, value):
79
- if not value:
80
- return self.default_value
81
- return utils.json_loads(
82
- value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
83
-
84
-
85
- class ListField(JSONField):
86
- default_value = []
87
-
88
-
89
- class SerializedField(LongTextField):
90
- def __init__(self, serialized_type=SerializedType.PICKLE,
91
- object_hook=None, object_pairs_hook=None, **kwargs):
92
- self._serialized_type = serialized_type
93
- self._object_hook = object_hook
94
- self._object_pairs_hook = object_pairs_hook
95
- super().__init__(**kwargs)
96
-
97
- def db_value(self, value):
98
- if self._serialized_type == SerializedType.PICKLE:
99
- return utils.serialize_b64(value, to_str=True)
100
- elif self._serialized_type == SerializedType.JSON:
101
- if value is None:
102
- return None
103
- return utils.json_dumps(value, with_type=True)
104
- else:
105
- raise ValueError(
106
- f"the serialized type {self._serialized_type} is not supported")
107
-
108
- def python_value(self, value):
109
- if self._serialized_type == SerializedType.PICKLE:
110
- return utils.deserialize_b64(value)
111
- elif self._serialized_type == SerializedType.JSON:
112
- if value is None:
113
- return {}
114
- return utils.json_loads(
115
- value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
116
- else:
117
- raise ValueError(
118
- f"the serialized type {self._serialized_type} is not supported")
119
-
120
-
121
- def is_continuous_field(cls: typing.Type) -> bool:
122
- if cls in CONTINUOUS_FIELD_TYPE:
123
- return True
124
- for p in cls.__bases__:
125
- if p in CONTINUOUS_FIELD_TYPE:
126
- return True
127
- elif p != Field and p != object:
128
- if is_continuous_field(p):
129
- return True
130
- else:
131
- return False
132
-
133
-
134
- def auto_date_timestamp_field():
135
- return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
136
-
137
-
138
- def auto_date_timestamp_db_field():
139
- return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
140
-
141
-
142
- def remove_field_name_prefix(field_name):
143
- return field_name[2:] if field_name.startswith('f_') else field_name
144
-
145
-
146
- class BaseModel(Model):
147
- create_time = BigIntegerField(null=True, index=True)
148
- create_date = DateTimeField(null=True, index=True)
149
- update_time = BigIntegerField(null=True, index=True)
150
- update_date = DateTimeField(null=True, index=True)
151
-
152
- def to_json(self):
153
- # This function is obsolete
154
- return self.to_dict()
155
-
156
- def to_dict(self):
157
- return self.__dict__['__data__']
158
-
159
- def to_human_model_dict(self, only_primary_with: list = None):
160
- model_dict = self.__dict__['__data__']
161
-
162
- if not only_primary_with:
163
- return {remove_field_name_prefix(
164
- k): v for k, v in model_dict.items()}
165
-
166
- human_model_dict = {}
167
- for k in self._meta.primary_key.field_names:
168
- human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
169
- for k in only_primary_with:
170
- human_model_dict[k] = model_dict[f'f_{k}']
171
- return human_model_dict
172
-
173
- @property
174
- def meta(self) -> Metadata:
175
- return self._meta
176
-
177
- @classmethod
178
- def get_primary_keys_name(cls):
179
- return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
180
- cls._meta.primary_key.name]
181
-
182
- @classmethod
183
- def getter_by(cls, attr):
184
- return operator.attrgetter(attr)(cls)
185
-
186
- @classmethod
187
- def query(cls, reverse=None, order_by=None, **kwargs):
188
- filters = []
189
- for f_n, f_v in kwargs.items():
190
- attr_name = '%s' % f_n
191
- if not hasattr(cls, attr_name) or f_v is None:
192
- continue
193
- if type(f_v) in {list, set}:
194
- f_v = list(f_v)
195
- if is_continuous_field(type(getattr(cls, attr_name))):
196
- if len(f_v) == 2:
197
- for i, v in enumerate(f_v):
198
- if isinstance(
199
- v, str) and f_n in auto_date_timestamp_field():
200
- # time type: %Y-%m-%d %H:%M:%S
201
- f_v[i] = utils.date_string_to_timestamp(v)
202
- lt_value = f_v[0]
203
- gt_value = f_v[1]
204
- if lt_value is not None and gt_value is not None:
205
- filters.append(
206
- cls.getter_by(attr_name).between(
207
- lt_value, gt_value))
208
- elif lt_value is not None:
209
- filters.append(
210
- operator.attrgetter(attr_name)(cls) >= lt_value)
211
- elif gt_value is not None:
212
- filters.append(
213
- operator.attrgetter(attr_name)(cls) <= gt_value)
214
- else:
215
- filters.append(operator.attrgetter(attr_name)(cls) << f_v)
216
- else:
217
- filters.append(operator.attrgetter(attr_name)(cls) == f_v)
218
- if filters:
219
- query_records = cls.select().where(*filters)
220
- if reverse is not None:
221
- if not order_by or not hasattr(cls, f"{order_by}"):
222
- order_by = "create_time"
223
- if reverse is True:
224
- query_records = query_records.order_by(
225
- cls.getter_by(f"{order_by}").desc())
226
- elif reverse is False:
227
- query_records = query_records.order_by(
228
- cls.getter_by(f"{order_by}").asc())
229
- return [query_record for query_record in query_records]
230
- else:
231
- return []
232
-
233
- @classmethod
234
- def insert(cls, __data=None, **insert):
235
- if isinstance(__data, dict) and __data:
236
- __data[cls._meta.combined["create_time"]
237
- ] = utils.current_timestamp()
238
- if insert:
239
- insert["create_time"] = utils.current_timestamp()
240
-
241
- return super().insert(__data, **insert)
242
-
243
- # update and insert will call this method
244
- @classmethod
245
- def _normalize_data(cls, data, kwargs):
246
- normalized = super()._normalize_data(data, kwargs)
247
- if not normalized:
248
- return {}
249
-
250
- normalized[cls._meta.combined["update_time"]
251
- ] = utils.current_timestamp()
252
-
253
- for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
254
- if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
255
- cls._meta.combined[f"{f_n}_time"] in normalized and \
256
- normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
257
- normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
258
- normalized[cls._meta.combined[f"{f_n}_time"]])
259
-
260
- return normalized
261
-
262
-
263
- class JsonSerializedField(SerializedField):
264
- def __init__(self, object_hook=utils.from_dict_hook,
265
- object_pairs_hook=None, **kwargs):
266
- super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
267
- object_pairs_hook=object_pairs_hook, **kwargs)
268
-
269
-
270
- @singleton
271
- class BaseDataBase:
272
- def __init__(self):
273
- database_config = DATABASE.copy()
274
- db_name = database_config.pop("name")
275
- self.database_connection = PooledMySQLDatabase(
276
- db_name, **database_config)
277
- stat_logger.info('init mysql database on cluster mode successfully')
278
-
279
-
280
- class DatabaseLock:
281
- def __init__(self, lock_name, timeout=10, db=None):
282
- self.lock_name = lock_name
283
- self.timeout = int(timeout)
284
- self.db = db if db else DB
285
-
286
- def lock(self):
287
- # SQL parameters only support %s format placeholders
288
- cursor = self.db.execute_sql(
289
- "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
290
- ret = cursor.fetchone()
291
- if ret[0] == 0:
292
- raise Exception(f'acquire mysql lock {self.lock_name} timeout')
293
- elif ret[0] == 1:
294
- return True
295
- else:
296
- raise Exception(f'failed to acquire lock {self.lock_name}')
297
-
298
- def unlock(self):
299
- cursor = self.db.execute_sql(
300
- "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
301
- ret = cursor.fetchone()
302
- if ret[0] == 0:
303
- raise Exception(
304
- f'mysql lock {self.lock_name} was not established by this thread')
305
- elif ret[0] == 1:
306
- return True
307
- else:
308
- raise Exception(f'mysql lock {self.lock_name} does not exist')
309
-
310
- def __enter__(self):
311
- if isinstance(self.db, PooledMySQLDatabase):
312
- self.lock()
313
- return self
314
-
315
- def __exit__(self, exc_type, exc_val, exc_tb):
316
- if isinstance(self.db, PooledMySQLDatabase):
317
- self.unlock()
318
-
319
- def __call__(self, func):
320
- @wraps(func)
321
- def magic(*args, **kwargs):
322
- with self:
323
- return func(*args, **kwargs)
324
-
325
- return magic
326
-
327
-
328
- DB = BaseDataBase().database_connection
329
- DB.lock = DatabaseLock
330
-
331
-
332
- def close_connection():
333
- try:
334
- if DB:
335
- DB.close_stale(age=30)
336
- except Exception as e:
337
- LOGGER.exception(e)
338
-
339
-
340
- class DataBaseModel(BaseModel):
341
- class Meta:
342
- database = DB
343
-
344
-
345
- @DB.connection_context()
346
- def init_database_tables(alter_fields=[]):
347
- members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
348
- table_objs = []
349
- create_failed_list = []
350
- for name, obj in members:
351
- if obj != DataBaseModel and issubclass(obj, DataBaseModel):
352
- table_objs.append(obj)
353
- LOGGER.info(f"start create table {obj.__name__}")
354
- try:
355
- obj.create_table()
356
- LOGGER.info(f"create table success: {obj.__name__}")
357
- except Exception as e:
358
- LOGGER.exception(e)
359
- create_failed_list.append(obj.__name__)
360
- if create_failed_list:
361
- LOGGER.info(f"create tables failed: {create_failed_list}")
362
- raise Exception(f"create tables failed: {create_failed_list}")
363
- migrate_db()
364
-
365
-
366
- def fill_db_model_object(model_object, human_model_dict):
367
- for k, v in human_model_dict.items():
368
- attr_name = '%s' % k
369
- if hasattr(model_object.__class__, attr_name):
370
- setattr(model_object, attr_name, v)
371
- return model_object
372
-
373
-
374
- class User(DataBaseModel, UserMixin):
375
- id = CharField(max_length=32, primary_key=True)
376
- access_token = CharField(max_length=255, null=True, index=True)
377
- nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
378
- password = CharField(max_length=255, null=True, help_text="password", index=True)
379
- email = CharField(
380
- max_length=255,
381
- null=False,
382
- help_text="email",
383
- index=True)
384
- avatar = TextField(null=True, help_text="avatar base64 string")
385
- language = CharField(
386
- max_length=32,
387
- null=True,
388
- help_text="English|Chinese",
389
- default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
390
- index=True)
391
- color_schema = CharField(
392
- max_length=32,
393
- null=True,
394
- help_text="Bright|Dark",
395
- default="Bright",
396
- index=True)
397
- timezone = CharField(
398
- max_length=64,
399
- null=True,
400
- help_text="Timezone",
401
- default="UTC+8\tAsia/Shanghai",
402
- index=True)
403
- last_login_time = DateTimeField(null=True, index=True)
404
- is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
405
- is_active = CharField(max_length=1, null=False, default="1", index=True)
406
- is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
407
- login_channel = CharField(null=True, help_text="from which user login", index=True)
408
- status = CharField(
409
- max_length=1,
410
- null=True,
411
- help_text="is it validate(0: wasted,1: validate)",
412
- default="1",
413
- index=True)
414
- is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
415
-
416
- def __str__(self):
417
- return self.email
418
-
419
- def get_id(self):
420
- jwt = Serializer(secret_key=SECRET_KEY)
421
- return jwt.dumps(str(self.access_token))
422
-
423
- class Meta:
424
- db_table = "user"
425
-
426
-
427
- class Tenant(DataBaseModel):
428
- id = CharField(max_length=32, primary_key=True)
429
- name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
430
- public_key = CharField(max_length=255, null=True, index=True)
431
- llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
432
- embd_id = CharField(
433
- max_length=128,
434
- null=False,
435
- help_text="default embedding model ID",
436
- index=True)
437
- asr_id = CharField(
438
- max_length=128,
439
- null=False,
440
- help_text="default ASR model ID",
441
- index=True)
442
- img2txt_id = CharField(
443
- max_length=128,
444
- null=False,
445
- help_text="default image to text model ID",
446
- index=True)
447
- rerank_id = CharField(
448
- max_length=128,
449
- null=False,
450
- help_text="default rerank model ID",
451
- index=True)
452
- parser_ids = CharField(
453
- max_length=256,
454
- null=False,
455
- help_text="document processors",
456
- index=True)
457
- credit = IntegerField(default=512, index=True)
458
- status = CharField(
459
- max_length=1,
460
- null=True,
461
- help_text="is it validate(0: wasted,1: validate)",
462
- default="1",
463
- index=True)
464
-
465
- class Meta:
466
- db_table = "tenant"
467
-
468
-
469
- class UserTenant(DataBaseModel):
470
- id = CharField(max_length=32, primary_key=True)
471
- user_id = CharField(max_length=32, null=False, index=True)
472
- tenant_id = CharField(max_length=32, null=False, index=True)
473
- role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
474
- invited_by = CharField(max_length=32, null=False, index=True)
475
- status = CharField(
476
- max_length=1,
477
- null=True,
478
- help_text="is it validate(0: wasted,1: validate)",
479
- default="1",
480
- index=True)
481
-
482
- class Meta:
483
- db_table = "user_tenant"
484
-
485
-
486
- class InvitationCode(DataBaseModel):
487
- id = CharField(max_length=32, primary_key=True)
488
- code = CharField(max_length=32, null=False, index=True)
489
- visit_time = DateTimeField(null=True, index=True)
490
- user_id = CharField(max_length=32, null=True, index=True)
491
- tenant_id = CharField(max_length=32, null=True, index=True)
492
- status = CharField(
493
- max_length=1,
494
- null=True,
495
- help_text="is it validate(0: wasted,1: validate)",
496
- default="1",
497
- index=True)
498
-
499
- class Meta:
500
- db_table = "invitation_code"
501
-
502
-
503
- class LLMFactories(DataBaseModel):
504
- name = CharField(
505
- max_length=128,
506
- null=False,
507
- help_text="LLM factory name",
508
- primary_key=True)
509
- logo = TextField(null=True, help_text="llm logo base64")
510
- tags = CharField(
511
- max_length=255,
512
- null=False,
513
- help_text="LLM, Text Embedding, Image2Text, ASR",
514
- index=True)
515
- status = CharField(
516
- max_length=1,
517
- null=True,
518
- help_text="is it validate(0: wasted,1: validate)",
519
- default="1",
520
- index=True)
521
-
522
- def __str__(self):
523
- return self.name
524
-
525
- class Meta:
526
- db_table = "llm_factories"
527
-
528
-
529
- class LLM(DataBaseModel):
530
- # LLMs dictionary
531
- llm_name = CharField(
532
- max_length=128,
533
- null=False,
534
- help_text="LLM name",
535
- index=True)
536
- model_type = CharField(
537
- max_length=128,
538
- null=False,
539
- help_text="LLM, Text Embedding, Image2Text, ASR",
540
- index=True)
541
- fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
542
- max_tokens = IntegerField(default=0)
543
-
544
- tags = CharField(
545
- max_length=255,
546
- null=False,
547
- help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
548
- index=True)
549
- status = CharField(
550
- max_length=1,
551
- null=True,
552
- help_text="is it validate(0: wasted,1: validate)",
553
- default="1",
554
- index=True)
555
-
556
- def __str__(self):
557
- return self.llm_name
558
-
559
- class Meta:
560
- primary_key = CompositeKey('fid', 'llm_name')
561
- db_table = "llm"
562
-
563
-
564
- class TenantLLM(DataBaseModel):
565
- tenant_id = CharField(max_length=32, null=False, index=True)
566
- llm_factory = CharField(
567
- max_length=128,
568
- null=False,
569
- help_text="LLM factory name",
570
- index=True)
571
- model_type = CharField(
572
- max_length=128,
573
- null=True,
574
- help_text="LLM, Text Embedding, Image2Text, ASR",
575
- index=True)
576
- llm_name = CharField(
577
- max_length=128,
578
- null=True,
579
- help_text="LLM name",
580
- default="",
581
- index=True)
582
- api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
583
- api_base = CharField(max_length=255, null=True, help_text="API Base")
584
-
585
- used_tokens = IntegerField(default=0, index=True)
586
-
587
- def __str__(self):
588
- return self.llm_name
589
-
590
- class Meta:
591
- db_table = "tenant_llm"
592
- primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
593
-
594
-
595
- class Knowledgebase(DataBaseModel):
596
- id = CharField(max_length=32, primary_key=True)
597
- avatar = TextField(null=True, help_text="avatar base64 string")
598
- tenant_id = CharField(max_length=32, null=False, index=True)
599
- name = CharField(
600
- max_length=128,
601
- null=False,
602
- help_text="KB name",
603
- index=True)
604
- language = CharField(
605
- max_length=32,
606
- null=True,
607
- default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
608
- help_text="English|Chinese",
609
- index=True)
610
- description = TextField(null=True, help_text="KB description")
611
- embd_id = CharField(
612
- max_length=128,
613
- null=False,
614
- help_text="default embedding model ID",
615
- index=True)
616
- permission = CharField(
617
- max_length=16,
618
- null=False,
619
- help_text="me|team",
620
- default="me",
621
- index=True)
622
- created_by = CharField(max_length=32, null=False, index=True)
623
- doc_num = IntegerField(default=0, index=True)
624
- token_num = IntegerField(default=0, index=True)
625
- chunk_num = IntegerField(default=0, index=True)
626
- similarity_threshold = FloatField(default=0.2, index=True)
627
- vector_similarity_weight = FloatField(default=0.3, index=True)
628
-
629
- parser_id = CharField(
630
- max_length=32,
631
- null=False,
632
- help_text="default parser ID",
633
- default=ParserType.NAIVE.value,
634
- index=True)
635
- parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
636
- status = CharField(
637
- max_length=1,
638
- null=True,
639
- help_text="is it validate(0: wasted,1: validate)",
640
- default="1",
641
- index=True)
642
-
643
- def __str__(self):
644
- return self.name
645
-
646
- class Meta:
647
- db_table = "knowledgebase"
648
-
649
-
650
- class Document(DataBaseModel):
651
- id = CharField(max_length=32, primary_key=True)
652
- thumbnail = TextField(null=True, help_text="thumbnail base64 string")
653
- kb_id = CharField(max_length=256, null=False, index=True)
654
- parser_id = CharField(
655
- max_length=32,
656
- null=False,
657
- help_text="default parser ID",
658
- index=True)
659
- parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
660
- source_type = CharField(
661
- max_length=128,
662
- null=False,
663
- default="local",
664
- help_text="where dose this document come from",
665
- index=True)
666
- type = CharField(max_length=32, null=False, help_text="file extension",
667
- index=True)
668
- created_by = CharField(
669
- max_length=32,
670
- null=False,
671
- help_text="who created it",
672
- index=True)
673
- name = CharField(
674
- max_length=255,
675
- null=True,
676
- help_text="file name",
677
- index=True)
678
- location = CharField(
679
- max_length=255,
680
- null=True,
681
- help_text="where dose it store",
682
- index=True)
683
- size = IntegerField(default=0, index=True)
684
- token_num = IntegerField(default=0, index=True)
685
- chunk_num = IntegerField(default=0, index=True)
686
- progress = FloatField(default=0, index=True)
687
- progress_msg = TextField(
688
- null=True,
689
- help_text="process message",
690
- default="")
691
- process_begin_at = DateTimeField(null=True, index=True)
692
- process_duation = FloatField(default=0)
693
-
694
- run = CharField(
695
- max_length=1,
696
- null=True,
697
- help_text="start to run processing or cancel.(1: run it; 2: cancel)",
698
- default="0",
699
- index=True)
700
- status = CharField(
701
- max_length=1,
702
- null=True,
703
- help_text="is it validate(0: wasted,1: validate)",
704
- default="1",
705
- index=True)
706
-
707
- class Meta:
708
- db_table = "document"
709
-
710
-
711
- class File(DataBaseModel):
712
- id = CharField(
713
- max_length=32,
714
- primary_key=True)
715
- parent_id = CharField(
716
- max_length=32,
717
- null=False,
718
- help_text="parent folder id",
719
- index=True)
720
- tenant_id = CharField(
721
- max_length=32,
722
- null=False,
723
- help_text="tenant id",
724
- index=True)
725
- created_by = CharField(
726
- max_length=32,
727
- null=False,
728
- help_text="who created it",
729
- index=True)
730
- name = CharField(
731
- max_length=255,
732
- null=False,
733
- help_text="file name or folder name",
734
- index=True)
735
- location = CharField(
736
- max_length=255,
737
- null=True,
738
- help_text="where dose it store",
739
- index=True)
740
- size = IntegerField(default=0, index=True)
741
- type = CharField(max_length=32, null=False, help_text="file extension", index=True)
742
- source_type = CharField(
743
- max_length=128,
744
- null=False,
745
- default="",
746
- help_text="where dose this document come from", index=True)
747
-
748
- class Meta:
749
- db_table = "file"
750
-
751
-
752
- class File2Document(DataBaseModel):
753
- id = CharField(
754
- max_length=32,
755
- primary_key=True)
756
- file_id = CharField(
757
- max_length=32,
758
- null=True,
759
- help_text="file id",
760
- index=True)
761
- document_id = CharField(
762
- max_length=32,
763
- null=True,
764
- help_text="document id",
765
- index=True)
766
-
767
- class Meta:
768
- db_table = "file2document"
769
-
770
-
771
- class Task(DataBaseModel):
772
- id = CharField(max_length=32, primary_key=True)
773
- doc_id = CharField(max_length=32, null=False, index=True)
774
- from_page = IntegerField(default=0)
775
-
776
- to_page = IntegerField(default=-1)
777
-
778
- begin_at = DateTimeField(null=True, index=True)
779
- process_duation = FloatField(default=0)
780
-
781
- progress = FloatField(default=0, index=True)
782
- progress_msg = TextField(
783
- null=True,
784
- help_text="process message",
785
- default="")
786
-
787
-
788
- class Dialog(DataBaseModel):
789
- id = CharField(max_length=32, primary_key=True)
790
- tenant_id = CharField(max_length=32, null=False, index=True)
791
- name = CharField(
792
- max_length=255,
793
- null=True,
794
- help_text="dialog application name",
795
- index=True)
796
- description = TextField(null=True, help_text="Dialog description")
797
- icon = TextField(null=True, help_text="icon base64 string")
798
- language = CharField(
799
- max_length=32,
800
- null=True,
801
- default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
802
- help_text="English|Chinese",
803
- index=True)
804
- llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
805
-
806
- llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
807
- "presence_penalty": 0.4, "max_tokens": 512})
808
- prompt_type = CharField(
809
- max_length=16,
810
- null=False,
811
- default="simple",
812
- help_text="simple|advanced",
813
- index=True)
814
- prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
815
- "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
816
-
817
- similarity_threshold = FloatField(default=0.2)
818
- vector_similarity_weight = FloatField(default=0.3)
819
-
820
- top_n = IntegerField(default=6)
821
-
822
- top_k = IntegerField(default=1024)
823
-
824
- do_refer = CharField(
825
- max_length=1,
826
- null=False,
827
- help_text="it needs to insert reference index into answer or not")
828
-
829
- rerank_id = CharField(
830
- max_length=128,
831
- null=False,
832
- help_text="default rerank model ID")
833
-
834
- kb_ids = JSONField(null=False, default=[])
835
- status = CharField(
836
- max_length=1,
837
- null=True,
838
- help_text="is it validate(0: wasted,1: validate)",
839
- default="1",
840
- index=True)
841
-
842
- class Meta:
843
- db_table = "dialog"
844
-
845
-
846
- class Conversation(DataBaseModel):
847
- id = CharField(max_length=32, primary_key=True)
848
- dialog_id = CharField(max_length=32, null=False, index=True)
849
- name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
850
- message = JSONField(null=True)
851
- reference = JSONField(null=True, default=[])
852
-
853
- class Meta:
854
- db_table = "conversation"
855
-
856
-
857
- class APIToken(DataBaseModel):
858
- tenant_id = CharField(max_length=32, null=False, index=True)
859
- token = CharField(max_length=255, null=False, index=True)
860
- dialog_id = CharField(max_length=32, null=False, index=True)
861
- source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
862
-
863
- class Meta:
864
- db_table = "api_token"
865
- primary_key = CompositeKey('tenant_id', 'token')
866
-
867
-
868
- class API4Conversation(DataBaseModel):
869
- id = CharField(max_length=32, primary_key=True)
870
- dialog_id = CharField(max_length=32, null=False, index=True)
871
- user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
872
- message = JSONField(null=True)
873
- reference = JSONField(null=True, default=[])
874
- tokens = IntegerField(default=0)
875
- source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
876
-
877
- duration = FloatField(default=0, index=True)
878
- round = IntegerField(default=0, index=True)
879
- thumb_up = IntegerField(default=0, index=True)
880
-
881
- class Meta:
882
- db_table = "api_4_conversation"
883
-
884
-
885
- class UserCanvas(DataBaseModel):
886
- id = CharField(max_length=32, primary_key=True)
887
- avatar = TextField(null=True, help_text="avatar base64 string")
888
- user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
889
- title = CharField(max_length=255, null=True, help_text="Canvas title")
890
-
891
- description = TextField(null=True, help_text="Canvas description")
892
- canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
893
- dsl = JSONField(null=True, default={})
894
-
895
- class Meta:
896
- db_table = "user_canvas"
897
-
898
-
899
- class CanvasTemplate(DataBaseModel):
900
- id = CharField(max_length=32, primary_key=True)
901
- avatar = TextField(null=True, help_text="avatar base64 string")
902
- title = CharField(max_length=255, null=True, help_text="Canvas title")
903
-
904
- description = TextField(null=True, help_text="Canvas description")
905
- canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
906
- dsl = JSONField(null=True, default={})
907
-
908
- class Meta:
909
- db_table = "canvas_template"
910
-
911
-
912
- def migrate_db():
913
- with DB.transaction():
914
- migrator = MySQLMigrator(DB)
915
- try:
916
- migrate(
917
- migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
918
- help_text="where dose this document come from",
919
- index=True))
920
- )
921
- except Exception as e:
922
- pass
923
- try:
924
- migrate(
925
- migrator.add_column('tenant', 'rerank_id',
926
- CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
927
- help_text="default rerank model ID"))
928
-
929
- )
930
- except Exception as e:
931
- pass
932
- try:
933
- migrate(
934
- migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
935
- help_text="default rerank model ID"))
936
-
937
- )
938
- except Exception as e:
939
- pass
940
- try:
941
- migrate(
942
- migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
943
-
944
- )
945
- except Exception as e:
946
- pass
947
- try:
948
- migrate(
949
- migrator.alter_column_type('tenant_llm', 'api_key',
950
- CharField(max_length=1024, null=True, help_text="API KEY", index=True))
951
- )
952
- except Exception as e:
953
- pass
954
- try:
955
- migrate(
956
- migrator.add_column('api_token', 'source',
957
- CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
958
- )
959
- except Exception as e:
960
- pass
961
- try:
962
- migrate(
963
- migrator.add_column('api_4_conversation', 'source',
964
- CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
965
- )
966
- except Exception as e:
967
- pass
968
- try:
969
- DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
970
- DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
971
- except Exception as e:
972
- pass
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import inspect
17
+ import os
18
+ import sys
19
+ import typing
20
+ import operator
21
+ from functools import wraps
22
+ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
23
+ from flask_login import UserMixin
24
+ from playhouse.migrate import MySQLMigrator, migrate
25
+ from peewee import (
26
+ BigIntegerField, BooleanField, CharField,
27
+ CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
28
+ Field, Model, Metadata
29
+ )
30
+ from playhouse.pool import PooledMySQLDatabase
31
+ from api.db import SerializedType, ParserType
32
+ from api.settings import DATABASE, stat_logger, SECRET_KEY
33
+ from api.utils.log_utils import getLogger
34
+ from api import utils
35
+
36
+ LOGGER = getLogger()
37
+
38
+
39
+ def singleton(cls, *args, **kw):
40
+ instances = {}
41
+
42
+ def _singleton():
43
+ key = str(cls) + str(os.getpid())
44
+ if key not in instances:
45
+ instances[key] = cls(*args, **kw)
46
+ return instances[key]
47
+
48
+ return _singleton
49
+
50
+
51
+ CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
52
+ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
53
+ "create",
54
+ "start",
55
+ "end",
56
+ "update",
57
+ "read_access",
58
+ "write_access"}
59
+
60
+
61
+ class LongTextField(TextField):
62
+ field_type = 'LONGTEXT'
63
+
64
+
65
+ class JSONField(LongTextField):
66
+ default_value = {}
67
+
68
+ def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
69
+ self._object_hook = object_hook
70
+ self._object_pairs_hook = object_pairs_hook
71
+ super().__init__(**kwargs)
72
+
73
+ def db_value(self, value):
74
+ if value is None:
75
+ value = self.default_value
76
+ return utils.json_dumps(value)
77
+
78
+ def python_value(self, value):
79
+ if not value:
80
+ return self.default_value
81
+ return utils.json_loads(
82
+ value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
83
+
84
+
85
+ class ListField(JSONField):
86
+ default_value = []
87
+
88
+
89
+ class SerializedField(LongTextField):
90
+ def __init__(self, serialized_type=SerializedType.PICKLE,
91
+ object_hook=None, object_pairs_hook=None, **kwargs):
92
+ self._serialized_type = serialized_type
93
+ self._object_hook = object_hook
94
+ self._object_pairs_hook = object_pairs_hook
95
+ super().__init__(**kwargs)
96
+
97
+ def db_value(self, value):
98
+ if self._serialized_type == SerializedType.PICKLE:
99
+ return utils.serialize_b64(value, to_str=True)
100
+ elif self._serialized_type == SerializedType.JSON:
101
+ if value is None:
102
+ return None
103
+ return utils.json_dumps(value, with_type=True)
104
+ else:
105
+ raise ValueError(
106
+ f"the serialized type {self._serialized_type} is not supported")
107
+
108
+ def python_value(self, value):
109
+ if self._serialized_type == SerializedType.PICKLE:
110
+ return utils.deserialize_b64(value)
111
+ elif self._serialized_type == SerializedType.JSON:
112
+ if value is None:
113
+ return {}
114
+ return utils.json_loads(
115
+ value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
116
+ else:
117
+ raise ValueError(
118
+ f"the serialized type {self._serialized_type} is not supported")
119
+
120
+
121
+ def is_continuous_field(cls: typing.Type) -> bool:
122
+ if cls in CONTINUOUS_FIELD_TYPE:
123
+ return True
124
+ for p in cls.__bases__:
125
+ if p in CONTINUOUS_FIELD_TYPE:
126
+ return True
127
+ elif p != Field and p != object:
128
+ if is_continuous_field(p):
129
+ return True
130
+ else:
131
+ return False
132
+
133
+
134
+ def auto_date_timestamp_field():
135
+ return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
136
+
137
+
138
+ def auto_date_timestamp_db_field():
139
+ return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
140
+
141
+
142
+ def remove_field_name_prefix(field_name):
143
+ return field_name[2:] if field_name.startswith('f_') else field_name
144
+
145
+
146
+ class BaseModel(Model):
147
+ create_time = BigIntegerField(null=True, index=True)
148
+ create_date = DateTimeField(null=True, index=True)
149
+ update_time = BigIntegerField(null=True, index=True)
150
+ update_date = DateTimeField(null=True, index=True)
151
+
152
+ def to_json(self):
153
+ # This function is obsolete
154
+ return self.to_dict()
155
+
156
+ def to_dict(self):
157
+ return self.__dict__['__data__']
158
+
159
+ def to_human_model_dict(self, only_primary_with: list = None):
160
+ model_dict = self.__dict__['__data__']
161
+
162
+ if not only_primary_with:
163
+ return {remove_field_name_prefix(
164
+ k): v for k, v in model_dict.items()}
165
+
166
+ human_model_dict = {}
167
+ for k in self._meta.primary_key.field_names:
168
+ human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
169
+ for k in only_primary_with:
170
+ human_model_dict[k] = model_dict[f'f_{k}']
171
+ return human_model_dict
172
+
173
+ @property
174
+ def meta(self) -> Metadata:
175
+ return self._meta
176
+
177
+ @classmethod
178
+ def get_primary_keys_name(cls):
179
+ return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
180
+ cls._meta.primary_key.name]
181
+
182
+ @classmethod
183
+ def getter_by(cls, attr):
184
+ return operator.attrgetter(attr)(cls)
185
+
186
+ @classmethod
187
+ def query(cls, reverse=None, order_by=None, **kwargs):
188
+ filters = []
189
+ for f_n, f_v in kwargs.items():
190
+ attr_name = '%s' % f_n
191
+ if not hasattr(cls, attr_name) or f_v is None:
192
+ continue
193
+ if type(f_v) in {list, set}:
194
+ f_v = list(f_v)
195
+ if is_continuous_field(type(getattr(cls, attr_name))):
196
+ if len(f_v) == 2:
197
+ for i, v in enumerate(f_v):
198
+ if isinstance(
199
+ v, str) and f_n in auto_date_timestamp_field():
200
+ # time type: %Y-%m-%d %H:%M:%S
201
+ f_v[i] = utils.date_string_to_timestamp(v)
202
+ lt_value = f_v[0]
203
+ gt_value = f_v[1]
204
+ if lt_value is not None and gt_value is not None:
205
+ filters.append(
206
+ cls.getter_by(attr_name).between(
207
+ lt_value, gt_value))
208
+ elif lt_value is not None:
209
+ filters.append(
210
+ operator.attrgetter(attr_name)(cls) >= lt_value)
211
+ elif gt_value is not None:
212
+ filters.append(
213
+ operator.attrgetter(attr_name)(cls) <= gt_value)
214
+ else:
215
+ filters.append(operator.attrgetter(attr_name)(cls) << f_v)
216
+ else:
217
+ filters.append(operator.attrgetter(attr_name)(cls) == f_v)
218
+ if filters:
219
+ query_records = cls.select().where(*filters)
220
+ if reverse is not None:
221
+ if not order_by or not hasattr(cls, f"{order_by}"):
222
+ order_by = "create_time"
223
+ if reverse is True:
224
+ query_records = query_records.order_by(
225
+ cls.getter_by(f"{order_by}").desc())
226
+ elif reverse is False:
227
+ query_records = query_records.order_by(
228
+ cls.getter_by(f"{order_by}").asc())
229
+ return [query_record for query_record in query_records]
230
+ else:
231
+ return []
232
+
233
+ @classmethod
234
+ def insert(cls, __data=None, **insert):
235
+ if isinstance(__data, dict) and __data:
236
+ __data[cls._meta.combined["create_time"]
237
+ ] = utils.current_timestamp()
238
+ if insert:
239
+ insert["create_time"] = utils.current_timestamp()
240
+
241
+ return super().insert(__data, **insert)
242
+
243
+ # update and insert will call this method
244
+ @classmethod
245
+ def _normalize_data(cls, data, kwargs):
246
+ normalized = super()._normalize_data(data, kwargs)
247
+ if not normalized:
248
+ return {}
249
+
250
+ normalized[cls._meta.combined["update_time"]
251
+ ] = utils.current_timestamp()
252
+
253
+ for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
254
+ if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
255
+ cls._meta.combined[f"{f_n}_time"] in normalized and \
256
+ normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
257
+ normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
258
+ normalized[cls._meta.combined[f"{f_n}_time"]])
259
+
260
+ return normalized
261
+
262
+
263
+ class JsonSerializedField(SerializedField):
264
+ def __init__(self, object_hook=utils.from_dict_hook,
265
+ object_pairs_hook=None, **kwargs):
266
+ super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
267
+ object_pairs_hook=object_pairs_hook, **kwargs)
268
+
269
+
270
+ @singleton
271
+ class BaseDataBase:
272
+ def __init__(self):
273
+ database_config = DATABASE.copy()
274
+ db_name = database_config.pop("name")
275
+ self.database_connection = PooledMySQLDatabase(
276
+ db_name, **database_config)
277
+ stat_logger.info('init mysql database on cluster mode successfully')
278
+
279
+
280
+ class DatabaseLock:
281
+ def __init__(self, lock_name, timeout=10, db=None):
282
+ self.lock_name = lock_name
283
+ self.timeout = int(timeout)
284
+ self.db = db if db else DB
285
+
286
+ def lock(self):
287
+ # SQL parameters only support %s format placeholders
288
+ cursor = self.db.execute_sql(
289
+ "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
290
+ ret = cursor.fetchone()
291
+ if ret[0] == 0:
292
+ raise Exception(f'acquire mysql lock {self.lock_name} timeout')
293
+ elif ret[0] == 1:
294
+ return True
295
+ else:
296
+ raise Exception(f'failed to acquire lock {self.lock_name}')
297
+
298
+ def unlock(self):
299
+ cursor = self.db.execute_sql(
300
+ "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
301
+ ret = cursor.fetchone()
302
+ if ret[0] == 0:
303
+ raise Exception(
304
+ f'mysql lock {self.lock_name} was not established by this thread')
305
+ elif ret[0] == 1:
306
+ return True
307
+ else:
308
+ raise Exception(f'mysql lock {self.lock_name} does not exist')
309
+
310
+ def __enter__(self):
311
+ if isinstance(self.db, PooledMySQLDatabase):
312
+ self.lock()
313
+ return self
314
+
315
+ def __exit__(self, exc_type, exc_val, exc_tb):
316
+ if isinstance(self.db, PooledMySQLDatabase):
317
+ self.unlock()
318
+
319
+ def __call__(self, func):
320
+ @wraps(func)
321
+ def magic(*args, **kwargs):
322
+ with self:
323
+ return func(*args, **kwargs)
324
+
325
+ return magic
326
+
327
+
328
+ DB = BaseDataBase().database_connection
329
+ DB.lock = DatabaseLock
330
+
331
+
332
+ def close_connection():
333
+ try:
334
+ if DB:
335
+ DB.close_stale(age=30)
336
+ except Exception as e:
337
+ LOGGER.exception(e)
338
+
339
+
340
+ class DataBaseModel(BaseModel):
341
+ class Meta:
342
+ database = DB
343
+
344
+
345
+ @DB.connection_context()
346
+ def init_database_tables(alter_fields=[]):
347
+ members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
348
+ table_objs = []
349
+ create_failed_list = []
350
+ for name, obj in members:
351
+ if obj != DataBaseModel and issubclass(obj, DataBaseModel):
352
+ table_objs.append(obj)
353
+ LOGGER.info(f"start create table {obj.__name__}")
354
+ try:
355
+ obj.create_table()
356
+ LOGGER.info(f"create table success: {obj.__name__}")
357
+ except Exception as e:
358
+ LOGGER.exception(e)
359
+ create_failed_list.append(obj.__name__)
360
+ if create_failed_list:
361
+ LOGGER.info(f"create tables failed: {create_failed_list}")
362
+ raise Exception(f"create tables failed: {create_failed_list}")
363
+ migrate_db()
364
+
365
+
366
+ def fill_db_model_object(model_object, human_model_dict):
367
+ for k, v in human_model_dict.items():
368
+ attr_name = '%s' % k
369
+ if hasattr(model_object.__class__, attr_name):
370
+ setattr(model_object, attr_name, v)
371
+ return model_object
372
+
373
+
374
+ class User(DataBaseModel, UserMixin):
375
+ id = CharField(max_length=32, primary_key=True)
376
+ access_token = CharField(max_length=255, null=True, index=True)
377
+ nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
378
+ password = CharField(max_length=255, null=True, help_text="password", index=True)
379
+ email = CharField(
380
+ max_length=255,
381
+ null=False,
382
+ help_text="email",
383
+ index=True)
384
+ avatar = TextField(null=True, help_text="avatar base64 string")
385
+ language = CharField(
386
+ max_length=32,
387
+ null=True,
388
+ help_text="English|Chinese",
389
+ default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
390
+ index=True)
391
+ color_schema = CharField(
392
+ max_length=32,
393
+ null=True,
394
+ help_text="Bright|Dark",
395
+ default="Bright",
396
+ index=True)
397
+ timezone = CharField(
398
+ max_length=64,
399
+ null=True,
400
+ help_text="Timezone",
401
+ default="UTC+8\tAsia/Shanghai",
402
+ index=True)
403
+ last_login_time = DateTimeField(null=True, index=True)
404
+ is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
405
+ is_active = CharField(max_length=1, null=False, default="1", index=True)
406
+ is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
407
+ login_channel = CharField(null=True, help_text="from which user login", index=True)
408
+ status = CharField(
409
+ max_length=1,
410
+ null=True,
411
+ help_text="is it validate(0: wasted,1: validate)",
412
+ default="1",
413
+ index=True)
414
+ is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
415
+
416
+ def __str__(self):
417
+ return self.email
418
+
419
+ def get_id(self):
420
+ jwt = Serializer(secret_key=SECRET_KEY)
421
+ return jwt.dumps(str(self.access_token))
422
+
423
+ class Meta:
424
+ db_table = "user"
425
+
426
+
427
+ class Tenant(DataBaseModel):
428
+ id = CharField(max_length=32, primary_key=True)
429
+ name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
430
+ public_key = CharField(max_length=255, null=True, index=True)
431
+ llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
432
+ embd_id = CharField(
433
+ max_length=128,
434
+ null=False,
435
+ help_text="default embedding model ID",
436
+ index=True)
437
+ asr_id = CharField(
438
+ max_length=128,
439
+ null=False,
440
+ help_text="default ASR model ID",
441
+ index=True)
442
+ img2txt_id = CharField(
443
+ max_length=128,
444
+ null=False,
445
+ help_text="default image to text model ID",
446
+ index=True)
447
+ rerank_id = CharField(
448
+ max_length=128,
449
+ null=False,
450
+ help_text="default rerank model ID",
451
+ index=True)
452
+ parser_ids = CharField(
453
+ max_length=256,
454
+ null=False,
455
+ help_text="document processors",
456
+ index=True)
457
+ credit = IntegerField(default=512, index=True)
458
+ status = CharField(
459
+ max_length=1,
460
+ null=True,
461
+ help_text="is it validate(0: wasted,1: validate)",
462
+ default="1",
463
+ index=True)
464
+
465
+ class Meta:
466
+ db_table = "tenant"
467
+
468
+
469
+ class UserTenant(DataBaseModel):
470
+ id = CharField(max_length=32, primary_key=True)
471
+ user_id = CharField(max_length=32, null=False, index=True)
472
+ tenant_id = CharField(max_length=32, null=False, index=True)
473
+ role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
474
+ invited_by = CharField(max_length=32, null=False, index=True)
475
+ status = CharField(
476
+ max_length=1,
477
+ null=True,
478
+ help_text="is it validate(0: wasted,1: validate)",
479
+ default="1",
480
+ index=True)
481
+
482
+ class Meta:
483
+ db_table = "user_tenant"
484
+
485
+
486
+ class InvitationCode(DataBaseModel):
487
+ id = CharField(max_length=32, primary_key=True)
488
+ code = CharField(max_length=32, null=False, index=True)
489
+ visit_time = DateTimeField(null=True, index=True)
490
+ user_id = CharField(max_length=32, null=True, index=True)
491
+ tenant_id = CharField(max_length=32, null=True, index=True)
492
+ status = CharField(
493
+ max_length=1,
494
+ null=True,
495
+ help_text="is it validate(0: wasted,1: validate)",
496
+ default="1",
497
+ index=True)
498
+
499
+ class Meta:
500
+ db_table = "invitation_code"
501
+
502
+
503
+ class LLMFactories(DataBaseModel):
504
+ name = CharField(
505
+ max_length=128,
506
+ null=False,
507
+ help_text="LLM factory name",
508
+ primary_key=True)
509
+ logo = TextField(null=True, help_text="llm logo base64")
510
+ tags = CharField(
511
+ max_length=255,
512
+ null=False,
513
+ help_text="LLM, Text Embedding, Image2Text, ASR",
514
+ index=True)
515
+ status = CharField(
516
+ max_length=1,
517
+ null=True,
518
+ help_text="is it validate(0: wasted,1: validate)",
519
+ default="1",
520
+ index=True)
521
+
522
+ def __str__(self):
523
+ return self.name
524
+
525
+ class Meta:
526
+ db_table = "llm_factories"
527
+
528
+
529
+ class LLM(DataBaseModel):
530
+ # LLMs dictionary
531
+ llm_name = CharField(
532
+ max_length=128,
533
+ null=False,
534
+ help_text="LLM name",
535
+ index=True)
536
+ model_type = CharField(
537
+ max_length=128,
538
+ null=False,
539
+ help_text="LLM, Text Embedding, Image2Text, ASR",
540
+ index=True)
541
+ fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
542
+ max_tokens = IntegerField(default=0)
543
+
544
+ tags = CharField(
545
+ max_length=255,
546
+ null=False,
547
+ help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
548
+ index=True)
549
+ status = CharField(
550
+ max_length=1,
551
+ null=True,
552
+ help_text="is it validate(0: wasted,1: validate)",
553
+ default="1",
554
+ index=True)
555
+
556
+ def __str__(self):
557
+ return self.llm_name
558
+
559
+ class Meta:
560
+ primary_key = CompositeKey('fid', 'llm_name')
561
+ db_table = "llm"
562
+
563
+
564
+ class TenantLLM(DataBaseModel):
565
+ tenant_id = CharField(max_length=32, null=False, index=True)
566
+ llm_factory = CharField(
567
+ max_length=128,
568
+ null=False,
569
+ help_text="LLM factory name",
570
+ index=True)
571
+ model_type = CharField(
572
+ max_length=128,
573
+ null=True,
574
+ help_text="LLM, Text Embedding, Image2Text, ASR",
575
+ index=True)
576
+ llm_name = CharField(
577
+ max_length=128,
578
+ null=True,
579
+ help_text="LLM name",
580
+ default="",
581
+ index=True)
582
+ api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
583
+ api_base = CharField(max_length=255, null=True, help_text="API Base")
584
+
585
+ used_tokens = IntegerField(default=0, index=True)
586
+
587
+ def __str__(self):
588
+ return self.llm_name
589
+
590
+ class Meta:
591
+ db_table = "tenant_llm"
592
+ primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
593
+
594
+
595
+ class Knowledgebase(DataBaseModel):
596
+ id = CharField(max_length=32, primary_key=True)
597
+ avatar = TextField(null=True, help_text="avatar base64 string")
598
+ tenant_id = CharField(max_length=32, null=False, index=True)
599
+ name = CharField(
600
+ max_length=128,
601
+ null=False,
602
+ help_text="KB name",
603
+ index=True)
604
+ language = CharField(
605
+ max_length=32,
606
+ null=True,
607
+ default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
608
+ help_text="English|Chinese",
609
+ index=True)
610
+ description = TextField(null=True, help_text="KB description")
611
+ embd_id = CharField(
612
+ max_length=128,
613
+ null=False,
614
+ help_text="default embedding model ID",
615
+ index=True)
616
+ permission = CharField(
617
+ max_length=16,
618
+ null=False,
619
+ help_text="me|team",
620
+ default="me",
621
+ index=True)
622
+ created_by = CharField(max_length=32, null=False, index=True)
623
+ doc_num = IntegerField(default=0, index=True)
624
+ token_num = IntegerField(default=0, index=True)
625
+ chunk_num = IntegerField(default=0, index=True)
626
+ similarity_threshold = FloatField(default=0.2, index=True)
627
+ vector_similarity_weight = FloatField(default=0.3, index=True)
628
+
629
+ parser_id = CharField(
630
+ max_length=32,
631
+ null=False,
632
+ help_text="default parser ID",
633
+ default=ParserType.NAIVE.value,
634
+ index=True)
635
+ parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
636
+ status = CharField(
637
+ max_length=1,
638
+ null=True,
639
+ help_text="is it validate(0: wasted,1: validate)",
640
+ default="1",
641
+ index=True)
642
+
643
+ def __str__(self):
644
+ return self.name
645
+
646
+ class Meta:
647
+ db_table = "knowledgebase"
648
+
649
+
650
+ class Document(DataBaseModel):
651
+ id = CharField(max_length=32, primary_key=True)
652
+ thumbnail = TextField(null=True, help_text="thumbnail base64 string")
653
+ kb_id = CharField(max_length=256, null=False, index=True)
654
+ parser_id = CharField(
655
+ max_length=32,
656
+ null=False,
657
+ help_text="default parser ID",
658
+ index=True)
659
+ parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
660
+ source_type = CharField(
661
+ max_length=128,
662
+ null=False,
663
+ default="local",
664
+ help_text="where dose this document come from",
665
+ index=True)
666
+ type = CharField(max_length=32, null=False, help_text="file extension",
667
+ index=True)
668
+ created_by = CharField(
669
+ max_length=32,
670
+ null=False,
671
+ help_text="who created it",
672
+ index=True)
673
+ name = CharField(
674
+ max_length=255,
675
+ null=True,
676
+ help_text="file name",
677
+ index=True)
678
+ location = CharField(
679
+ max_length=255,
680
+ null=True,
681
+ help_text="where dose it store",
682
+ index=True)
683
+ size = IntegerField(default=0, index=True)
684
+ token_num = IntegerField(default=0, index=True)
685
+ chunk_num = IntegerField(default=0, index=True)
686
+ progress = FloatField(default=0, index=True)
687
+ progress_msg = TextField(
688
+ null=True,
689
+ help_text="process message",
690
+ default="")
691
+ process_begin_at = DateTimeField(null=True, index=True)
692
+ process_duation = FloatField(default=0)
693
+
694
+ run = CharField(
695
+ max_length=1,
696
+ null=True,
697
+ help_text="start to run processing or cancel.(1: run it; 2: cancel)",
698
+ default="0",
699
+ index=True)
700
+ status = CharField(
701
+ max_length=1,
702
+ null=True,
703
+ help_text="is it validate(0: wasted,1: validate)",
704
+ default="1",
705
+ index=True)
706
+
707
+ class Meta:
708
+ db_table = "document"
709
+
710
+
711
+ class File(DataBaseModel):
712
+ id = CharField(
713
+ max_length=32,
714
+ primary_key=True)
715
+ parent_id = CharField(
716
+ max_length=32,
717
+ null=False,
718
+ help_text="parent folder id",
719
+ index=True)
720
+ tenant_id = CharField(
721
+ max_length=32,
722
+ null=False,
723
+ help_text="tenant id",
724
+ index=True)
725
+ created_by = CharField(
726
+ max_length=32,
727
+ null=False,
728
+ help_text="who created it",
729
+ index=True)
730
+ name = CharField(
731
+ max_length=255,
732
+ null=False,
733
+ help_text="file name or folder name",
734
+ index=True)
735
+ location = CharField(
736
+ max_length=255,
737
+ null=True,
738
+ help_text="where dose it store",
739
+ index=True)
740
+ size = IntegerField(default=0, index=True)
741
+ type = CharField(max_length=32, null=False, help_text="file extension", index=True)
742
+ source_type = CharField(
743
+ max_length=128,
744
+ null=False,
745
+ default="",
746
+ help_text="where dose this document come from", index=True)
747
+
748
+ class Meta:
749
+ db_table = "file"
750
+
751
+
752
+ class File2Document(DataBaseModel):
753
+ id = CharField(
754
+ max_length=32,
755
+ primary_key=True)
756
+ file_id = CharField(
757
+ max_length=32,
758
+ null=True,
759
+ help_text="file id",
760
+ index=True)
761
+ document_id = CharField(
762
+ max_length=32,
763
+ null=True,
764
+ help_text="document id",
765
+ index=True)
766
+
767
+ class Meta:
768
+ db_table = "file2document"
769
+
770
+
771
+ class Task(DataBaseModel):
772
+ id = CharField(max_length=32, primary_key=True)
773
+ doc_id = CharField(max_length=32, null=False, index=True)
774
+ from_page = IntegerField(default=0)
775
+
776
+ to_page = IntegerField(default=-1)
777
+
778
+ begin_at = DateTimeField(null=True, index=True)
779
+ process_duation = FloatField(default=0)
780
+
781
+ progress = FloatField(default=0, index=True)
782
+ progress_msg = TextField(
783
+ null=True,
784
+ help_text="process message",
785
+ default="")
786
+
787
+
788
+ class Dialog(DataBaseModel):
789
+ id = CharField(max_length=32, primary_key=True)
790
+ tenant_id = CharField(max_length=32, null=False, index=True)
791
+ name = CharField(
792
+ max_length=255,
793
+ null=True,
794
+ help_text="dialog application name",
795
+ index=True)
796
+ description = TextField(null=True, help_text="Dialog description")
797
+ icon = TextField(null=True, help_text="icon base64 string")
798
+ language = CharField(
799
+ max_length=32,
800
+ null=True,
801
+ default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
802
+ help_text="English|Chinese",
803
+ index=True)
804
+ llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
805
+
806
+ llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
807
+ "presence_penalty": 0.4, "max_tokens": 512})
808
+ prompt_type = CharField(
809
+ max_length=16,
810
+ null=False,
811
+ default="simple",
812
+ help_text="simple|advanced",
813
+ index=True)
814
+ prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
815
+ "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
816
+
817
+ similarity_threshold = FloatField(default=0.2)
818
+ vector_similarity_weight = FloatField(default=0.3)
819
+
820
+ top_n = IntegerField(default=6)
821
+
822
+ top_k = IntegerField(default=1024)
823
+
824
+ do_refer = CharField(
825
+ max_length=1,
826
+ null=False,
827
+ help_text="it needs to insert reference index into answer or not")
828
+
829
+ rerank_id = CharField(
830
+ max_length=128,
831
+ null=False,
832
+ help_text="default rerank model ID")
833
+
834
+ kb_ids = JSONField(null=False, default=[])
835
+ status = CharField(
836
+ max_length=1,
837
+ null=True,
838
+ help_text="is it validate(0: wasted,1: validate)",
839
+ default="1",
840
+ index=True)
841
+
842
+ class Meta:
843
+ db_table = "dialog"
844
+
845
+
846
+ class Conversation(DataBaseModel):
847
+ id = CharField(max_length=32, primary_key=True)
848
+ dialog_id = CharField(max_length=32, null=False, index=True)
849
+ name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
850
+ message = JSONField(null=True)
851
+ reference = JSONField(null=True, default=[])
852
+
853
+ class Meta:
854
+ db_table = "conversation"
855
+
856
+
857
+ class APIToken(DataBaseModel):
858
+ tenant_id = CharField(max_length=32, null=False, index=True)
859
+ token = CharField(max_length=255, null=False, index=True)
860
+ dialog_id = CharField(max_length=32, null=False, index=True)
861
+ source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
862
+
863
+ class Meta:
864
+ db_table = "api_token"
865
+ primary_key = CompositeKey('tenant_id', 'token')
866
+
867
+
868
+ class API4Conversation(DataBaseModel):
869
+ id = CharField(max_length=32, primary_key=True)
870
+ dialog_id = CharField(max_length=32, null=False, index=True)
871
+ user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
872
+ message = JSONField(null=True)
873
+ reference = JSONField(null=True, default=[])
874
+ tokens = IntegerField(default=0)
875
+ source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
876
+
877
+ duration = FloatField(default=0, index=True)
878
+ round = IntegerField(default=0, index=True)
879
+ thumb_up = IntegerField(default=0, index=True)
880
+
881
+ class Meta:
882
+ db_table = "api_4_conversation"
883
+
884
+
885
+ class UserCanvas(DataBaseModel):
886
+ id = CharField(max_length=32, primary_key=True)
887
+ avatar = TextField(null=True, help_text="avatar base64 string")
888
+ user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
889
+ title = CharField(max_length=255, null=True, help_text="Canvas title")
890
+
891
+ description = TextField(null=True, help_text="Canvas description")
892
+ canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
893
+ dsl = JSONField(null=True, default={})
894
+
895
+ class Meta:
896
+ db_table = "user_canvas"
897
+
898
+
899
+ class CanvasTemplate(DataBaseModel):
900
+ id = CharField(max_length=32, primary_key=True)
901
+ avatar = TextField(null=True, help_text="avatar base64 string")
902
+ title = CharField(max_length=255, null=True, help_text="Canvas title")
903
+
904
+ description = TextField(null=True, help_text="Canvas description")
905
+ canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
906
+ dsl = JSONField(null=True, default={})
907
+
908
+ class Meta:
909
+ db_table = "canvas_template"
910
+
911
+
912
+ def migrate_db():
913
+ with DB.transaction():
914
+ migrator = MySQLMigrator(DB)
915
+ try:
916
+ migrate(
917
+ migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
918
+ help_text="where dose this document come from",
919
+ index=True))
920
+ )
921
+ except Exception as e:
922
+ pass
923
+ try:
924
+ migrate(
925
+ migrator.add_column('tenant', 'rerank_id',
926
+ CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
927
+ help_text="default rerank model ID"))
928
+
929
+ )
930
+ except Exception as e:
931
+ pass
932
+ try:
933
+ migrate(
934
+ migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
935
+ help_text="default rerank model ID"))
936
+
937
+ )
938
+ except Exception as e:
939
+ pass
940
+ try:
941
+ migrate(
942
+ migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
943
+
944
+ )
945
+ except Exception as e:
946
+ pass
947
+ try:
948
+ migrate(
949
+ migrator.alter_column_type('tenant_llm', 'api_key',
950
+ CharField(max_length=1024, null=True, help_text="API KEY", index=True))
951
+ )
952
+ except Exception as e:
953
+ pass
954
+ try:
955
+ migrate(
956
+ migrator.add_column('api_token', 'source',
957
+ CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
958
+ )
959
+ except Exception as e:
960
+ pass
961
+ try:
962
+ migrate(
963
+ migrator.add_column('api_4_conversation', 'source',
964
+ CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
965
+ )
966
+ except Exception as e:
967
+ pass
968
+ try:
969
+ DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
970
+ DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
971
+ except Exception as e:
972
+ pass
api/db/db_utils.py CHANGED
@@ -1,130 +1,130 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import operator
17
- from functools import reduce
18
- from typing import Dict, Type, Union
19
-
20
- from api.utils import current_timestamp, timestamp_to_date
21
-
22
- from api.db.db_models import DB, DataBaseModel
23
- from api.db.runtime_config import RuntimeConfig
24
- from api.utils.log_utils import getLogger
25
- from enum import Enum
26
-
27
-
28
- LOGGER = getLogger()
29
-
30
-
31
- @DB.connection_context()
32
- def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
33
- DB.create_tables([model])
34
-
35
- for i, data in enumerate(data_source):
36
- current_time = current_timestamp() + i
37
- current_date = timestamp_to_date(current_time)
38
- if 'create_time' not in data:
39
- data['create_time'] = current_time
40
- data['create_date'] = timestamp_to_date(data['create_time'])
41
- data['update_time'] = current_time
42
- data['update_date'] = current_date
43
-
44
- preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'})
45
-
46
- batch_size = 1000
47
-
48
- for i in range(0, len(data_source), batch_size):
49
- with DB.atomic():
50
- query = model.insert_many(data_source[i:i + batch_size])
51
- if replace_on_conflict:
52
- query = query.on_conflict(preserve=preserve)
53
- query.execute()
54
-
55
-
56
- def get_dynamic_db_model(base, job_id):
57
- return type(base.model(
58
- table_index=get_dynamic_tracking_table_index(job_id=job_id)))
59
-
60
-
61
- def get_dynamic_tracking_table_index(job_id):
62
- return job_id[:8]
63
-
64
-
65
- def fill_db_model_object(model_object, human_model_dict):
66
- for k, v in human_model_dict.items():
67
- attr_name = 'f_%s' % k
68
- if hasattr(model_object.__class__, attr_name):
69
- setattr(model_object, attr_name, v)
70
- return model_object
71
-
72
-
73
- # https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
74
- supported_operators = {
75
- '==': operator.eq,
76
- '<': operator.lt,
77
- '<=': operator.le,
78
- '>': operator.gt,
79
- '>=': operator.ge,
80
- '!=': operator.ne,
81
- '<<': operator.lshift,
82
- '>>': operator.rshift,
83
- '%': operator.mod,
84
- '**': operator.pow,
85
- '^': operator.xor,
86
- '~': operator.inv,
87
- }
88
-
89
-
90
- def query_dict2expression(
91
- model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
92
- expression = []
93
-
94
- for field, value in query.items():
95
- if not isinstance(value, (list, tuple)):
96
- value = ('==', value)
97
- op, *val = value
98
-
99
- field = getattr(model, f'f_{field}')
100
- value = supported_operators[op](
101
- field, val[0]) if op in supported_operators else getattr(
102
- field, op)(
103
- *val)
104
- expression.append(value)
105
-
106
- return reduce(operator.iand, expression)
107
-
108
-
109
- def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
110
- query: dict = None, order_by: Union[str, list, tuple] = None):
111
- data = model.select()
112
- if query:
113
- data = data.where(query_dict2expression(model, query))
114
- count = data.count()
115
-
116
- if not order_by:
117
- order_by = 'create_time'
118
- if not isinstance(order_by, (list, tuple)):
119
- order_by = (order_by, 'asc')
120
- order_by, order = order_by
121
- order_by = getattr(model, f'f_{order_by}')
122
- order_by = getattr(order_by, order)()
123
- data = data.order_by(order_by)
124
-
125
- if limit > 0:
126
- data = data.limit(limit)
127
- if offset > 0:
128
- data = data.offset(offset)
129
-
130
- return list(data), count
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import operator
17
+ from functools import reduce
18
+ from typing import Dict, Type, Union
19
+
20
+ from api.utils import current_timestamp, timestamp_to_date
21
+
22
+ from api.db.db_models import DB, DataBaseModel
23
+ from api.db.runtime_config import RuntimeConfig
24
+ from api.utils.log_utils import getLogger
25
+ from enum import Enum
26
+
27
+
28
+ LOGGER = getLogger()
29
+
30
+
31
+ @DB.connection_context()
32
+ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
33
+ DB.create_tables([model])
34
+
35
+ for i, data in enumerate(data_source):
36
+ current_time = current_timestamp() + i
37
+ current_date = timestamp_to_date(current_time)
38
+ if 'create_time' not in data:
39
+ data['create_time'] = current_time
40
+ data['create_date'] = timestamp_to_date(data['create_time'])
41
+ data['update_time'] = current_time
42
+ data['update_date'] = current_date
43
+
44
+ preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'})
45
+
46
+ batch_size = 1000
47
+
48
+ for i in range(0, len(data_source), batch_size):
49
+ with DB.atomic():
50
+ query = model.insert_many(data_source[i:i + batch_size])
51
+ if replace_on_conflict:
52
+ query = query.on_conflict(preserve=preserve)
53
+ query.execute()
54
+
55
+
56
+ def get_dynamic_db_model(base, job_id):
57
+ return type(base.model(
58
+ table_index=get_dynamic_tracking_table_index(job_id=job_id)))
59
+
60
+
61
+ def get_dynamic_tracking_table_index(job_id):
62
+ return job_id[:8]
63
+
64
+
65
+ def fill_db_model_object(model_object, human_model_dict):
66
+ for k, v in human_model_dict.items():
67
+ attr_name = 'f_%s' % k
68
+ if hasattr(model_object.__class__, attr_name):
69
+ setattr(model_object, attr_name, v)
70
+ return model_object
71
+
72
+
73
+ # https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
74
+ supported_operators = {
75
+ '==': operator.eq,
76
+ '<': operator.lt,
77
+ '<=': operator.le,
78
+ '>': operator.gt,
79
+ '>=': operator.ge,
80
+ '!=': operator.ne,
81
+ '<<': operator.lshift,
82
+ '>>': operator.rshift,
83
+ '%': operator.mod,
84
+ '**': operator.pow,
85
+ '^': operator.xor,
86
+ '~': operator.inv,
87
+ }
88
+
89
+
90
+ def query_dict2expression(
91
+ model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
92
+ expression = []
93
+
94
+ for field, value in query.items():
95
+ if not isinstance(value, (list, tuple)):
96
+ value = ('==', value)
97
+ op, *val = value
98
+
99
+ field = getattr(model, f'f_{field}')
100
+ value = supported_operators[op](
101
+ field, val[0]) if op in supported_operators else getattr(
102
+ field, op)(
103
+ *val)
104
+ expression.append(value)
105
+
106
+ return reduce(operator.iand, expression)
107
+
108
+
109
+ def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
110
+ query: dict = None, order_by: Union[str, list, tuple] = None):
111
+ data = model.select()
112
+ if query:
113
+ data = data.where(query_dict2expression(model, query))
114
+ count = data.count()
115
+
116
+ if not order_by:
117
+ order_by = 'create_time'
118
+ if not isinstance(order_by, (list, tuple)):
119
+ order_by = (order_by, 'asc')
120
+ order_by, order = order_by
121
+ order_by = getattr(model, f'f_{order_by}')
122
+ order_by = getattr(order_by, order)()
123
+ data = data.order_by(order_by)
124
+
125
+ if limit > 0:
126
+ data = data.limit(limit)
127
+ if offset > 0:
128
+ data = data.offset(offset)
129
+
130
+ return list(data), count
api/db/init_data.py CHANGED
@@ -1,184 +1,184 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import json
17
- import os
18
- import time
19
- import uuid
20
- from copy import deepcopy
21
-
22
- from api.db import LLMType, UserTenantRole
23
- from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
24
- from api.db.services import UserService
25
- from api.db.services.canvas_service import CanvasTemplateService
26
- from api.db.services.document_service import DocumentService
27
- from api.db.services.knowledgebase_service import KnowledgebaseService
28
- from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
29
- from api.db.services.user_service import TenantService, UserTenantService
30
- from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
31
- from api.utils.file_utils import get_project_base_directory
32
-
33
-
34
- def init_superuser():
35
- user_info = {
36
- "id": uuid.uuid1().hex,
37
- "password": "admin",
38
- "nickname": "admin",
39
- "is_superuser": True,
40
- "email": "[email protected]",
41
- "creator": "system",
42
- "status": "1",
43
- }
44
- tenant = {
45
- "id": user_info["id"],
46
- "name": user_info["nickname"] + "‘s Kingdom",
47
- "llm_id": CHAT_MDL,
48
- "embd_id": EMBEDDING_MDL,
49
- "asr_id": ASR_MDL,
50
- "parser_ids": PARSERS,
51
- "img2txt_id": IMAGE2TEXT_MDL
52
- }
53
- usr_tenant = {
54
- "tenant_id": user_info["id"],
55
- "user_id": user_info["id"],
56
- "invited_by": user_info["id"],
57
- "role": UserTenantRole.OWNER
58
- }
59
- tenant_llm = []
60
- for llm in LLMService.query(fid=LLM_FACTORY):
61
- tenant_llm.append(
62
- {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
63
- "api_key": API_KEY, "api_base": LLM_BASE_URL})
64
-
65
- if not UserService.save(**user_info):
66
- print("\033[93m【ERROR】\033[0mcan't init admin.")
67
- return
68
- TenantService.insert(**tenant)
69
- UserTenantService.insert(**usr_tenant)
70
- TenantLLMService.insert_many(tenant_llm)
71
- print(
72
- "【INFO】Super user initialized. \033[93memail: [email protected], password: admin\033[0m. Changing the password after logining is strongly recomanded.")
73
-
74
- chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
75
- msg = chat_mdl.chat(system="", history=[
76
- {"role": "user", "content": "Hello!"}], gen_conf={})
77
- if msg.find("ERROR: ") == 0:
78
- print(
79
- "\33[91m【ERROR】\33[0m: ",
80
- "'{}' dosen't work. {}".format(
81
- tenant["llm_id"],
82
- msg))
83
- embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
84
- v, c = embd_mdl.encode(["Hello!"])
85
- if c == 0:
86
- print(
87
- "\33[91m【ERROR】\33[0m:",
88
- " '{}' dosen't work!".format(
89
- tenant["embd_id"]))
90
-
91
-
92
- def init_llm_factory():
93
- try:
94
- LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")])
95
- except Exception as e:
96
- pass
97
-
98
- factory_llm_infos = json.load(
99
- open(
100
- os.path.join(get_project_base_directory(), "conf", "llm_factories.json"),
101
- "r",
102
- )
103
- )
104
- for factory_llm_info in factory_llm_infos["factory_llm_infos"]:
105
- llm_infos = factory_llm_info.pop("llm")
106
- try:
107
- LLMFactoriesService.save(**factory_llm_info)
108
- except Exception as e:
109
- pass
110
- LLMService.filter_delete([LLM.fid == factory_llm_info["name"]])
111
- for llm_info in llm_infos:
112
- llm_info["fid"] = factory_llm_info["name"]
113
- try:
114
- LLMService.save(**llm_info)
115
- except Exception as e:
116
- pass
117
-
118
- LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
119
- LLMService.filter_delete([LLM.fid == "Local"])
120
- LLMService.filter_delete([LLM.llm_name == "qwen-vl-max"])
121
- LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
122
- TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
123
- LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
124
- LLMService.filter_delete([LLMService.model.fid == "QAnything"])
125
- TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
126
- TenantService.filter_update([1 == 1], {
127
- "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
128
- ## insert openai two embedding models to the current openai user.
129
- print("Start to insert 2 OpenAI embedding models...")
130
- tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
131
- for tid in tenant_ids:
132
- for row in TenantLLMService.query(llm_factory="OpenAI", tenant_id=tid):
133
- row = row.to_dict()
134
- row["model_type"] = LLMType.EMBEDDING.value
135
- row["llm_name"] = "text-embedding-3-small"
136
- row["used_tokens"] = 0
137
- try:
138
- TenantLLMService.save(**row)
139
- row = deepcopy(row)
140
- row["llm_name"] = "text-embedding-3-large"
141
- TenantLLMService.save(**row)
142
- except Exception as e:
143
- pass
144
- break
145
- for kb_id in KnowledgebaseService.get_all_ids():
146
- KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
147
- """
148
- drop table llm;
149
- drop table llm_factories;
150
- update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
151
- alter table knowledgebase modify avatar longtext;
152
- alter table user modify avatar longtext;
153
- alter table dialog modify icon longtext;
154
- """
155
-
156
-
157
- def add_graph_templates():
158
- dir = os.path.join(get_project_base_directory(), "agent", "templates")
159
- for fnm in os.listdir(dir):
160
- try:
161
- cnvs = json.load(open(os.path.join(dir, fnm), "r"))
162
- try:
163
- CanvasTemplateService.save(**cnvs)
164
- except:
165
- CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
166
- except Exception as e:
167
- print("Add graph templates error: ", e)
168
- print("------------", flush=True)
169
-
170
-
171
- def init_web_data():
172
- start_time = time.time()
173
-
174
- init_llm_factory()
175
- if not UserService.get_all().count():
176
- init_superuser()
177
-
178
- add_graph_templates()
179
- print("init web data success:{}".format(time.time() - start_time))
180
-
181
-
182
- if __name__ == '__main__':
183
- init_web_db()
184
- init_web_data()
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import os
18
+ import time
19
+ import uuid
20
+ from copy import deepcopy
21
+
22
+ from api.db import LLMType, UserTenantRole
23
+ from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
24
+ from api.db.services import UserService
25
+ from api.db.services.canvas_service import CanvasTemplateService
26
+ from api.db.services.document_service import DocumentService
27
+ from api.db.services.knowledgebase_service import KnowledgebaseService
28
+ from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
29
+ from api.db.services.user_service import TenantService, UserTenantService
30
+ from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
31
+ from api.utils.file_utils import get_project_base_directory
32
+
33
+
34
+ def init_superuser():
35
+ user_info = {
36
+ "id": uuid.uuid1().hex,
37
+ "password": "admin",
38
+ "nickname": "admin",
39
+ "is_superuser": True,
40
+ "email": "[email protected]",
41
+ "creator": "system",
42
+ "status": "1",
43
+ }
44
+ tenant = {
45
+ "id": user_info["id"],
46
+ "name": user_info["nickname"] + "‘s Kingdom",
47
+ "llm_id": CHAT_MDL,
48
+ "embd_id": EMBEDDING_MDL,
49
+ "asr_id": ASR_MDL,
50
+ "parser_ids": PARSERS,
51
+ "img2txt_id": IMAGE2TEXT_MDL
52
+ }
53
+ usr_tenant = {
54
+ "tenant_id": user_info["id"],
55
+ "user_id": user_info["id"],
56
+ "invited_by": user_info["id"],
57
+ "role": UserTenantRole.OWNER
58
+ }
59
+ tenant_llm = []
60
+ for llm in LLMService.query(fid=LLM_FACTORY):
61
+ tenant_llm.append(
62
+ {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
63
+ "api_key": API_KEY, "api_base": LLM_BASE_URL})
64
+
65
+ if not UserService.save(**user_info):
66
+ print("\033[93m【ERROR】\033[0mcan't init admin.")
67
+ return
68
+ TenantService.insert(**tenant)
69
+ UserTenantService.insert(**usr_tenant)
70
+ TenantLLMService.insert_many(tenant_llm)
71
+ print(
72
+ "【INFO】Super user initialized. \033[93memail: [email protected], password: admin\033[0m. Changing the password after logining is strongly recomanded.")
73
+
74
+ chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
75
+ msg = chat_mdl.chat(system="", history=[
76
+ {"role": "user", "content": "Hello!"}], gen_conf={})
77
+ if msg.find("ERROR: ") == 0:
78
+ print(
79
+ "\33[91m【ERROR】\33[0m: ",
80
+ "'{}' dosen't work. {}".format(
81
+ tenant["llm_id"],
82
+ msg))
83
+ embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
84
+ v, c = embd_mdl.encode(["Hello!"])
85
+ if c == 0:
86
+ print(
87
+ "\33[91m【ERROR】\33[0m:",
88
+ " '{}' dosen't work!".format(
89
+ tenant["embd_id"]))
90
+
91
+
92
+ def init_llm_factory():
93
+ try:
94
+ LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")])
95
+ except Exception as e:
96
+ pass
97
+
98
+ factory_llm_infos = json.load(
99
+ open(
100
+ os.path.join(get_project_base_directory(), "conf", "llm_factories.json"),
101
+ "r",
102
+ )
103
+ )
104
+ for factory_llm_info in factory_llm_infos["factory_llm_infos"]:
105
+ llm_infos = factory_llm_info.pop("llm")
106
+ try:
107
+ LLMFactoriesService.save(**factory_llm_info)
108
+ except Exception as e:
109
+ pass
110
+ LLMService.filter_delete([LLM.fid == factory_llm_info["name"]])
111
+ for llm_info in llm_infos:
112
+ llm_info["fid"] = factory_llm_info["name"]
113
+ try:
114
+ LLMService.save(**llm_info)
115
+ except Exception as e:
116
+ pass
117
+
118
+ LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
119
+ LLMService.filter_delete([LLM.fid == "Local"])
120
+ LLMService.filter_delete([LLM.llm_name == "qwen-vl-max"])
121
+ LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
122
+ TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
123
+ LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
124
+ LLMService.filter_delete([LLMService.model.fid == "QAnything"])
125
+ TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
126
+ TenantService.filter_update([1 == 1], {
127
+ "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
128
+ ## insert openai two embedding models to the current openai user.
129
+ print("Start to insert 2 OpenAI embedding models...")
130
+ tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
131
+ for tid in tenant_ids:
132
+ for row in TenantLLMService.query(llm_factory="OpenAI", tenant_id=tid):
133
+ row = row.to_dict()
134
+ row["model_type"] = LLMType.EMBEDDING.value
135
+ row["llm_name"] = "text-embedding-3-small"
136
+ row["used_tokens"] = 0
137
+ try:
138
+ TenantLLMService.save(**row)
139
+ row = deepcopy(row)
140
+ row["llm_name"] = "text-embedding-3-large"
141
+ TenantLLMService.save(**row)
142
+ except Exception as e:
143
+ pass
144
+ break
145
+ for kb_id in KnowledgebaseService.get_all_ids():
146
+ KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
147
+ """
148
+ drop table llm;
149
+ drop table llm_factories;
150
+ update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
151
+ alter table knowledgebase modify avatar longtext;
152
+ alter table user modify avatar longtext;
153
+ alter table dialog modify icon longtext;
154
+ """
155
+
156
+
157
+ def add_graph_templates():
158
+ dir = os.path.join(get_project_base_directory(), "agent", "templates")
159
+ for fnm in os.listdir(dir):
160
+ try:
161
+ cnvs = json.load(open(os.path.join(dir, fnm), "r"))
162
+ try:
163
+ CanvasTemplateService.save(**cnvs)
164
+ except:
165
+ CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
166
+ except Exception as e:
167
+ print("Add graph templates error: ", e)
168
+ print("------------", flush=True)
169
+
170
+
171
+ def init_web_data():
172
+ start_time = time.time()
173
+
174
+ init_llm_factory()
175
+ if not UserService.get_all().count():
176
+ init_superuser()
177
+
178
+ add_graph_templates()
179
+ print("init web data success:{}".format(time.time() - start_time))
180
+
181
+
182
+ if __name__ == '__main__':
183
+ init_web_db()
184
+ init_web_data()
api/db/operatioins.py CHANGED
@@ -1,21 +1,21 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- import operator
18
- import time
19
- import typing
20
- from api.utils.log_utils import sql_logger
21
- import peewee
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import operator
18
+ import time
19
+ import typing
20
+ from api.utils.log_utils import sql_logger
21
+ import peewee
api/db/reload_config_base.py CHANGED
@@ -1,28 +1,28 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- class ReloadConfigBase:
17
- @classmethod
18
- def get_all(cls):
19
- configs = {}
20
- for k, v in cls.__dict__.items():
21
- if not callable(getattr(cls, k)) and not k.startswith(
22
- "__") and not k.startswith("_"):
23
- configs[k] = v
24
- return configs
25
-
26
- @classmethod
27
- def get(cls, config_name):
28
- return getattr(cls, config_name) if hasattr(cls, config_name) else None
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ class ReloadConfigBase:
17
+ @classmethod
18
+ def get_all(cls):
19
+ configs = {}
20
+ for k, v in cls.__dict__.items():
21
+ if not callable(getattr(cls, k)) and not k.startswith(
22
+ "__") and not k.startswith("_"):
23
+ configs[k] = v
24
+ return configs
25
+
26
+ @classmethod
27
+ def get(cls, config_name):
28
+ return getattr(cls, config_name) if hasattr(cls, config_name) else None
api/db/runtime_config.py CHANGED
@@ -1,54 +1,54 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from api.versions import get_versions
17
- from .reload_config_base import ReloadConfigBase
18
-
19
-
20
- class RuntimeConfig(ReloadConfigBase):
21
- DEBUG = None
22
- WORK_MODE = None
23
- HTTP_PORT = None
24
- JOB_SERVER_HOST = None
25
- JOB_SERVER_VIP = None
26
- ENV = dict()
27
- SERVICE_DB = None
28
- LOAD_CONFIG_MANAGER = False
29
-
30
- @classmethod
31
- def init_config(cls, **kwargs):
32
- for k, v in kwargs.items():
33
- if hasattr(cls, k):
34
- setattr(cls, k, v)
35
-
36
- @classmethod
37
- def init_env(cls):
38
- cls.ENV.update(get_versions())
39
-
40
- @classmethod
41
- def load_config_manager(cls):
42
- cls.LOAD_CONFIG_MANAGER = True
43
-
44
- @classmethod
45
- def get_env(cls, key):
46
- return cls.ENV.get(key, None)
47
-
48
- @classmethod
49
- def get_all_env(cls):
50
- return cls.ENV
51
-
52
- @classmethod
53
- def set_service_db(cls, service_db):
54
- cls.SERVICE_DB = service_db
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from api.versions import get_versions
17
+ from .reload_config_base import ReloadConfigBase
18
+
19
+
20
+ class RuntimeConfig(ReloadConfigBase):
21
+ DEBUG = None
22
+ WORK_MODE = None
23
+ HTTP_PORT = None
24
+ JOB_SERVER_HOST = None
25
+ JOB_SERVER_VIP = None
26
+ ENV = dict()
27
+ SERVICE_DB = None
28
+ LOAD_CONFIG_MANAGER = False
29
+
30
+ @classmethod
31
+ def init_config(cls, **kwargs):
32
+ for k, v in kwargs.items():
33
+ if hasattr(cls, k):
34
+ setattr(cls, k, v)
35
+
36
+ @classmethod
37
+ def init_env(cls):
38
+ cls.ENV.update(get_versions())
39
+
40
+ @classmethod
41
+ def load_config_manager(cls):
42
+ cls.LOAD_CONFIG_MANAGER = True
43
+
44
+ @classmethod
45
+ def get_env(cls, key):
46
+ return cls.ENV.get(key, None)
47
+
48
+ @classmethod
49
+ def get_all_env(cls):
50
+ return cls.ENV
51
+
52
+ @classmethod
53
+ def set_service_db(cls, service_db):
54
+ cls.SERVICE_DB = service_db
api/db/services/__init__.py CHANGED
@@ -1,38 +1,38 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import pathlib
17
- import re
18
- from .user_service import UserService
19
-
20
-
21
- def duplicate_name(query_func, **kwargs):
22
- fnm = kwargs["name"]
23
- objs = query_func(**kwargs)
24
- if not objs: return fnm
25
- ext = pathlib.Path(fnm).suffix #.jpg
26
- nm = re.sub(r"%s$"%ext, "", fnm)
27
- r = re.search(r"\(([0-9]+)\)$", nm)
28
- c = 0
29
- if r:
30
- c = int(r.group(1))
31
- nm = re.sub(r"\([0-9]+\)$", "", nm)
32
- c += 1
33
- nm = f"{nm}({c})"
34
- if ext: nm += f"{ext}"
35
-
36
- kwargs["name"] = nm
37
- return duplicate_name(query_func, **kwargs)
38
-
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import pathlib
17
+ import re
18
+ from .user_service import UserService
19
+
20
+
21
+ def duplicate_name(query_func, **kwargs):
22
+ fnm = kwargs["name"]
23
+ objs = query_func(**kwargs)
24
+ if not objs: return fnm
25
+ ext = pathlib.Path(fnm).suffix #.jpg
26
+ nm = re.sub(r"%s$"%ext, "", fnm)
27
+ r = re.search(r"\(([0-9]+)\)$", nm)
28
+ c = 0
29
+ if r:
30
+ c = int(r.group(1))
31
+ nm = re.sub(r"\([0-9]+\)$", "", nm)
32
+ c += 1
33
+ nm = f"{nm}({c})"
34
+ if ext: nm += f"{ext}"
35
+
36
+ kwargs["name"] = nm
37
+ return duplicate_name(query_func, **kwargs)
38
+
api/db/services/api_service.py CHANGED
@@ -1,68 +1,68 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from datetime import datetime
17
- import peewee
18
- from api.db.db_models import DB, API4Conversation, APIToken, Dialog
19
- from api.db.services.common_service import CommonService
20
- from api.utils import current_timestamp, datetime_format
21
-
22
-
23
- class APITokenService(CommonService):
24
- model = APIToken
25
-
26
- @classmethod
27
- @DB.connection_context()
28
- def used(cls, token):
29
- return cls.model.update({
30
- "update_time": current_timestamp(),
31
- "update_date": datetime_format(datetime.now()),
32
- }).where(
33
- cls.model.token == token
34
- )
35
-
36
-
37
- class API4ConversationService(CommonService):
38
- model = API4Conversation
39
-
40
- @classmethod
41
- @DB.connection_context()
42
- def append_message(cls, id, conversation):
43
- cls.update_by_id(id, conversation)
44
- return cls.model.update(round=cls.model.round + 1).where(cls.model.id==id).execute()
45
-
46
- @classmethod
47
- @DB.connection_context()
48
- def stats(cls, tenant_id, from_date, to_date, source=None):
49
- if len(to_date) == 10: to_date += " 23:59:59"
50
- return cls.model.select(
51
- cls.model.create_date.truncate("day").alias("dt"),
52
- peewee.fn.COUNT(
53
- cls.model.id).alias("pv"),
54
- peewee.fn.COUNT(
55
- cls.model.user_id.distinct()).alias("uv"),
56
- peewee.fn.SUM(
57
- cls.model.tokens).alias("tokens"),
58
- peewee.fn.SUM(
59
- cls.model.duration).alias("duration"),
60
- peewee.fn.AVG(
61
- cls.model.round).alias("round"),
62
- peewee.fn.SUM(
63
- cls.model.thumb_up).alias("thumb_up")
64
- ).join(Dialog, on=(cls.model.dialog_id == Dialog.id & Dialog.tenant_id == tenant_id)).where(
65
- cls.model.create_date >= from_date,
66
- cls.model.create_date <= to_date,
67
- cls.model.source == source
68
- ).group_by(cls.model.create_date.truncate("day")).dicts()
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from datetime import datetime
17
+ import peewee
18
+ from api.db.db_models import DB, API4Conversation, APIToken, Dialog
19
+ from api.db.services.common_service import CommonService
20
+ from api.utils import current_timestamp, datetime_format
21
+
22
+
23
+ class APITokenService(CommonService):
24
+ model = APIToken
25
+
26
+ @classmethod
27
+ @DB.connection_context()
28
+ def used(cls, token):
29
+ return cls.model.update({
30
+ "update_time": current_timestamp(),
31
+ "update_date": datetime_format(datetime.now()),
32
+ }).where(
33
+ cls.model.token == token
34
+ )
35
+
36
+
37
+ class API4ConversationService(CommonService):
38
+ model = API4Conversation
39
+
40
+ @classmethod
41
+ @DB.connection_context()
42
+ def append_message(cls, id, conversation):
43
+ cls.update_by_id(id, conversation)
44
+ return cls.model.update(round=cls.model.round + 1).where(cls.model.id==id).execute()
45
+
46
+ @classmethod
47
+ @DB.connection_context()
48
+ def stats(cls, tenant_id, from_date, to_date, source=None):
49
+ if len(to_date) == 10: to_date += " 23:59:59"
50
+ return cls.model.select(
51
+ cls.model.create_date.truncate("day").alias("dt"),
52
+ peewee.fn.COUNT(
53
+ cls.model.id).alias("pv"),
54
+ peewee.fn.COUNT(
55
+ cls.model.user_id.distinct()).alias("uv"),
56
+ peewee.fn.SUM(
57
+ cls.model.tokens).alias("tokens"),
58
+ peewee.fn.SUM(
59
+ cls.model.duration).alias("duration"),
60
+ peewee.fn.AVG(
61
+ cls.model.round).alias("round"),
62
+ peewee.fn.SUM(
63
+ cls.model.thumb_up).alias("thumb_up")
64
+ ).join(Dialog, on=(cls.model.dialog_id == Dialog.id & Dialog.tenant_id == tenant_id)).where(
65
+ cls.model.create_date >= from_date,
66
+ cls.model.create_date <= to_date,
67
+ cls.model.source == source
68
+ ).group_by(cls.model.create_date.truncate("day")).dicts()
api/db/services/common_service.py CHANGED
@@ -1,183 +1,183 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from datetime import datetime
17
-
18
- import peewee
19
-
20
- from api.db.db_models import DB
21
- from api.utils import datetime_format, current_timestamp, get_uuid
22
-
23
-
24
- class CommonService:
25
- model = None
26
-
27
- @classmethod
28
- @DB.connection_context()
29
- def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
30
- return cls.model.query(cols=cols, reverse=reverse,
31
- order_by=order_by, **kwargs)
32
-
33
- @classmethod
34
- @DB.connection_context()
35
- def get_all(cls, cols=None, reverse=None, order_by=None):
36
- if cols:
37
- query_records = cls.model.select(*cols)
38
- else:
39
- query_records = cls.model.select()
40
- if reverse is not None:
41
- if not order_by or not hasattr(cls, order_by):
42
- order_by = "create_time"
43
- if reverse is True:
44
- query_records = query_records.order_by(
45
- cls.model.getter_by(order_by).desc())
46
- elif reverse is False:
47
- query_records = query_records.order_by(
48
- cls.model.getter_by(order_by).asc())
49
- return query_records
50
-
51
- @classmethod
52
- @DB.connection_context()
53
- def get(cls, **kwargs):
54
- return cls.model.get(**kwargs)
55
-
56
- @classmethod
57
- @DB.connection_context()
58
- def get_or_none(cls, **kwargs):
59
- try:
60
- return cls.model.get(**kwargs)
61
- except peewee.DoesNotExist:
62
- return None
63
-
64
- @classmethod
65
- @DB.connection_context()
66
- def save(cls, **kwargs):
67
- # if "id" not in kwargs:
68
- # kwargs["id"] = get_uuid()
69
- sample_obj = cls.model(**kwargs).save(force_insert=True)
70
- return sample_obj
71
-
72
- @classmethod
73
- @DB.connection_context()
74
- def insert(cls, **kwargs):
75
- if "id" not in kwargs:
76
- kwargs["id"] = get_uuid()
77
- kwargs["create_time"] = current_timestamp()
78
- kwargs["create_date"] = datetime_format(datetime.now())
79
- kwargs["update_time"] = current_timestamp()
80
- kwargs["update_date"] = datetime_format(datetime.now())
81
- sample_obj = cls.model(**kwargs).save(force_insert=True)
82
- return sample_obj
83
-
84
- @classmethod
85
- @DB.connection_context()
86
- def insert_many(cls, data_list, batch_size=100):
87
- with DB.atomic():
88
- for d in data_list:
89
- d["create_time"] = current_timestamp()
90
- d["create_date"] = datetime_format(datetime.now())
91
- for i in range(0, len(data_list), batch_size):
92
- cls.model.insert_many(data_list[i:i + batch_size]).execute()
93
-
94
- @classmethod
95
- @DB.connection_context()
96
- def update_many_by_id(cls, data_list):
97
- with DB.atomic():
98
- for data in data_list:
99
- data["update_time"] = current_timestamp()
100
- data["update_date"] = datetime_format(datetime.now())
101
- cls.model.update(data).where(
102
- cls.model.id == data["id"]).execute()
103
-
104
- @classmethod
105
- @DB.connection_context()
106
- def update_by_id(cls, pid, data):
107
- data["update_time"] = current_timestamp()
108
- data["update_date"] = datetime_format(datetime.now())
109
- num = cls.model.update(data).where(cls.model.id == pid).execute()
110
- return num
111
-
112
- @classmethod
113
- @DB.connection_context()
114
- def get_by_id(cls, pid):
115
- try:
116
- obj = cls.model.query(id=pid)[0]
117
- return True, obj
118
- except Exception as e:
119
- return False, None
120
-
121
- @classmethod
122
- @DB.connection_context()
123
- def get_by_ids(cls, pids, cols=None):
124
- if cols:
125
- objs = cls.model.select(*cols)
126
- else:
127
- objs = cls.model.select()
128
- return objs.where(cls.model.id.in_(pids))
129
-
130
- @classmethod
131
- @DB.connection_context()
132
- def delete_by_id(cls, pid):
133
- return cls.model.delete().where(cls.model.id == pid).execute()
134
-
135
- @classmethod
136
- @DB.connection_context()
137
- def filter_delete(cls, filters):
138
- with DB.atomic():
139
- num = cls.model.delete().where(*filters).execute()
140
- return num
141
-
142
- @classmethod
143
- @DB.connection_context()
144
- def filter_update(cls, filters, update_data):
145
- with DB.atomic():
146
- return cls.model.update(update_data).where(*filters).execute()
147
-
148
- @staticmethod
149
- def cut_list(tar_list, n):
150
- length = len(tar_list)
151
- arr = range(length)
152
- result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
153
- return result
154
-
155
- @classmethod
156
- @DB.connection_context()
157
- def filter_scope_list(cls, in_key, in_filters_list,
158
- filters=None, cols=None):
159
- in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
160
- if not filters:
161
- filters = []
162
- res_list = []
163
- if cols:
164
- for i in in_filters_tuple_list:
165
- query_records = cls.model.select(
166
- *
167
- cols).where(
168
- getattr(
169
- cls.model,
170
- in_key).in_(i),
171
- *
172
- filters)
173
- if query_records:
174
- res_list.extend(
175
- [query_record for query_record in query_records])
176
- else:
177
- for i in in_filters_tuple_list:
178
- query_records = cls.model.select().where(
179
- getattr(cls.model, in_key).in_(i), *filters)
180
- if query_records:
181
- res_list.extend(
182
- [query_record for query_record in query_records])
183
- return res_list
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from datetime import datetime
17
+
18
+ import peewee
19
+
20
+ from api.db.db_models import DB
21
+ from api.utils import datetime_format, current_timestamp, get_uuid
22
+
23
+
24
+ class CommonService:
25
+ model = None
26
+
27
+ @classmethod
28
+ @DB.connection_context()
29
+ def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
30
+ return cls.model.query(cols=cols, reverse=reverse,
31
+ order_by=order_by, **kwargs)
32
+
33
+ @classmethod
34
+ @DB.connection_context()
35
+ def get_all(cls, cols=None, reverse=None, order_by=None):
36
+ if cols:
37
+ query_records = cls.model.select(*cols)
38
+ else:
39
+ query_records = cls.model.select()
40
+ if reverse is not None:
41
+ if not order_by or not hasattr(cls, order_by):
42
+ order_by = "create_time"
43
+ if reverse is True:
44
+ query_records = query_records.order_by(
45
+ cls.model.getter_by(order_by).desc())
46
+ elif reverse is False:
47
+ query_records = query_records.order_by(
48
+ cls.model.getter_by(order_by).asc())
49
+ return query_records
50
+
51
+ @classmethod
52
+ @DB.connection_context()
53
+ def get(cls, **kwargs):
54
+ return cls.model.get(**kwargs)
55
+
56
+ @classmethod
57
+ @DB.connection_context()
58
+ def get_or_none(cls, **kwargs):
59
+ try:
60
+ return cls.model.get(**kwargs)
61
+ except peewee.DoesNotExist:
62
+ return None
63
+
64
+ @classmethod
65
+ @DB.connection_context()
66
+ def save(cls, **kwargs):
67
+ # if "id" not in kwargs:
68
+ # kwargs["id"] = get_uuid()
69
+ sample_obj = cls.model(**kwargs).save(force_insert=True)
70
+ return sample_obj
71
+
72
+ @classmethod
73
+ @DB.connection_context()
74
+ def insert(cls, **kwargs):
75
+ if "id" not in kwargs:
76
+ kwargs["id"] = get_uuid()
77
+ kwargs["create_time"] = current_timestamp()
78
+ kwargs["create_date"] = datetime_format(datetime.now())
79
+ kwargs["update_time"] = current_timestamp()
80
+ kwargs["update_date"] = datetime_format(datetime.now())
81
+ sample_obj = cls.model(**kwargs).save(force_insert=True)
82
+ return sample_obj
83
+
84
+ @classmethod
85
+ @DB.connection_context()
86
+ def insert_many(cls, data_list, batch_size=100):
87
+ with DB.atomic():
88
+ for d in data_list:
89
+ d["create_time"] = current_timestamp()
90
+ d["create_date"] = datetime_format(datetime.now())
91
+ for i in range(0, len(data_list), batch_size):
92
+ cls.model.insert_many(data_list[i:i + batch_size]).execute()
93
+
94
+ @classmethod
95
+ @DB.connection_context()
96
+ def update_many_by_id(cls, data_list):
97
+ with DB.atomic():
98
+ for data in data_list:
99
+ data["update_time"] = current_timestamp()
100
+ data["update_date"] = datetime_format(datetime.now())
101
+ cls.model.update(data).where(
102
+ cls.model.id == data["id"]).execute()
103
+
104
+ @classmethod
105
+ @DB.connection_context()
106
+ def update_by_id(cls, pid, data):
107
+ data["update_time"] = current_timestamp()
108
+ data["update_date"] = datetime_format(datetime.now())
109
+ num = cls.model.update(data).where(cls.model.id == pid).execute()
110
+ return num
111
+
112
+ @classmethod
113
+ @DB.connection_context()
114
+ def get_by_id(cls, pid):
115
+ try:
116
+ obj = cls.model.query(id=pid)[0]
117
+ return True, obj
118
+ except Exception as e:
119
+ return False, None
120
+
121
+ @classmethod
122
+ @DB.connection_context()
123
+ def get_by_ids(cls, pids, cols=None):
124
+ if cols:
125
+ objs = cls.model.select(*cols)
126
+ else:
127
+ objs = cls.model.select()
128
+ return objs.where(cls.model.id.in_(pids))
129
+
130
+ @classmethod
131
+ @DB.connection_context()
132
+ def delete_by_id(cls, pid):
133
+ return cls.model.delete().where(cls.model.id == pid).execute()
134
+
135
+ @classmethod
136
+ @DB.connection_context()
137
+ def filter_delete(cls, filters):
138
+ with DB.atomic():
139
+ num = cls.model.delete().where(*filters).execute()
140
+ return num
141
+
142
+ @classmethod
143
+ @DB.connection_context()
144
+ def filter_update(cls, filters, update_data):
145
+ with DB.atomic():
146
+ return cls.model.update(update_data).where(*filters).execute()
147
+
148
+ @staticmethod
149
+ def cut_list(tar_list, n):
150
+ length = len(tar_list)
151
+ arr = range(length)
152
+ result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
153
+ return result
154
+
155
+ @classmethod
156
+ @DB.connection_context()
157
+ def filter_scope_list(cls, in_key, in_filters_list,
158
+ filters=None, cols=None):
159
+ in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
160
+ if not filters:
161
+ filters = []
162
+ res_list = []
163
+ if cols:
164
+ for i in in_filters_tuple_list:
165
+ query_records = cls.model.select(
166
+ *
167
+ cols).where(
168
+ getattr(
169
+ cls.model,
170
+ in_key).in_(i),
171
+ *
172
+ filters)
173
+ if query_records:
174
+ res_list.extend(
175
+ [query_record for query_record in query_records])
176
+ else:
177
+ for i in in_filters_tuple_list:
178
+ query_records = cls.model.select().where(
179
+ getattr(cls.model, in_key).in_(i), *filters)
180
+ if query_records:
181
+ res_list.extend(
182
+ [query_record for query_record in query_records])
183
+ return res_list
api/db/services/dialog_service.py CHANGED
@@ -1,392 +1,392 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import os
17
- import json
18
- import re
19
- from copy import deepcopy
20
-
21
- from api.db import LLMType, ParserType
22
- from api.db.db_models import Dialog, Conversation
23
- from api.db.services.common_service import CommonService
24
- from api.db.services.knowledgebase_service import KnowledgebaseService
25
- from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
26
- from api.settings import chat_logger, retrievaler, kg_retrievaler
27
- from rag.app.resume import forbidden_select_fields4resume
28
- from rag.nlp import keyword_extraction
29
- from rag.nlp.search import index_name
30
- from rag.utils import rmSpace, num_tokens_from_string, encoder
31
- from api.utils.file_utils import get_project_base_directory
32
-
33
-
34
- class DialogService(CommonService):
35
- model = Dialog
36
-
37
-
38
- class ConversationService(CommonService):
39
- model = Conversation
40
-
41
-
42
- def message_fit_in(msg, max_length=4000):
43
- def count():
44
- nonlocal msg
45
- tks_cnts = []
46
- for m in msg:
47
- tks_cnts.append(
48
- {"role": m["role"], "count": num_tokens_from_string(m["content"])})
49
- total = 0
50
- for m in tks_cnts:
51
- total += m["count"]
52
- return total
53
-
54
- c = count()
55
- if c < max_length:
56
- return c, msg
57
-
58
- msg_ = [m for m in msg[:-1] if m["role"] == "system"]
59
- msg_.append(msg[-1])
60
- msg = msg_
61
- c = count()
62
- if c < max_length:
63
- return c, msg
64
-
65
- ll = num_tokens_from_string(msg_[0]["content"])
66
- l = num_tokens_from_string(msg_[-1]["content"])
67
- if ll / (ll + l) > 0.8:
68
- m = msg_[0]["content"]
69
- m = encoder.decode(encoder.encode(m)[:max_length - l])
70
- msg[0]["content"] = m
71
- return max_length, msg
72
-
73
- m = msg_[1]["content"]
74
- m = encoder.decode(encoder.encode(m)[:max_length - l])
75
- msg[1]["content"] = m
76
- return max_length, msg
77
-
78
-
79
- def llm_id2llm_type(llm_id):
80
- fnm = os.path.join(get_project_base_directory(), "conf")
81
- llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
82
- for llm_factory in llm_factories["factory_llm_infos"]:
83
- for llm in llm_factory["llm"]:
84
- if llm_id == llm["llm_name"]:
85
- return llm["model_type"].strip(",")[-1]
86
-
87
-
88
- def chat(dialog, messages, stream=True, **kwargs):
89
- assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
90
- llm = LLMService.query(llm_name=dialog.llm_id)
91
- if not llm:
92
- llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
93
- if not llm:
94
- raise LookupError("LLM(%s) not found" % dialog.llm_id)
95
- max_tokens = 8192
96
- else:
97
- max_tokens = llm[0].max_tokens
98
- kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
99
- embd_nms = list(set([kb.embd_id for kb in kbs]))
100
- if len(embd_nms) != 1:
101
- yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
102
- return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
103
-
104
- is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
105
- retr = retrievaler if not is_kg else kg_retrievaler
106
-
107
- questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
108
- attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
109
- if "doc_ids" in messages[-1]:
110
- attachments = messages[-1]["doc_ids"]
111
- for m in messages[:-1]:
112
- if "doc_ids" in m:
113
- attachments.extend(m["doc_ids"])
114
-
115
- embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
116
- if llm_id2llm_type(dialog.llm_id) == "image2text":
117
- chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
118
- else:
119
- chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
120
-
121
- prompt_config = dialog.prompt_config
122
- field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
123
- # try to use sql if field mapping is good to go
124
- if field_map:
125
- chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
126
- ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
127
- if ans:
128
- yield ans
129
- return
130
-
131
- for p in prompt_config["parameters"]:
132
- if p["key"] == "knowledge":
133
- continue
134
- if p["key"] not in kwargs and not p["optional"]:
135
- raise KeyError("Miss parameter: " + p["key"])
136
- if p["key"] not in kwargs:
137
- prompt_config["system"] = prompt_config["system"].replace(
138
- "{%s}" % p["key"], " ")
139
-
140
- rerank_mdl = None
141
- if dialog.rerank_id:
142
- rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
143
-
144
- for _ in range(len(questions) // 2):
145
- questions.append(questions[-1])
146
- if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
147
- kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
148
- else:
149
- if prompt_config.get("keyword", False):
150
- questions[-1] += keyword_extraction(chat_mdl, questions[-1])
151
- kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
152
- dialog.similarity_threshold,
153
- dialog.vector_similarity_weight,
154
- doc_ids=attachments,
155
- top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
156
- knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
157
- #self-rag
158
- if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
159
- questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
160
- kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
161
- dialog.similarity_threshold,
162
- dialog.vector_similarity_weight,
163
- doc_ids=attachments,
164
- top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
165
- knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
166
-
167
- chat_logger.info(
168
- "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
169
-
170
- if not knowledges and prompt_config.get("empty_response"):
171
- yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
172
- return {"answer": prompt_config["empty_response"], "reference": kbinfos}
173
-
174
- kwargs["knowledge"] = "\n".join(knowledges)
175
- gen_conf = dialog.llm_setting
176
-
177
- msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
178
- msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
179
- for m in messages if m["role"] != "system"])
180
- used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
181
- assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
182
-
183
- if "max_tokens" in gen_conf:
184
- gen_conf["max_tokens"] = min(
185
- gen_conf["max_tokens"],
186
- max_tokens - used_token_count)
187
-
188
- def decorate_answer(answer):
189
- nonlocal prompt_config, knowledges, kwargs, kbinfos
190
- refs = []
191
- if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
192
- answer, idx = retr.insert_citations(answer,
193
- [ck["content_ltks"]
194
- for ck in kbinfos["chunks"]],
195
- [ck["vector"]
196
- for ck in kbinfos["chunks"]],
197
- embd_mdl,
198
- tkweight=1 - dialog.vector_similarity_weight,
199
- vtweight=dialog.vector_similarity_weight)
200
- idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
201
- recall_docs = [
202
- d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
203
- if not recall_docs: recall_docs = kbinfos["doc_aggs"]
204
- kbinfos["doc_aggs"] = recall_docs
205
-
206
- refs = deepcopy(kbinfos)
207
- for c in refs["chunks"]:
208
- if c.get("vector"):
209
- del c["vector"]
210
-
211
- if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
212
- answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
213
- return {"answer": answer, "reference": refs}
214
-
215
- if stream:
216
- answer = ""
217
- for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf):
218
- answer = ans
219
- yield {"answer": answer, "reference": {}}
220
- yield decorate_answer(answer)
221
- else:
222
- answer = chat_mdl.chat(
223
- msg[0]["content"], msg[1:], gen_conf)
224
- chat_logger.info("User: {}|Assistant: {}".format(
225
- msg[-1]["content"], answer))
226
- yield decorate_answer(answer)
227
-
228
-
229
- def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
230
- sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
231
- user_promt = """
232
- 表名:{};
233
- 数据库表字段说明如下:
234
- {}
235
-
236
- 问题如下:
237
- {}
238
- 请写出SQL, 且只要SQL,不要有其他说明及文字。
239
- """.format(
240
- index_name(tenant_id),
241
- "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
242
- question
243
- )
244
- tried_times = 0
245
-
246
- def get_table():
247
- nonlocal sys_prompt, user_promt, question, tried_times
248
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
249
- "temperature": 0.06})
250
- print(user_promt, sql)
251
- chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
252
- sql = re.sub(r"[\r\n]+", " ", sql.lower())
253
- sql = re.sub(r".*select ", "select ", sql.lower())
254
- sql = re.sub(r" +", " ", sql)
255
- sql = re.sub(r"([;;]|```).*", "", sql)
256
- if sql[:len("select ")] != "select ":
257
- return None, None
258
- if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
259
- if sql[:len("select *")] != "select *":
260
- sql = "select doc_id,docnm_kwd," + sql[6:]
261
- else:
262
- flds = []
263
- for k in field_map.keys():
264
- if k in forbidden_select_fields4resume:
265
- continue
266
- if len(flds) > 11:
267
- break
268
- flds.append(k)
269
- sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
270
-
271
- print(f"“{question}” get SQL(refined): {sql}")
272
-
273
- chat_logger.info(f"“{question}” get SQL(refined): {sql}")
274
- tried_times += 1
275
- return retrievaler.sql_retrieval(sql, format="json"), sql
276
-
277
- tbl, sql = get_table()
278
- if tbl is None:
279
- return None
280
- if tbl.get("error") and tried_times <= 2:
281
- user_promt = """
282
- 表名:{};
283
- 数据库表字段说明如下:
284
- {}
285
-
286
- 问题如下:
287
- {}
288
-
289
- 你上一次给出的错误SQL如下:
290
- {}
291
-
292
- 后台报错如下:
293
- {}
294
-
295
- 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
296
- """.format(
297
- index_name(tenant_id),
298
- "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
299
- question, sql, tbl["error"]
300
- )
301
- tbl, sql = get_table()
302
- chat_logger.info("TRY it again: {}".format(sql))
303
-
304
- chat_logger.info("GET table: {}".format(tbl))
305
- print(tbl)
306
- if tbl.get("error") or len(tbl["rows"]) == 0:
307
- return None
308
-
309
- docid_idx = set([ii for ii, c in enumerate(
310
- tbl["columns"]) if c["name"] == "doc_id"])
311
- docnm_idx = set([ii for ii, c in enumerate(
312
- tbl["columns"]) if c["name"] == "docnm_kwd"])
313
- clmn_idx = [ii for ii in range(
314
- len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
315
-
316
- # compose markdown table
317
- clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
318
- tbl["columns"][i]["name"])) for i in
319
- clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
320
-
321
- line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
322
- ("|------|" if docid_idx and docid_idx else "")
323
-
324
- rows = ["|" +
325
- "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
326
- "|" for r in tbl["rows"]]
327
- if quota:
328
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
329
- else:
330
- rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
331
- rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
332
-
333
- if not docid_idx or not docnm_idx:
334
- chat_logger.warning("SQL missing field: " + sql)
335
- return {
336
- "answer": "\n".join([clmns, line, rows]),
337
- "reference": {"chunks": [], "doc_aggs": []}
338
- }
339
-
340
- docid_idx = list(docid_idx)[0]
341
- docnm_idx = list(docnm_idx)[0]
342
- doc_aggs = {}
343
- for r in tbl["rows"]:
344
- if r[docid_idx] not in doc_aggs:
345
- doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
346
- doc_aggs[r[docid_idx]]["count"] += 1
347
- return {
348
- "answer": "\n".join([clmns, line, rows]),
349
- "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
350
- "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
351
- doc_aggs.items()]}
352
- }
353
-
354
-
355
- def relevant(tenant_id, llm_id, question, contents: list):
356
- if llm_id2llm_type(llm_id) == "image2text":
357
- chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
358
- else:
359
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
360
- prompt = """
361
- You are a grader assessing relevance of a retrieved document to a user question.
362
- It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
363
- If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
364
- Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
365
- No other words needed except 'yes' or 'no'.
366
- """
367
- if not contents:return False
368
- contents = "Documents: \n" + " - ".join(contents)
369
- contents = f"Question: {question}\n" + contents
370
- if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
371
- contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
372
- ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
373
- if ans.lower().find("yes") >= 0: return True
374
- return False
375
-
376
-
377
- def rewrite(tenant_id, llm_id, question):
378
- if llm_id2llm_type(llm_id) == "image2text":
379
- chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
380
- else:
381
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
382
- prompt = """
383
- You are an expert at query expansion to generate a paraphrasing of a question.
384
- I can't retrieval relevant information from the knowledge base by using user's question directly.
385
- You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
386
- writing the abbreviation in its entirety, adding some extra descriptions or explanations,
387
- changing the way of expression, translating the original question into another language (English/Chinese), etc.
388
- And return 5 versions of question and one is from translation.
389
- Just list the question. No other words are needed.
390
- """
391
- ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
392
- return ans
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import os
17
+ import json
18
+ import re
19
+ from copy import deepcopy
20
+
21
+ from api.db import LLMType, ParserType
22
+ from api.db.db_models import Dialog, Conversation
23
+ from api.db.services.common_service import CommonService
24
+ from api.db.services.knowledgebase_service import KnowledgebaseService
25
+ from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
26
+ from api.settings import chat_logger, retrievaler, kg_retrievaler
27
+ from rag.app.resume import forbidden_select_fields4resume
28
+ from rag.nlp import keyword_extraction
29
+ from rag.nlp.search import index_name
30
+ from rag.utils import rmSpace, num_tokens_from_string, encoder
31
+ from api.utils.file_utils import get_project_base_directory
32
+
33
+
34
+ class DialogService(CommonService):
35
+ model = Dialog
36
+
37
+
38
+ class ConversationService(CommonService):
39
+ model = Conversation
40
+
41
+
42
+ def message_fit_in(msg, max_length=4000):
43
+ def count():
44
+ nonlocal msg
45
+ tks_cnts = []
46
+ for m in msg:
47
+ tks_cnts.append(
48
+ {"role": m["role"], "count": num_tokens_from_string(m["content"])})
49
+ total = 0
50
+ for m in tks_cnts:
51
+ total += m["count"]
52
+ return total
53
+
54
+ c = count()
55
+ if c < max_length:
56
+ return c, msg
57
+
58
+ msg_ = [m for m in msg[:-1] if m["role"] == "system"]
59
+ msg_.append(msg[-1])
60
+ msg = msg_
61
+ c = count()
62
+ if c < max_length:
63
+ return c, msg
64
+
65
+ ll = num_tokens_from_string(msg_[0]["content"])
66
+ l = num_tokens_from_string(msg_[-1]["content"])
67
+ if ll / (ll + l) > 0.8:
68
+ m = msg_[0]["content"]
69
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
70
+ msg[0]["content"] = m
71
+ return max_length, msg
72
+
73
+ m = msg_[1]["content"]
74
+ m = encoder.decode(encoder.encode(m)[:max_length - l])
75
+ msg[1]["content"] = m
76
+ return max_length, msg
77
+
78
+
79
+ def llm_id2llm_type(llm_id):
80
+ fnm = os.path.join(get_project_base_directory(), "conf")
81
+ llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
82
+ for llm_factory in llm_factories["factory_llm_infos"]:
83
+ for llm in llm_factory["llm"]:
84
+ if llm_id == llm["llm_name"]:
85
+ return llm["model_type"].strip(",")[-1]
86
+
87
+
88
+ def chat(dialog, messages, stream=True, **kwargs):
89
+ assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
90
+ llm = LLMService.query(llm_name=dialog.llm_id)
91
+ if not llm:
92
+ llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
93
+ if not llm:
94
+ raise LookupError("LLM(%s) not found" % dialog.llm_id)
95
+ max_tokens = 8192
96
+ else:
97
+ max_tokens = llm[0].max_tokens
98
+ kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
99
+ embd_nms = list(set([kb.embd_id for kb in kbs]))
100
+ if len(embd_nms) != 1:
101
+ yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
102
+ return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
103
+
104
+ is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
105
+ retr = retrievaler if not is_kg else kg_retrievaler
106
+
107
+ questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
108
+ attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
109
+ if "doc_ids" in messages[-1]:
110
+ attachments = messages[-1]["doc_ids"]
111
+ for m in messages[:-1]:
112
+ if "doc_ids" in m:
113
+ attachments.extend(m["doc_ids"])
114
+
115
+ embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
116
+ if llm_id2llm_type(dialog.llm_id) == "image2text":
117
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
118
+ else:
119
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
120
+
121
+ prompt_config = dialog.prompt_config
122
+ field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
123
+ # try to use sql if field mapping is good to go
124
+ if field_map:
125
+ chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
126
+ ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
127
+ if ans:
128
+ yield ans
129
+ return
130
+
131
+ for p in prompt_config["parameters"]:
132
+ if p["key"] == "knowledge":
133
+ continue
134
+ if p["key"] not in kwargs and not p["optional"]:
135
+ raise KeyError("Miss parameter: " + p["key"])
136
+ if p["key"] not in kwargs:
137
+ prompt_config["system"] = prompt_config["system"].replace(
138
+ "{%s}" % p["key"], " ")
139
+
140
+ rerank_mdl = None
141
+ if dialog.rerank_id:
142
+ rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
143
+
144
+ for _ in range(len(questions) // 2):
145
+ questions.append(questions[-1])
146
+ if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
147
+ kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
148
+ else:
149
+ if prompt_config.get("keyword", False):
150
+ questions[-1] += keyword_extraction(chat_mdl, questions[-1])
151
+ kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
152
+ dialog.similarity_threshold,
153
+ dialog.vector_similarity_weight,
154
+ doc_ids=attachments,
155
+ top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
156
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
157
+ #self-rag
158
+ if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
159
+ questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
160
+ kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
161
+ dialog.similarity_threshold,
162
+ dialog.vector_similarity_weight,
163
+ doc_ids=attachments,
164
+ top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
165
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
166
+
167
+ chat_logger.info(
168
+ "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
169
+
170
+ if not knowledges and prompt_config.get("empty_response"):
171
+ yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
172
+ return {"answer": prompt_config["empty_response"], "reference": kbinfos}
173
+
174
+ kwargs["knowledge"] = "\n".join(knowledges)
175
+ gen_conf = dialog.llm_setting
176
+
177
+ msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
178
+ msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
179
+ for m in messages if m["role"] != "system"])
180
+ used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
181
+ assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
182
+
183
+ if "max_tokens" in gen_conf:
184
+ gen_conf["max_tokens"] = min(
185
+ gen_conf["max_tokens"],
186
+ max_tokens - used_token_count)
187
+
188
+ def decorate_answer(answer):
189
+ nonlocal prompt_config, knowledges, kwargs, kbinfos
190
+ refs = []
191
+ if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
192
+ answer, idx = retr.insert_citations(answer,
193
+ [ck["content_ltks"]
194
+ for ck in kbinfos["chunks"]],
195
+ [ck["vector"]
196
+ for ck in kbinfos["chunks"]],
197
+ embd_mdl,
198
+ tkweight=1 - dialog.vector_similarity_weight,
199
+ vtweight=dialog.vector_similarity_weight)
200
+ idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
201
+ recall_docs = [
202
+ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
203
+ if not recall_docs: recall_docs = kbinfos["doc_aggs"]
204
+ kbinfos["doc_aggs"] = recall_docs
205
+
206
+ refs = deepcopy(kbinfos)
207
+ for c in refs["chunks"]:
208
+ if c.get("vector"):
209
+ del c["vector"]
210
+
211
+ if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
212
+ answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
213
+ return {"answer": answer, "reference": refs}
214
+
215
+ if stream:
216
+ answer = ""
217
+ for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf):
218
+ answer = ans
219
+ yield {"answer": answer, "reference": {}}
220
+ yield decorate_answer(answer)
221
+ else:
222
+ answer = chat_mdl.chat(
223
+ msg[0]["content"], msg[1:], gen_conf)
224
+ chat_logger.info("User: {}|Assistant: {}".format(
225
+ msg[-1]["content"], answer))
226
+ yield decorate_answer(answer)
227
+
228
+
229
+ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
230
+ sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
231
+ user_promt = """
232
+ 表名:{};
233
+ 数据库表字段说明如下:
234
+ {}
235
+
236
+ 问题如下:
237
+ {}
238
+ 请写出SQL, 且只要SQL,不要有其他说明及文字。
239
+ """.format(
240
+ index_name(tenant_id),
241
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
242
+ question
243
+ )
244
+ tried_times = 0
245
+
246
+ def get_table():
247
+ nonlocal sys_prompt, user_promt, question, tried_times
248
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
249
+ "temperature": 0.06})
250
+ print(user_promt, sql)
251
+ chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
252
+ sql = re.sub(r"[\r\n]+", " ", sql.lower())
253
+ sql = re.sub(r".*select ", "select ", sql.lower())
254
+ sql = re.sub(r" +", " ", sql)
255
+ sql = re.sub(r"([;;]|```).*", "", sql)
256
+ if sql[:len("select ")] != "select ":
257
+ return None, None
258
+ if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
259
+ if sql[:len("select *")] != "select *":
260
+ sql = "select doc_id,docnm_kwd," + sql[6:]
261
+ else:
262
+ flds = []
263
+ for k in field_map.keys():
264
+ if k in forbidden_select_fields4resume:
265
+ continue
266
+ if len(flds) > 11:
267
+ break
268
+ flds.append(k)
269
+ sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
270
+
271
+ print(f"“{question}” get SQL(refined): {sql}")
272
+
273
+ chat_logger.info(f"“{question}” get SQL(refined): {sql}")
274
+ tried_times += 1
275
+ return retrievaler.sql_retrieval(sql, format="json"), sql
276
+
277
+ tbl, sql = get_table()
278
+ if tbl is None:
279
+ return None
280
+ if tbl.get("error") and tried_times <= 2:
281
+ user_promt = """
282
+ 表名:{};
283
+ 数据库表字段说明如下:
284
+ {}
285
+
286
+ 问题如下:
287
+ {}
288
+
289
+ 你上一次给出的错误SQL如下:
290
+ {}
291
+
292
+ 后台报错如下:
293
+ {}
294
+
295
+ 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
296
+ """.format(
297
+ index_name(tenant_id),
298
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
299
+ question, sql, tbl["error"]
300
+ )
301
+ tbl, sql = get_table()
302
+ chat_logger.info("TRY it again: {}".format(sql))
303
+
304
+ chat_logger.info("GET table: {}".format(tbl))
305
+ print(tbl)
306
+ if tbl.get("error") or len(tbl["rows"]) == 0:
307
+ return None
308
+
309
+ docid_idx = set([ii for ii, c in enumerate(
310
+ tbl["columns"]) if c["name"] == "doc_id"])
311
+ docnm_idx = set([ii for ii, c in enumerate(
312
+ tbl["columns"]) if c["name"] == "docnm_kwd"])
313
+ clmn_idx = [ii for ii in range(
314
+ len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
315
+
316
+ # compose markdown table
317
+ clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
318
+ tbl["columns"][i]["name"])) for i in
319
+ clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
320
+
321
+ line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
322
+ ("|------|" if docid_idx and docid_idx else "")
323
+
324
+ rows = ["|" +
325
+ "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
326
+ "|" for r in tbl["rows"]]
327
+ if quota:
328
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
329
+ else:
330
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
331
+ rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
332
+
333
+ if not docid_idx or not docnm_idx:
334
+ chat_logger.warning("SQL missing field: " + sql)
335
+ return {
336
+ "answer": "\n".join([clmns, line, rows]),
337
+ "reference": {"chunks": [], "doc_aggs": []}
338
+ }
339
+
340
+ docid_idx = list(docid_idx)[0]
341
+ docnm_idx = list(docnm_idx)[0]
342
+ doc_aggs = {}
343
+ for r in tbl["rows"]:
344
+ if r[docid_idx] not in doc_aggs:
345
+ doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
346
+ doc_aggs[r[docid_idx]]["count"] += 1
347
+ return {
348
+ "answer": "\n".join([clmns, line, rows]),
349
+ "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
350
+ "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
351
+ doc_aggs.items()]}
352
+ }
353
+
354
+
355
+ def relevant(tenant_id, llm_id, question, contents: list):
356
+ if llm_id2llm_type(llm_id) == "image2text":
357
+ chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
358
+ else:
359
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
360
+ prompt = """
361
+ You are a grader assessing relevance of a retrieved document to a user question.
362
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
363
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
364
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
365
+ No other words needed except 'yes' or 'no'.
366
+ """
367
+ if not contents:return False
368
+ contents = "Documents: \n" + " - ".join(contents)
369
+ contents = f"Question: {question}\n" + contents
370
+ if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
371
+ contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
372
+ ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
373
+ if ans.lower().find("yes") >= 0: return True
374
+ return False
375
+
376
+
377
+ def rewrite(tenant_id, llm_id, question):
378
+ if llm_id2llm_type(llm_id) == "image2text":
379
+ chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
380
+ else:
381
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
382
+ prompt = """
383
+ You are an expert at query expansion to generate a paraphrasing of a question.
384
+ I can't retrieval relevant information from the knowledge base by using user's question directly.
385
+ You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
386
+ writing the abbreviation in its entirety, adding some extra descriptions or explanations,
387
+ changing the way of expression, translating the original question into another language (English/Chinese), etc.
388
+ And return 5 versions of question and one is from translation.
389
+ Just list the question. No other words are needed.
390
+ """
391
+ ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
392
+ return ans
api/db/services/document_service.py CHANGED
@@ -1,382 +1,382 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import random
17
- from datetime import datetime
18
- from elasticsearch_dsl import Q
19
- from peewee import fn
20
-
21
- from api.db.db_utils import bulk_insert_into_db
22
- from api.settings import stat_logger
23
- from api.utils import current_timestamp, get_format_time, get_uuid
24
- from rag.settings import SVR_QUEUE_NAME
25
- from rag.utils.es_conn import ELASTICSEARCH
26
- from rag.utils.minio_conn import MINIO
27
- from rag.nlp import search
28
-
29
- from api.db import FileType, TaskStatus, ParserType
30
- from api.db.db_models import DB, Knowledgebase, Tenant, Task
31
- from api.db.db_models import Document
32
- from api.db.services.common_service import CommonService
33
- from api.db.services.knowledgebase_service import KnowledgebaseService
34
- from api.db import StatusEnum
35
- from rag.utils.redis_conn import REDIS_CONN
36
-
37
-
38
- class DocumentService(CommonService):
39
- model = Document
40
-
41
- @classmethod
42
- @DB.connection_context()
43
- def get_by_kb_id(cls, kb_id, page_number, items_per_page,
44
- orderby, desc, keywords):
45
- if keywords:
46
- docs = cls.model.select().where(
47
- (cls.model.kb_id == kb_id),
48
- (fn.LOWER(cls.model.name).contains(keywords.lower()))
49
- )
50
- else:
51
- docs = cls.model.select().where(cls.model.kb_id == kb_id)
52
- count = docs.count()
53
- if desc:
54
- docs = docs.order_by(cls.model.getter_by(orderby).desc())
55
- else:
56
- docs = docs.order_by(cls.model.getter_by(orderby).asc())
57
-
58
- docs = docs.paginate(page_number, items_per_page)
59
-
60
- return list(docs.dicts()), count
61
-
62
- @classmethod
63
- @DB.connection_context()
64
- def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
65
- if keywords:
66
- docs = cls.model.select().where(
67
- (cls.model.kb_id == dataset_id),
68
- (fn.LOWER(cls.model.name).contains(keywords.lower()))
69
- )
70
- else:
71
- docs = cls.model.select().where(cls.model.kb_id == dataset_id)
72
-
73
- total = docs.count()
74
-
75
- if descend == 'True':
76
- docs = docs.order_by(cls.model.getter_by(order_by).desc())
77
- if descend == 'False':
78
- docs = docs.order_by(cls.model.getter_by(order_by).asc())
79
-
80
- docs = list(docs.dicts())
81
- docs_length = len(docs)
82
-
83
- if offset < 0 or offset > docs_length:
84
- raise IndexError("Offset is out of the valid range.")
85
-
86
- if count == -1:
87
- return docs[offset:], total
88
-
89
- return docs[offset:offset + count], total
90
-
91
- @classmethod
92
- @DB.connection_context()
93
- def insert(cls, doc):
94
- if not cls.save(**doc):
95
- raise RuntimeError("Database error (Document)!")
96
- e, doc = cls.get_by_id(doc["id"])
97
- if not e:
98
- raise RuntimeError("Database error (Document retrieval)!")
99
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
100
- if not KnowledgebaseService.update_by_id(
101
- kb.id, {"doc_num": kb.doc_num + 1}):
102
- raise RuntimeError("Database error (Knowledgebase)!")
103
- return doc
104
-
105
- @classmethod
106
- @DB.connection_context()
107
- def remove_document(cls, doc, tenant_id):
108
- ELASTICSEARCH.deleteByQuery(
109
- Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
110
- cls.clear_chunk_num(doc.id)
111
- return cls.delete_by_id(doc.id)
112
-
113
- @classmethod
114
- @DB.connection_context()
115
- def get_newly_uploaded(cls):
116
- fields = [
117
- cls.model.id,
118
- cls.model.kb_id,
119
- cls.model.parser_id,
120
- cls.model.parser_config,
121
- cls.model.name,
122
- cls.model.type,
123
- cls.model.location,
124
- cls.model.size,
125
- Knowledgebase.tenant_id,
126
- Tenant.embd_id,
127
- Tenant.img2txt_id,
128
- Tenant.asr_id,
129
- cls.model.update_time]
130
- docs = cls.model.select(*fields) \
131
- .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
132
- .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
133
- .where(
134
- cls.model.status == StatusEnum.VALID.value,
135
- ~(cls.model.type == FileType.VIRTUAL.value),
136
- cls.model.progress == 0,
137
- cls.model.update_time >= current_timestamp() - 1000 * 600,
138
- cls.model.run == TaskStatus.RUNNING.value)\
139
- .order_by(cls.model.update_time.asc())
140
- return list(docs.dicts())
141
-
142
- @classmethod
143
- @DB.connection_context()
144
- def get_unfinished_docs(cls):
145
- fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
146
- docs = cls.model.select(*fields) \
147
- .where(
148
- cls.model.status == StatusEnum.VALID.value,
149
- ~(cls.model.type == FileType.VIRTUAL.value),
150
- cls.model.progress < 1,
151
- cls.model.progress > 0)
152
- return list(docs.dicts())
153
-
154
- @classmethod
155
- @DB.connection_context()
156
- def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
157
- num = cls.model.update(token_num=cls.model.token_num + token_num,
158
- chunk_num=cls.model.chunk_num + chunk_num,
159
- process_duation=cls.model.process_duation + duation).where(
160
- cls.model.id == doc_id).execute()
161
- if num == 0:
162
- raise LookupError(
163
- "Document not found which is supposed to be there")
164
- num = Knowledgebase.update(
165
- token_num=Knowledgebase.token_num +
166
- token_num,
167
- chunk_num=Knowledgebase.chunk_num +
168
- chunk_num).where(
169
- Knowledgebase.id == kb_id).execute()
170
- return num
171
-
172
- @classmethod
173
- @DB.connection_context()
174
- def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
175
- num = cls.model.update(token_num=cls.model.token_num - token_num,
176
- chunk_num=cls.model.chunk_num - chunk_num,
177
- process_duation=cls.model.process_duation + duation).where(
178
- cls.model.id == doc_id).execute()
179
- if num == 0:
180
- raise LookupError(
181
- "Document not found which is supposed to be there")
182
- num = Knowledgebase.update(
183
- token_num=Knowledgebase.token_num -
184
- token_num,
185
- chunk_num=Knowledgebase.chunk_num -
186
- chunk_num
187
- ).where(
188
- Knowledgebase.id == kb_id).execute()
189
- return num
190
-
191
- @classmethod
192
- @DB.connection_context()
193
- def clear_chunk_num(cls, doc_id):
194
- doc = cls.model.get_by_id(doc_id)
195
- assert doc, "Can't fine document in database."
196
-
197
- num = Knowledgebase.update(
198
- token_num=Knowledgebase.token_num -
199
- doc.token_num,
200
- chunk_num=Knowledgebase.chunk_num -
201
- doc.chunk_num,
202
- doc_num=Knowledgebase.doc_num-1
203
- ).where(
204
- Knowledgebase.id == doc.kb_id).execute()
205
- return num
206
-
207
- @classmethod
208
- @DB.connection_context()
209
- def get_tenant_id(cls, doc_id):
210
- docs = cls.model.select(
211
- Knowledgebase.tenant_id).join(
212
- Knowledgebase, on=(
213
- Knowledgebase.id == cls.model.kb_id)).where(
214
- cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
215
- docs = docs.dicts()
216
- if not docs:
217
- return
218
- return docs[0]["tenant_id"]
219
-
220
- @classmethod
221
- @DB.connection_context()
222
- def get_tenant_id_by_name(cls, name):
223
- docs = cls.model.select(
224
- Knowledgebase.tenant_id).join(
225
- Knowledgebase, on=(
226
- Knowledgebase.id == cls.model.kb_id)).where(
227
- cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
228
- docs = docs.dicts()
229
- if not docs:
230
- return
231
- return docs[0]["tenant_id"]
232
-
233
- @classmethod
234
- @DB.connection_context()
235
- def get_embd_id(cls, doc_id):
236
- docs = cls.model.select(
237
- Knowledgebase.embd_id).join(
238
- Knowledgebase, on=(
239
- Knowledgebase.id == cls.model.kb_id)).where(
240
- cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
241
- docs = docs.dicts()
242
- if not docs:
243
- return
244
- return docs[0]["embd_id"]
245
-
246
- @classmethod
247
- @DB.connection_context()
248
- def get_doc_id_by_doc_name(cls, doc_name):
249
- fields = [cls.model.id]
250
- doc_id = cls.model.select(*fields) \
251
- .where(cls.model.name == doc_name)
252
- doc_id = doc_id.dicts()
253
- if not doc_id:
254
- return
255
- return doc_id[0]["id"]
256
-
257
- @classmethod
258
- @DB.connection_context()
259
- def get_thumbnails(cls, docids):
260
- fields = [cls.model.id, cls.model.thumbnail]
261
- return list(cls.model.select(
262
- *fields).where(cls.model.id.in_(docids)).dicts())
263
-
264
- @classmethod
265
- @DB.connection_context()
266
- def update_parser_config(cls, id, config):
267
- e, d = cls.get_by_id(id)
268
- if not e:
269
- raise LookupError(f"Document({id}) not found.")
270
-
271
- def dfs_update(old, new):
272
- for k, v in new.items():
273
- if k not in old:
274
- old[k] = v
275
- continue
276
- if isinstance(v, dict):
277
- assert isinstance(old[k], dict)
278
- dfs_update(old[k], v)
279
- else:
280
- old[k] = v
281
- dfs_update(d.parser_config, config)
282
- cls.update_by_id(id, {"parser_config": d.parser_config})
283
-
284
- @classmethod
285
- @DB.connection_context()
286
- def get_doc_count(cls, tenant_id):
287
- docs = cls.model.select(cls.model.id).join(Knowledgebase,
288
- on=(Knowledgebase.id == cls.model.kb_id)).where(
289
- Knowledgebase.tenant_id == tenant_id)
290
- return len(docs)
291
-
292
- @classmethod
293
- @DB.connection_context()
294
- def begin2parse(cls, docid):
295
- cls.update_by_id(
296
- docid, {"progress": random.random() * 1 / 100.,
297
- "progress_msg": "Task dispatched...",
298
- "process_begin_at": get_format_time()
299
- })
300
-
301
- @classmethod
302
- @DB.connection_context()
303
- def update_progress(cls):
304
- docs = cls.get_unfinished_docs()
305
- for d in docs:
306
- try:
307
- tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
308
- if not tsks:
309
- continue
310
- msg = []
311
- prg = 0
312
- finished = True
313
- bad = 0
314
- e, doc = DocumentService.get_by_id(d["id"])
315
- status = doc.run#TaskStatus.RUNNING.value
316
- for t in tsks:
317
- if 0 <= t.progress < 1:
318
- finished = False
319
- prg += t.progress if t.progress >= 0 else 0
320
- if t.progress_msg not in msg:
321
- msg.append(t.progress_msg)
322
- if t.progress == -1:
323
- bad += 1
324
- prg /= len(tsks)
325
- if finished and bad:
326
- prg = -1
327
- status = TaskStatus.FAIL.value
328
- elif finished:
329
- if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
330
- queue_raptor_tasks(d)
331
- prg *= 0.98
332
- msg.append("------ RAPTOR -------")
333
- else:
334
- status = TaskStatus.DONE.value
335
-
336
- msg = "\n".join(msg)
337
- info = {
338
- "process_duation": datetime.timestamp(
339
- datetime.now()) -
340
- d["process_begin_at"].timestamp(),
341
- "run": status}
342
- if prg != 0:
343
- info["progress"] = prg
344
- if msg:
345
- info["progress_msg"] = msg
346
- cls.update_by_id(d["id"], info)
347
- except Exception as e:
348
- stat_logger.error("fetch task exception:" + str(e))
349
-
350
- @classmethod
351
- @DB.connection_context()
352
- def get_kb_doc_count(cls, kb_id):
353
- return len(cls.model.select(cls.model.id).where(
354
- cls.model.kb_id == kb_id).dicts())
355
-
356
-
357
- @classmethod
358
- @DB.connection_context()
359
- def do_cancel(cls, doc_id):
360
- try:
361
- _, doc = DocumentService.get_by_id(doc_id)
362
- return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
363
- except Exception as e:
364
- pass
365
- return False
366
-
367
-
368
- def queue_raptor_tasks(doc):
369
- def new_task():
370
- nonlocal doc
371
- return {
372
- "id": get_uuid(),
373
- "doc_id": doc["id"],
374
- "from_page": 0,
375
- "to_page": -1,
376
- "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
377
- }
378
-
379
- task = new_task()
380
- bulk_insert_into_db(Task, [task], True)
381
- task["type"] = "raptor"
382
- assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import random
17
+ from datetime import datetime
18
+ from elasticsearch_dsl import Q
19
+ from peewee import fn
20
+
21
+ from api.db.db_utils import bulk_insert_into_db
22
+ from api.settings import stat_logger
23
+ from api.utils import current_timestamp, get_format_time, get_uuid
24
+ from rag.settings import SVR_QUEUE_NAME
25
+ from rag.utils.es_conn import ELASTICSEARCH
26
+ from rag.utils.minio_conn import MINIO
27
+ from rag.nlp import search
28
+
29
+ from api.db import FileType, TaskStatus, ParserType
30
+ from api.db.db_models import DB, Knowledgebase, Tenant, Task
31
+ from api.db.db_models import Document
32
+ from api.db.services.common_service import CommonService
33
+ from api.db.services.knowledgebase_service import KnowledgebaseService
34
+ from api.db import StatusEnum
35
+ from rag.utils.redis_conn import REDIS_CONN
36
+
37
+
38
+ class DocumentService(CommonService):
39
+ model = Document
40
+
41
+ @classmethod
42
+ @DB.connection_context()
43
+ def get_by_kb_id(cls, kb_id, page_number, items_per_page,
44
+ orderby, desc, keywords):
45
+ if keywords:
46
+ docs = cls.model.select().where(
47
+ (cls.model.kb_id == kb_id),
48
+ (fn.LOWER(cls.model.name).contains(keywords.lower()))
49
+ )
50
+ else:
51
+ docs = cls.model.select().where(cls.model.kb_id == kb_id)
52
+ count = docs.count()
53
+ if desc:
54
+ docs = docs.order_by(cls.model.getter_by(orderby).desc())
55
+ else:
56
+ docs = docs.order_by(cls.model.getter_by(orderby).asc())
57
+
58
+ docs = docs.paginate(page_number, items_per_page)
59
+
60
+ return list(docs.dicts()), count
61
+
62
+ @classmethod
63
+ @DB.connection_context()
64
+ def list_documents_in_dataset(cls, dataset_id, offset, count, order_by, descend, keywords):
65
+ if keywords:
66
+ docs = cls.model.select().where(
67
+ (cls.model.kb_id == dataset_id),
68
+ (fn.LOWER(cls.model.name).contains(keywords.lower()))
69
+ )
70
+ else:
71
+ docs = cls.model.select().where(cls.model.kb_id == dataset_id)
72
+
73
+ total = docs.count()
74
+
75
+ if descend == 'True':
76
+ docs = docs.order_by(cls.model.getter_by(order_by).desc())
77
+ if descend == 'False':
78
+ docs = docs.order_by(cls.model.getter_by(order_by).asc())
79
+
80
+ docs = list(docs.dicts())
81
+ docs_length = len(docs)
82
+
83
+ if offset < 0 or offset > docs_length:
84
+ raise IndexError("Offset is out of the valid range.")
85
+
86
+ if count == -1:
87
+ return docs[offset:], total
88
+
89
+ return docs[offset:offset + count], total
90
+
91
+ @classmethod
92
+ @DB.connection_context()
93
+ def insert(cls, doc):
94
+ if not cls.save(**doc):
95
+ raise RuntimeError("Database error (Document)!")
96
+ e, doc = cls.get_by_id(doc["id"])
97
+ if not e:
98
+ raise RuntimeError("Database error (Document retrieval)!")
99
+ e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
100
+ if not KnowledgebaseService.update_by_id(
101
+ kb.id, {"doc_num": kb.doc_num + 1}):
102
+ raise RuntimeError("Database error (Knowledgebase)!")
103
+ return doc
104
+
105
+ @classmethod
106
+ @DB.connection_context()
107
+ def remove_document(cls, doc, tenant_id):
108
+ ELASTICSEARCH.deleteByQuery(
109
+ Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
110
+ cls.clear_chunk_num(doc.id)
111
+ return cls.delete_by_id(doc.id)
112
+
113
+ @classmethod
114
+ @DB.connection_context()
115
+ def get_newly_uploaded(cls):
116
+ fields = [
117
+ cls.model.id,
118
+ cls.model.kb_id,
119
+ cls.model.parser_id,
120
+ cls.model.parser_config,
121
+ cls.model.name,
122
+ cls.model.type,
123
+ cls.model.location,
124
+ cls.model.size,
125
+ Knowledgebase.tenant_id,
126
+ Tenant.embd_id,
127
+ Tenant.img2txt_id,
128
+ Tenant.asr_id,
129
+ cls.model.update_time]
130
+ docs = cls.model.select(*fields) \
131
+ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
132
+ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
133
+ .where(
134
+ cls.model.status == StatusEnum.VALID.value,
135
+ ~(cls.model.type == FileType.VIRTUAL.value),
136
+ cls.model.progress == 0,
137
+ cls.model.update_time >= current_timestamp() - 1000 * 600,
138
+ cls.model.run == TaskStatus.RUNNING.value)\
139
+ .order_by(cls.model.update_time.asc())
140
+ return list(docs.dicts())
141
+
142
+ @classmethod
143
+ @DB.connection_context()
144
+ def get_unfinished_docs(cls):
145
+ fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg, cls.model.run]
146
+ docs = cls.model.select(*fields) \
147
+ .where(
148
+ cls.model.status == StatusEnum.VALID.value,
149
+ ~(cls.model.type == FileType.VIRTUAL.value),
150
+ cls.model.progress < 1,
151
+ cls.model.progress > 0)
152
+ return list(docs.dicts())
153
+
154
+ @classmethod
155
+ @DB.connection_context()
156
+ def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
157
+ num = cls.model.update(token_num=cls.model.token_num + token_num,
158
+ chunk_num=cls.model.chunk_num + chunk_num,
159
+ process_duation=cls.model.process_duation + duation).where(
160
+ cls.model.id == doc_id).execute()
161
+ if num == 0:
162
+ raise LookupError(
163
+ "Document not found which is supposed to be there")
164
+ num = Knowledgebase.update(
165
+ token_num=Knowledgebase.token_num +
166
+ token_num,
167
+ chunk_num=Knowledgebase.chunk_num +
168
+ chunk_num).where(
169
+ Knowledgebase.id == kb_id).execute()
170
+ return num
171
+
172
+ @classmethod
173
+ @DB.connection_context()
174
+ def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
175
+ num = cls.model.update(token_num=cls.model.token_num - token_num,
176
+ chunk_num=cls.model.chunk_num - chunk_num,
177
+ process_duation=cls.model.process_duation + duation).where(
178
+ cls.model.id == doc_id).execute()
179
+ if num == 0:
180
+ raise LookupError(
181
+ "Document not found which is supposed to be there")
182
+ num = Knowledgebase.update(
183
+ token_num=Knowledgebase.token_num -
184
+ token_num,
185
+ chunk_num=Knowledgebase.chunk_num -
186
+ chunk_num
187
+ ).where(
188
+ Knowledgebase.id == kb_id).execute()
189
+ return num
190
+
191
+ @classmethod
192
+ @DB.connection_context()
193
+ def clear_chunk_num(cls, doc_id):
194
+ doc = cls.model.get_by_id(doc_id)
195
+ assert doc, "Can't fine document in database."
196
+
197
+ num = Knowledgebase.update(
198
+ token_num=Knowledgebase.token_num -
199
+ doc.token_num,
200
+ chunk_num=Knowledgebase.chunk_num -
201
+ doc.chunk_num,
202
+ doc_num=Knowledgebase.doc_num-1
203
+ ).where(
204
+ Knowledgebase.id == doc.kb_id).execute()
205
+ return num
206
+
207
+ @classmethod
208
+ @DB.connection_context()
209
+ def get_tenant_id(cls, doc_id):
210
+ docs = cls.model.select(
211
+ Knowledgebase.tenant_id).join(
212
+ Knowledgebase, on=(
213
+ Knowledgebase.id == cls.model.kb_id)).where(
214
+ cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
215
+ docs = docs.dicts()
216
+ if not docs:
217
+ return
218
+ return docs[0]["tenant_id"]
219
+
220
+ @classmethod
221
+ @DB.connection_context()
222
+ def get_tenant_id_by_name(cls, name):
223
+ docs = cls.model.select(
224
+ Knowledgebase.tenant_id).join(
225
+ Knowledgebase, on=(
226
+ Knowledgebase.id == cls.model.kb_id)).where(
227
+ cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
228
+ docs = docs.dicts()
229
+ if not docs:
230
+ return
231
+ return docs[0]["tenant_id"]
232
+
233
+ @classmethod
234
+ @DB.connection_context()
235
+ def get_embd_id(cls, doc_id):
236
+ docs = cls.model.select(
237
+ Knowledgebase.embd_id).join(
238
+ Knowledgebase, on=(
239
+ Knowledgebase.id == cls.model.kb_id)).where(
240
+ cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
241
+ docs = docs.dicts()
242
+ if not docs:
243
+ return
244
+ return docs[0]["embd_id"]
245
+
246
+ @classmethod
247
+ @DB.connection_context()
248
+ def get_doc_id_by_doc_name(cls, doc_name):
249
+ fields = [cls.model.id]
250
+ doc_id = cls.model.select(*fields) \
251
+ .where(cls.model.name == doc_name)
252
+ doc_id = doc_id.dicts()
253
+ if not doc_id:
254
+ return
255
+ return doc_id[0]["id"]
256
+
257
+ @classmethod
258
+ @DB.connection_context()
259
+ def get_thumbnails(cls, docids):
260
+ fields = [cls.model.id, cls.model.thumbnail]
261
+ return list(cls.model.select(
262
+ *fields).where(cls.model.id.in_(docids)).dicts())
263
+
264
+ @classmethod
265
+ @DB.connection_context()
266
+ def update_parser_config(cls, id, config):
267
+ e, d = cls.get_by_id(id)
268
+ if not e:
269
+ raise LookupError(f"Document({id}) not found.")
270
+
271
+ def dfs_update(old, new):
272
+ for k, v in new.items():
273
+ if k not in old:
274
+ old[k] = v
275
+ continue
276
+ if isinstance(v, dict):
277
+ assert isinstance(old[k], dict)
278
+ dfs_update(old[k], v)
279
+ else:
280
+ old[k] = v
281
+ dfs_update(d.parser_config, config)
282
+ cls.update_by_id(id, {"parser_config": d.parser_config})
283
+
284
+ @classmethod
285
+ @DB.connection_context()
286
+ def get_doc_count(cls, tenant_id):
287
+ docs = cls.model.select(cls.model.id).join(Knowledgebase,
288
+ on=(Knowledgebase.id == cls.model.kb_id)).where(
289
+ Knowledgebase.tenant_id == tenant_id)
290
+ return len(docs)
291
+
292
+ @classmethod
293
+ @DB.connection_context()
294
+ def begin2parse(cls, docid):
295
+ cls.update_by_id(
296
+ docid, {"progress": random.random() * 1 / 100.,
297
+ "progress_msg": "Task dispatched...",
298
+ "process_begin_at": get_format_time()
299
+ })
300
+
301
+ @classmethod
302
+ @DB.connection_context()
303
+ def update_progress(cls):
304
+ docs = cls.get_unfinished_docs()
305
+ for d in docs:
306
+ try:
307
+ tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
308
+ if not tsks:
309
+ continue
310
+ msg = []
311
+ prg = 0
312
+ finished = True
313
+ bad = 0
314
+ e, doc = DocumentService.get_by_id(d["id"])
315
+ status = doc.run#TaskStatus.RUNNING.value
316
+ for t in tsks:
317
+ if 0 <= t.progress < 1:
318
+ finished = False
319
+ prg += t.progress if t.progress >= 0 else 0
320
+ if t.progress_msg not in msg:
321
+ msg.append(t.progress_msg)
322
+ if t.progress == -1:
323
+ bad += 1
324
+ prg /= len(tsks)
325
+ if finished and bad:
326
+ prg = -1
327
+ status = TaskStatus.FAIL.value
328
+ elif finished:
329
+ if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
330
+ queue_raptor_tasks(d)
331
+ prg *= 0.98
332
+ msg.append("------ RAPTOR -------")
333
+ else:
334
+ status = TaskStatus.DONE.value
335
+
336
+ msg = "\n".join(msg)
337
+ info = {
338
+ "process_duation": datetime.timestamp(
339
+ datetime.now()) -
340
+ d["process_begin_at"].timestamp(),
341
+ "run": status}
342
+ if prg != 0:
343
+ info["progress"] = prg
344
+ if msg:
345
+ info["progress_msg"] = msg
346
+ cls.update_by_id(d["id"], info)
347
+ except Exception as e:
348
+ stat_logger.error("fetch task exception:" + str(e))
349
+
350
+ @classmethod
351
+ @DB.connection_context()
352
+ def get_kb_doc_count(cls, kb_id):
353
+ return len(cls.model.select(cls.model.id).where(
354
+ cls.model.kb_id == kb_id).dicts())
355
+
356
+
357
+ @classmethod
358
+ @DB.connection_context()
359
+ def do_cancel(cls, doc_id):
360
+ try:
361
+ _, doc = DocumentService.get_by_id(doc_id)
362
+ return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
363
+ except Exception as e:
364
+ pass
365
+ return False
366
+
367
+
368
+ def queue_raptor_tasks(doc):
369
+ def new_task():
370
+ nonlocal doc
371
+ return {
372
+ "id": get_uuid(),
373
+ "doc_id": doc["id"],
374
+ "from_page": 0,
375
+ "to_page": -1,
376
+ "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
377
+ }
378
+
379
+ task = new_task()
380
+ bulk_insert_into_db(Task, [task], True)
381
+ task["type"] = "raptor"
382
+ assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
api/db/services/knowledgebase_service.py CHANGED
@@ -1,144 +1,144 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from api.db import StatusEnum, TenantPermission
17
- from api.db.db_models import Knowledgebase, DB, Tenant
18
- from api.db.services.common_service import CommonService
19
-
20
-
21
- class KnowledgebaseService(CommonService):
22
- model = Knowledgebase
23
-
24
- @classmethod
25
- @DB.connection_context()
26
- def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
27
- page_number, items_per_page, orderby, desc):
28
- kbs = cls.model.select().where(
29
- ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
30
- TenantPermission.TEAM.value)) | (
31
- cls.model.tenant_id == user_id))
32
- & (cls.model.status == StatusEnum.VALID.value)
33
- )
34
- if desc:
35
- kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
36
- else:
37
- kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
38
-
39
- kbs = kbs.paginate(page_number, items_per_page)
40
-
41
- return list(kbs.dicts())
42
-
43
- @classmethod
44
- @DB.connection_context()
45
- def get_by_tenant_ids_by_offset(cls, joined_tenant_ids, user_id, offset, count, orderby, desc):
46
- kbs = cls.model.select().where(
47
- ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
48
- TenantPermission.TEAM.value)) | (
49
- cls.model.tenant_id == user_id))
50
- & (cls.model.status == StatusEnum.VALID.value)
51
- )
52
- if desc:
53
- kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
54
- else:
55
- kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
56
-
57
- kbs = list(kbs.dicts())
58
-
59
- kbs_length = len(kbs)
60
- if offset < 0 or offset > kbs_length:
61
- raise IndexError("Offset is out of the valid range.")
62
-
63
- if count == -1:
64
- return kbs[offset:]
65
-
66
- return kbs[offset:offset+count]
67
-
68
- @classmethod
69
- @DB.connection_context()
70
- def get_detail(cls, kb_id):
71
- fields = [
72
- cls.model.id,
73
- #Tenant.embd_id,
74
- cls.model.embd_id,
75
- cls.model.avatar,
76
- cls.model.name,
77
- cls.model.language,
78
- cls.model.description,
79
- cls.model.permission,
80
- cls.model.doc_num,
81
- cls.model.token_num,
82
- cls.model.chunk_num,
83
- cls.model.parser_id,
84
- cls.model.parser_config]
85
- kbs = cls.model.select(*fields).join(Tenant, on=(
86
- (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
87
- (cls.model.id == kb_id),
88
- (cls.model.status == StatusEnum.VALID.value)
89
- )
90
- if not kbs:
91
- return
92
- d = kbs[0].to_dict()
93
- #d["embd_id"] = kbs[0].tenant.embd_id
94
- return d
95
-
96
- @classmethod
97
- @DB.connection_context()
98
- def update_parser_config(cls, id, config):
99
- e, m = cls.get_by_id(id)
100
- if not e:
101
- raise LookupError(f"knowledgebase({id}) not found.")
102
-
103
- def dfs_update(old, new):
104
- for k, v in new.items():
105
- if k not in old:
106
- old[k] = v
107
- continue
108
- if isinstance(v, dict):
109
- assert isinstance(old[k], dict)
110
- dfs_update(old[k], v)
111
- elif isinstance(v, list):
112
- assert isinstance(old[k], list)
113
- old[k] = list(set(old[k] + v))
114
- else:
115
- old[k] = v
116
-
117
- dfs_update(m.parser_config, config)
118
- cls.update_by_id(id, {"parser_config": m.parser_config})
119
-
120
- @classmethod
121
- @DB.connection_context()
122
- def get_field_map(cls, ids):
123
- conf = {}
124
- for k in cls.get_by_ids(ids):
125
- if k.parser_config and "field_map" in k.parser_config:
126
- conf.update(k.parser_config["field_map"])
127
- return conf
128
-
129
- @classmethod
130
- @DB.connection_context()
131
- def get_by_name(cls, kb_name, tenant_id):
132
- kb = cls.model.select().where(
133
- (cls.model.name == kb_name)
134
- & (cls.model.tenant_id == tenant_id)
135
- & (cls.model.status == StatusEnum.VALID.value)
136
- )
137
- if kb:
138
- return True, kb[0]
139
- return False, None
140
-
141
- @classmethod
142
- @DB.connection_context()
143
- def get_all_ids(cls):
144
- return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from api.db import StatusEnum, TenantPermission
17
+ from api.db.db_models import Knowledgebase, DB, Tenant
18
+ from api.db.services.common_service import CommonService
19
+
20
+
21
+ class KnowledgebaseService(CommonService):
22
+ model = Knowledgebase
23
+
24
+ @classmethod
25
+ @DB.connection_context()
26
+ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
27
+ page_number, items_per_page, orderby, desc):
28
+ kbs = cls.model.select().where(
29
+ ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
30
+ TenantPermission.TEAM.value)) | (
31
+ cls.model.tenant_id == user_id))
32
+ & (cls.model.status == StatusEnum.VALID.value)
33
+ )
34
+ if desc:
35
+ kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
36
+ else:
37
+ kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
38
+
39
+ kbs = kbs.paginate(page_number, items_per_page)
40
+
41
+ return list(kbs.dicts())
42
+
43
+ @classmethod
44
+ @DB.connection_context()
45
+ def get_by_tenant_ids_by_offset(cls, joined_tenant_ids, user_id, offset, count, orderby, desc):
46
+ kbs = cls.model.select().where(
47
+ ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
48
+ TenantPermission.TEAM.value)) | (
49
+ cls.model.tenant_id == user_id))
50
+ & (cls.model.status == StatusEnum.VALID.value)
51
+ )
52
+ if desc:
53
+ kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
54
+ else:
55
+ kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
56
+
57
+ kbs = list(kbs.dicts())
58
+
59
+ kbs_length = len(kbs)
60
+ if offset < 0 or offset > kbs_length:
61
+ raise IndexError("Offset is out of the valid range.")
62
+
63
+ if count == -1:
64
+ return kbs[offset:]
65
+
66
+ return kbs[offset:offset+count]
67
+
68
+ @classmethod
69
+ @DB.connection_context()
70
+ def get_detail(cls, kb_id):
71
+ fields = [
72
+ cls.model.id,
73
+ #Tenant.embd_id,
74
+ cls.model.embd_id,
75
+ cls.model.avatar,
76
+ cls.model.name,
77
+ cls.model.language,
78
+ cls.model.description,
79
+ cls.model.permission,
80
+ cls.model.doc_num,
81
+ cls.model.token_num,
82
+ cls.model.chunk_num,
83
+ cls.model.parser_id,
84
+ cls.model.parser_config]
85
+ kbs = cls.model.select(*fields).join(Tenant, on=(
86
+ (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
87
+ (cls.model.id == kb_id),
88
+ (cls.model.status == StatusEnum.VALID.value)
89
+ )
90
+ if not kbs:
91
+ return
92
+ d = kbs[0].to_dict()
93
+ #d["embd_id"] = kbs[0].tenant.embd_id
94
+ return d
95
+
96
+ @classmethod
97
+ @DB.connection_context()
98
+ def update_parser_config(cls, id, config):
99
+ e, m = cls.get_by_id(id)
100
+ if not e:
101
+ raise LookupError(f"knowledgebase({id}) not found.")
102
+
103
+ def dfs_update(old, new):
104
+ for k, v in new.items():
105
+ if k not in old:
106
+ old[k] = v
107
+ continue
108
+ if isinstance(v, dict):
109
+ assert isinstance(old[k], dict)
110
+ dfs_update(old[k], v)
111
+ elif isinstance(v, list):
112
+ assert isinstance(old[k], list)
113
+ old[k] = list(set(old[k] + v))
114
+ else:
115
+ old[k] = v
116
+
117
+ dfs_update(m.parser_config, config)
118
+ cls.update_by_id(id, {"parser_config": m.parser_config})
119
+
120
+ @classmethod
121
+ @DB.connection_context()
122
+ def get_field_map(cls, ids):
123
+ conf = {}
124
+ for k in cls.get_by_ids(ids):
125
+ if k.parser_config and "field_map" in k.parser_config:
126
+ conf.update(k.parser_config["field_map"])
127
+ return conf
128
+
129
+ @classmethod
130
+ @DB.connection_context()
131
+ def get_by_name(cls, kb_name, tenant_id):
132
+ kb = cls.model.select().where(
133
+ (cls.model.name == kb_name)
134
+ & (cls.model.tenant_id == tenant_id)
135
+ & (cls.model.status == StatusEnum.VALID.value)
136
+ )
137
+ if kb:
138
+ return True, kb[0]
139
+ return False, None
140
+
141
+ @classmethod
142
+ @DB.connection_context()
143
+ def get_all_ids(cls):
144
+ return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
api/db/services/llm_service.py CHANGED
@@ -1,242 +1,242 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from api.db.services.user_service import TenantService
17
- from api.settings import database_logger
18
- from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel
19
- from api.db import LLMType
20
- from api.db.db_models import DB, UserTenant
21
- from api.db.db_models import LLMFactories, LLM, TenantLLM
22
- from api.db.services.common_service import CommonService
23
-
24
-
25
- class LLMFactoriesService(CommonService):
26
- model = LLMFactories
27
-
28
-
29
- class LLMService(CommonService):
30
- model = LLM
31
-
32
-
33
- class TenantLLMService(CommonService):
34
- model = TenantLLM
35
-
36
- @classmethod
37
- @DB.connection_context()
38
- def get_api_key(cls, tenant_id, model_name):
39
- objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
40
- if not objs:
41
- return
42
- return objs[0]
43
-
44
- @classmethod
45
- @DB.connection_context()
46
- def get_my_llms(cls, tenant_id):
47
- fields = [
48
- cls.model.llm_factory,
49
- LLMFactories.logo,
50
- LLMFactories.tags,
51
- cls.model.model_type,
52
- cls.model.llm_name,
53
- cls.model.used_tokens
54
- ]
55
- objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
56
- cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
57
-
58
- return list(objs)
59
-
60
- @classmethod
61
- @DB.connection_context()
62
- def model_instance(cls, tenant_id, llm_type,
63
- llm_name=None, lang="Chinese"):
64
- e, tenant = TenantService.get_by_id(tenant_id)
65
- if not e:
66
- raise LookupError("Tenant not found")
67
-
68
- if llm_type == LLMType.EMBEDDING.value:
69
- mdlnm = tenant.embd_id if not llm_name else llm_name
70
- elif llm_type == LLMType.SPEECH2TEXT.value:
71
- mdlnm = tenant.asr_id
72
- elif llm_type == LLMType.IMAGE2TEXT.value:
73
- mdlnm = tenant.img2txt_id if not llm_name else llm_name
74
- elif llm_type == LLMType.CHAT.value:
75
- mdlnm = tenant.llm_id if not llm_name else llm_name
76
- elif llm_type == LLMType.RERANK:
77
- mdlnm = tenant.rerank_id if not llm_name else llm_name
78
- else:
79
- assert False, "LLM type error"
80
-
81
- model_config = cls.get_api_key(tenant_id, mdlnm)
82
- if model_config: model_config = model_config.to_dict()
83
- if not model_config:
84
- if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
85
- llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
86
- if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
87
- model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
88
- if not model_config:
89
- if llm_name == "flag-embedding":
90
- model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
91
- "llm_name": llm_name, "api_base": ""}
92
- else:
93
- if not mdlnm:
94
- raise LookupError(f"Type of {llm_type} model is not set.")
95
- raise LookupError("Model({}) not authorized".format(mdlnm))
96
-
97
- if llm_type == LLMType.EMBEDDING.value:
98
- if model_config["llm_factory"] not in EmbeddingModel:
99
- return
100
- return EmbeddingModel[model_config["llm_factory"]](
101
- model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
102
-
103
- if llm_type == LLMType.RERANK:
104
- if model_config["llm_factory"] not in RerankModel:
105
- return
106
- return RerankModel[model_config["llm_factory"]](
107
- model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
108
-
109
- if llm_type == LLMType.IMAGE2TEXT.value:
110
- if model_config["llm_factory"] not in CvModel:
111
- return
112
- return CvModel[model_config["llm_factory"]](
113
- model_config["api_key"], model_config["llm_name"], lang,
114
- base_url=model_config["api_base"]
115
- )
116
-
117
- if llm_type == LLMType.CHAT.value:
118
- if model_config["llm_factory"] not in ChatModel:
119
- return
120
- return ChatModel[model_config["llm_factory"]](
121
- model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
122
-
123
- if llm_type == LLMType.SPEECH2TEXT:
124
- if model_config["llm_factory"] not in Seq2txtModel:
125
- return
126
- return Seq2txtModel[model_config["llm_factory"]](
127
- model_config["api_key"], model_config["llm_name"], lang,
128
- base_url=model_config["api_base"]
129
- )
130
-
131
- @classmethod
132
- @DB.connection_context()
133
- def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
134
- e, tenant = TenantService.get_by_id(tenant_id)
135
- if not e:
136
- raise LookupError("Tenant not found")
137
-
138
- if llm_type == LLMType.EMBEDDING.value:
139
- mdlnm = tenant.embd_id
140
- elif llm_type == LLMType.SPEECH2TEXT.value:
141
- mdlnm = tenant.asr_id
142
- elif llm_type == LLMType.IMAGE2TEXT.value:
143
- mdlnm = tenant.img2txt_id
144
- elif llm_type == LLMType.CHAT.value:
145
- mdlnm = tenant.llm_id if not llm_name else llm_name
146
- elif llm_type == LLMType.RERANK:
147
- mdlnm = tenant.llm_id if not llm_name else llm_name
148
- else:
149
- assert False, "LLM type error"
150
-
151
- num = 0
152
- try:
153
- for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
154
- num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
155
- .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
156
- .execute()
157
- except Exception as e:
158
- pass
159
- return num
160
-
161
- @classmethod
162
- @DB.connection_context()
163
- def get_openai_models(cls):
164
- objs = cls.model.select().where(
165
- (cls.model.llm_factory == "OpenAI"),
166
- ~(cls.model.llm_name == "text-embedding-3-small"),
167
- ~(cls.model.llm_name == "text-embedding-3-large")
168
- ).dicts()
169
- return list(objs)
170
-
171
-
172
- class LLMBundle(object):
173
- def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
174
- self.tenant_id = tenant_id
175
- self.llm_type = llm_type
176
- self.llm_name = llm_name
177
- self.mdl = TenantLLMService.model_instance(
178
- tenant_id, llm_type, llm_name, lang=lang)
179
- assert self.mdl, "Can't find mole for {}/{}/{}".format(
180
- tenant_id, llm_type, llm_name)
181
- self.max_length = 512
182
- for lm in LLMService.query(llm_name=llm_name):
183
- self.max_length = lm.max_tokens
184
- break
185
-
186
- def encode(self, texts: list, batch_size=32):
187
- emd, used_tokens = self.mdl.encode(texts, batch_size)
188
- if not TenantLLMService.increase_usage(
189
- self.tenant_id, self.llm_type, used_tokens):
190
- database_logger.error(
191
- "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
192
- return emd, used_tokens
193
-
194
- def encode_queries(self, query: str):
195
- emd, used_tokens = self.mdl.encode_queries(query)
196
- if not TenantLLMService.increase_usage(
197
- self.tenant_id, self.llm_type, used_tokens):
198
- database_logger.error(
199
- "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
200
- return emd, used_tokens
201
-
202
- def similarity(self, query: str, texts: list):
203
- sim, used_tokens = self.mdl.similarity(query, texts)
204
- if not TenantLLMService.increase_usage(
205
- self.tenant_id, self.llm_type, used_tokens):
206
- database_logger.error(
207
- "Can't update token usage for {}/RERANK".format(self.tenant_id))
208
- return sim, used_tokens
209
-
210
- def describe(self, image, max_tokens=300):
211
- txt, used_tokens = self.mdl.describe(image, max_tokens)
212
- if not TenantLLMService.increase_usage(
213
- self.tenant_id, self.llm_type, used_tokens):
214
- database_logger.error(
215
- "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
216
- return txt
217
-
218
- def transcription(self, audio):
219
- txt, used_tokens = self.mdl.transcription(audio)
220
- if not TenantLLMService.increase_usage(
221
- self.tenant_id, self.llm_type, used_tokens):
222
- database_logger.error(
223
- "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id))
224
- return txt
225
-
226
- def chat(self, system, history, gen_conf):
227
- txt, used_tokens = self.mdl.chat(system, history, gen_conf)
228
- if not TenantLLMService.increase_usage(
229
- self.tenant_id, self.llm_type, used_tokens, self.llm_name):
230
- database_logger.error(
231
- "Can't update token usage for {}/CHAT".format(self.tenant_id))
232
- return txt
233
-
234
- def chat_streamly(self, system, history, gen_conf):
235
- for txt in self.mdl.chat_streamly(system, history, gen_conf):
236
- if isinstance(txt, int):
237
- if not TenantLLMService.increase_usage(
238
- self.tenant_id, self.llm_type, txt, self.llm_name):
239
- database_logger.error(
240
- "Can't update token usage for {}/CHAT".format(self.tenant_id))
241
- return
242
- yield txt
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ from api.db.services.user_service import TenantService
17
+ from api.settings import database_logger
18
+ from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel
19
+ from api.db import LLMType
20
+ from api.db.db_models import DB, UserTenant
21
+ from api.db.db_models import LLMFactories, LLM, TenantLLM
22
+ from api.db.services.common_service import CommonService
23
+
24
+
25
+ class LLMFactoriesService(CommonService):
26
+ model = LLMFactories
27
+
28
+
29
+ class LLMService(CommonService):
30
+ model = LLM
31
+
32
+
33
+ class TenantLLMService(CommonService):
34
+ model = TenantLLM
35
+
36
+ @classmethod
37
+ @DB.connection_context()
38
+ def get_api_key(cls, tenant_id, model_name):
39
+ objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
40
+ if not objs:
41
+ return
42
+ return objs[0]
43
+
44
+ @classmethod
45
+ @DB.connection_context()
46
+ def get_my_llms(cls, tenant_id):
47
+ fields = [
48
+ cls.model.llm_factory,
49
+ LLMFactories.logo,
50
+ LLMFactories.tags,
51
+ cls.model.model_type,
52
+ cls.model.llm_name,
53
+ cls.model.used_tokens
54
+ ]
55
+ objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
56
+ cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
57
+
58
+ return list(objs)
59
+
60
+ @classmethod
61
+ @DB.connection_context()
62
+ def model_instance(cls, tenant_id, llm_type,
63
+ llm_name=None, lang="Chinese"):
64
+ e, tenant = TenantService.get_by_id(tenant_id)
65
+ if not e:
66
+ raise LookupError("Tenant not found")
67
+
68
+ if llm_type == LLMType.EMBEDDING.value:
69
+ mdlnm = tenant.embd_id if not llm_name else llm_name
70
+ elif llm_type == LLMType.SPEECH2TEXT.value:
71
+ mdlnm = tenant.asr_id
72
+ elif llm_type == LLMType.IMAGE2TEXT.value:
73
+ mdlnm = tenant.img2txt_id if not llm_name else llm_name
74
+ elif llm_type == LLMType.CHAT.value:
75
+ mdlnm = tenant.llm_id if not llm_name else llm_name
76
+ elif llm_type == LLMType.RERANK:
77
+ mdlnm = tenant.rerank_id if not llm_name else llm_name
78
+ else:
79
+ assert False, "LLM type error"
80
+
81
+ model_config = cls.get_api_key(tenant_id, mdlnm)
82
+ if model_config: model_config = model_config.to_dict()
83
+ if not model_config:
84
+ if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
85
+ llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
86
+ if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
87
+ model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
88
+ if not model_config:
89
+ if llm_name == "flag-embedding":
90
+ model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
91
+ "llm_name": llm_name, "api_base": ""}
92
+ else:
93
+ if not mdlnm:
94
+ raise LookupError(f"Type of {llm_type} model is not set.")
95
+ raise LookupError("Model({}) not authorized".format(mdlnm))
96
+
97
+ if llm_type == LLMType.EMBEDDING.value:
98
+ if model_config["llm_factory"] not in EmbeddingModel:
99
+ return
100
+ return EmbeddingModel[model_config["llm_factory"]](
101
+ model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
102
+
103
+ if llm_type == LLMType.RERANK:
104
+ if model_config["llm_factory"] not in RerankModel:
105
+ return
106
+ return RerankModel[model_config["llm_factory"]](
107
+ model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
108
+
109
+ if llm_type == LLMType.IMAGE2TEXT.value:
110
+ if model_config["llm_factory"] not in CvModel:
111
+ return
112
+ return CvModel[model_config["llm_factory"]](
113
+ model_config["api_key"], model_config["llm_name"], lang,
114
+ base_url=model_config["api_base"]
115
+ )
116
+
117
+ if llm_type == LLMType.CHAT.value:
118
+ if model_config["llm_factory"] not in ChatModel:
119
+ return
120
+ return ChatModel[model_config["llm_factory"]](
121
+ model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
122
+
123
+ if llm_type == LLMType.SPEECH2TEXT:
124
+ if model_config["llm_factory"] not in Seq2txtModel:
125
+ return
126
+ return Seq2txtModel[model_config["llm_factory"]](
127
+ model_config["api_key"], model_config["llm_name"], lang,
128
+ base_url=model_config["api_base"]
129
+ )
130
+
131
+ @classmethod
132
+ @DB.connection_context()
133
+ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
134
+ e, tenant = TenantService.get_by_id(tenant_id)
135
+ if not e:
136
+ raise LookupError("Tenant not found")
137
+
138
+ if llm_type == LLMType.EMBEDDING.value:
139
+ mdlnm = tenant.embd_id
140
+ elif llm_type == LLMType.SPEECH2TEXT.value:
141
+ mdlnm = tenant.asr_id
142
+ elif llm_type == LLMType.IMAGE2TEXT.value:
143
+ mdlnm = tenant.img2txt_id
144
+ elif llm_type == LLMType.CHAT.value:
145
+ mdlnm = tenant.llm_id if not llm_name else llm_name
146
+ elif llm_type == LLMType.RERANK:
147
+ mdlnm = tenant.llm_id if not llm_name else llm_name
148
+ else:
149
+ assert False, "LLM type error"
150
+
151
+ num = 0
152
+ try:
153
+ for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
154
+ num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
155
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
156
+ .execute()
157
+ except Exception as e:
158
+ pass
159
+ return num
160
+
161
+ @classmethod
162
+ @DB.connection_context()
163
+ def get_openai_models(cls):
164
+ objs = cls.model.select().where(
165
+ (cls.model.llm_factory == "OpenAI"),
166
+ ~(cls.model.llm_name == "text-embedding-3-small"),
167
+ ~(cls.model.llm_name == "text-embedding-3-large")
168
+ ).dicts()
169
+ return list(objs)
170
+
171
+
172
+ class LLMBundle(object):
173
+ def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
174
+ self.tenant_id = tenant_id
175
+ self.llm_type = llm_type
176
+ self.llm_name = llm_name
177
+ self.mdl = TenantLLMService.model_instance(
178
+ tenant_id, llm_type, llm_name, lang=lang)
179
+ assert self.mdl, "Can't find mole for {}/{}/{}".format(
180
+ tenant_id, llm_type, llm_name)
181
+ self.max_length = 512
182
+ for lm in LLMService.query(llm_name=llm_name):
183
+ self.max_length = lm.max_tokens
184
+ break
185
+
186
+ def encode(self, texts: list, batch_size=32):
187
+ emd, used_tokens = self.mdl.encode(texts, batch_size)
188
+ if not TenantLLMService.increase_usage(
189
+ self.tenant_id, self.llm_type, used_tokens):
190
+ database_logger.error(
191
+ "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
192
+ return emd, used_tokens
193
+
194
+ def encode_queries(self, query: str):
195
+ emd, used_tokens = self.mdl.encode_queries(query)
196
+ if not TenantLLMService.increase_usage(
197
+ self.tenant_id, self.llm_type, used_tokens):
198
+ database_logger.error(
199
+ "Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
200
+ return emd, used_tokens
201
+
202
+ def similarity(self, query: str, texts: list):
203
+ sim, used_tokens = self.mdl.similarity(query, texts)
204
+ if not TenantLLMService.increase_usage(
205
+ self.tenant_id, self.llm_type, used_tokens):
206
+ database_logger.error(
207
+ "Can't update token usage for {}/RERANK".format(self.tenant_id))
208
+ return sim, used_tokens
209
+
210
+ def describe(self, image, max_tokens=300):
211
+ txt, used_tokens = self.mdl.describe(image, max_tokens)
212
+ if not TenantLLMService.increase_usage(
213
+ self.tenant_id, self.llm_type, used_tokens):
214
+ database_logger.error(
215
+ "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
216
+ return txt
217
+
218
+ def transcription(self, audio):
219
+ txt, used_tokens = self.mdl.transcription(audio)
220
+ if not TenantLLMService.increase_usage(
221
+ self.tenant_id, self.llm_type, used_tokens):
222
+ database_logger.error(
223
+ "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id))
224
+ return txt
225
+
226
+ def chat(self, system, history, gen_conf):
227
+ txt, used_tokens = self.mdl.chat(system, history, gen_conf)
228
+ if not TenantLLMService.increase_usage(
229
+ self.tenant_id, self.llm_type, used_tokens, self.llm_name):
230
+ database_logger.error(
231
+ "Can't update token usage for {}/CHAT".format(self.tenant_id))
232
+ return txt
233
+
234
+ def chat_streamly(self, system, history, gen_conf):
235
+ for txt in self.mdl.chat_streamly(system, history, gen_conf):
236
+ if isinstance(txt, int):
237
+ if not TenantLLMService.increase_usage(
238
+ self.tenant_id, self.llm_type, txt, self.llm_name):
239
+ database_logger.error(
240
+ "Can't update token usage for {}/CHAT".format(self.tenant_id))
241
+ return
242
+ yield txt
api/db/services/task_service.py CHANGED
@@ -1,175 +1,175 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import os
17
- import random
18
-
19
- from api.db.db_utils import bulk_insert_into_db
20
- from deepdoc.parser import PdfParser
21
- from peewee import JOIN
22
- from api.db.db_models import DB, File2Document, File
23
- from api.db import StatusEnum, FileType, TaskStatus
24
- from api.db.db_models import Task, Document, Knowledgebase, Tenant
25
- from api.db.services.common_service import CommonService
26
- from api.db.services.document_service import DocumentService
27
- from api.utils import current_timestamp, get_uuid
28
- from deepdoc.parser.excel_parser import RAGFlowExcelParser
29
- from rag.settings import SVR_QUEUE_NAME
30
- from rag.utils.minio_conn import MINIO
31
- from rag.utils.redis_conn import REDIS_CONN
32
-
33
-
34
- class TaskService(CommonService):
35
- model = Task
36
-
37
- @classmethod
38
- @DB.connection_context()
39
- def get_tasks(cls, task_id):
40
- fields = [
41
- cls.model.id,
42
- cls.model.doc_id,
43
- cls.model.from_page,
44
- cls.model.to_page,
45
- Document.kb_id,
46
- Document.parser_id,
47
- Document.parser_config,
48
- Document.name,
49
- Document.type,
50
- Document.location,
51
- Document.size,
52
- Knowledgebase.tenant_id,
53
- Knowledgebase.language,
54
- Knowledgebase.embd_id,
55
- Tenant.img2txt_id,
56
- Tenant.asr_id,
57
- Tenant.llm_id,
58
- cls.model.update_time]
59
- docs = cls.model.select(*fields) \
60
- .join(Document, on=(cls.model.doc_id == Document.id)) \
61
- .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
62
- .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
63
- .where(cls.model.id == task_id)
64
- docs = list(docs.dicts())
65
- if not docs: return []
66
-
67
- cls.model.update(progress_msg=cls.model.progress_msg + "\n" + "Task has been received.",
68
- progress=random.random() / 10.).where(
69
- cls.model.id == docs[0]["id"]).execute()
70
- return docs
71
-
72
- @classmethod
73
- @DB.connection_context()
74
- def get_ongoing_doc_name(cls):
75
- with DB.lock("get_task", -1):
76
- docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
77
- .join(Document, on=(cls.model.doc_id == Document.id)) \
78
- .join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
79
- .join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
80
- .where(
81
- Document.status == StatusEnum.VALID.value,
82
- Document.run == TaskStatus.RUNNING.value,
83
- ~(Document.type == FileType.VIRTUAL.value),
84
- cls.model.progress < 1,
85
- cls.model.create_time >= current_timestamp() - 1000 * 600
86
- )
87
- docs = list(docs.dicts())
88
- if not docs: return []
89
-
90
- return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
91
-
92
- @classmethod
93
- @DB.connection_context()
94
- def do_cancel(cls, id):
95
- try:
96
- task = cls.model.get_by_id(id)
97
- _, doc = DocumentService.get_by_id(task.doc_id)
98
- return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
99
- except Exception as e:
100
- pass
101
- return False
102
-
103
- @classmethod
104
- @DB.connection_context()
105
- def update_progress(cls, id, info):
106
- if os.environ.get("MACOS"):
107
- if info["progress_msg"]:
108
- cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
109
- cls.model.id == id).execute()
110
- if "progress" in info:
111
- cls.model.update(progress=info["progress"]).where(
112
- cls.model.id == id).execute()
113
- return
114
-
115
- with DB.lock("update_progress", -1):
116
- if info["progress_msg"]:
117
- cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
118
- cls.model.id == id).execute()
119
- if "progress" in info:
120
- cls.model.update(progress=info["progress"]).where(
121
- cls.model.id == id).execute()
122
-
123
-
124
- def queue_tasks(doc, bucket, name):
125
- def new_task():
126
- nonlocal doc
127
- return {
128
- "id": get_uuid(),
129
- "doc_id": doc["id"]
130
- }
131
- tsks = []
132
-
133
- if doc["type"] == FileType.PDF.value:
134
- file_bin = MINIO.get(bucket, name)
135
- do_layout = doc["parser_config"].get("layout_recognize", True)
136
- pages = PdfParser.total_page_number(doc["name"], file_bin)
137
- page_size = doc["parser_config"].get("task_page_size", 12)
138
- if doc["parser_id"] == "paper":
139
- page_size = doc["parser_config"].get("task_page_size", 22)
140
- if doc["parser_id"] == "one":
141
- page_size = 1000000000
142
- if doc["parser_id"] == "knowledge_graph":
143
- page_size = 1000000000
144
- if not do_layout:
145
- page_size = 1000000000
146
- page_ranges = doc["parser_config"].get("pages")
147
- if not page_ranges:
148
- page_ranges = [(1, 100000)]
149
- for s, e in page_ranges:
150
- s -= 1
151
- s = max(0, s)
152
- e = min(e - 1, pages)
153
- for p in range(s, e, page_size):
154
- task = new_task()
155
- task["from_page"] = p
156
- task["to_page"] = min(p + page_size, e)
157
- tsks.append(task)
158
-
159
- elif doc["parser_id"] == "table":
160
- file_bin = MINIO.get(bucket, name)
161
- rn = RAGFlowExcelParser.row_number(
162
- doc["name"], file_bin)
163
- for i in range(0, rn, 3000):
164
- task = new_task()
165
- task["from_page"] = i
166
- task["to_page"] = min(i + 3000, rn)
167
- tsks.append(task)
168
- else:
169
- tsks.append(new_task())
170
-
171
- bulk_insert_into_db(Task, tsks, True)
172
- DocumentService.begin2parse(doc["id"])
173
-
174
- for t in tsks:
175
- assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import os
17
+ import random
18
+
19
+ from api.db.db_utils import bulk_insert_into_db
20
+ from deepdoc.parser import PdfParser
21
+ from peewee import JOIN
22
+ from api.db.db_models import DB, File2Document, File
23
+ from api.db import StatusEnum, FileType, TaskStatus
24
+ from api.db.db_models import Task, Document, Knowledgebase, Tenant
25
+ from api.db.services.common_service import CommonService
26
+ from api.db.services.document_service import DocumentService
27
+ from api.utils import current_timestamp, get_uuid
28
+ from deepdoc.parser.excel_parser import RAGFlowExcelParser
29
+ from rag.settings import SVR_QUEUE_NAME
30
+ from rag.utils.minio_conn import MINIO
31
+ from rag.utils.redis_conn import REDIS_CONN
32
+
33
+
34
+ class TaskService(CommonService):
35
+ model = Task
36
+
37
+ @classmethod
38
+ @DB.connection_context()
39
+ def get_tasks(cls, task_id):
40
+ fields = [
41
+ cls.model.id,
42
+ cls.model.doc_id,
43
+ cls.model.from_page,
44
+ cls.model.to_page,
45
+ Document.kb_id,
46
+ Document.parser_id,
47
+ Document.parser_config,
48
+ Document.name,
49
+ Document.type,
50
+ Document.location,
51
+ Document.size,
52
+ Knowledgebase.tenant_id,
53
+ Knowledgebase.language,
54
+ Knowledgebase.embd_id,
55
+ Tenant.img2txt_id,
56
+ Tenant.asr_id,
57
+ Tenant.llm_id,
58
+ cls.model.update_time]
59
+ docs = cls.model.select(*fields) \
60
+ .join(Document, on=(cls.model.doc_id == Document.id)) \
61
+ .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
62
+ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
63
+ .where(cls.model.id == task_id)
64
+ docs = list(docs.dicts())
65
+ if not docs: return []
66
+
67
+ cls.model.update(progress_msg=cls.model.progress_msg + "\n" + "Task has been received.",
68
+ progress=random.random() / 10.).where(
69
+ cls.model.id == docs[0]["id"]).execute()
70
+ return docs
71
+
72
+ @classmethod
73
+ @DB.connection_context()
74
+ def get_ongoing_doc_name(cls):
75
+ with DB.lock("get_task", -1):
76
+ docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
77
+ .join(Document, on=(cls.model.doc_id == Document.id)) \
78
+ .join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
79
+ .join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
80
+ .where(
81
+ Document.status == StatusEnum.VALID.value,
82
+ Document.run == TaskStatus.RUNNING.value,
83
+ ~(Document.type == FileType.VIRTUAL.value),
84
+ cls.model.progress < 1,
85
+ cls.model.create_time >= current_timestamp() - 1000 * 600
86
+ )
87
+ docs = list(docs.dicts())
88
+ if not docs: return []
89
+
90
+ return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
91
+
92
+ @classmethod
93
+ @DB.connection_context()
94
+ def do_cancel(cls, id):
95
+ try:
96
+ task = cls.model.get_by_id(id)
97
+ _, doc = DocumentService.get_by_id(task.doc_id)
98
+ return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
99
+ except Exception as e:
100
+ pass
101
+ return False
102
+
103
+ @classmethod
104
+ @DB.connection_context()
105
+ def update_progress(cls, id, info):
106
+ if os.environ.get("MACOS"):
107
+ if info["progress_msg"]:
108
+ cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
109
+ cls.model.id == id).execute()
110
+ if "progress" in info:
111
+ cls.model.update(progress=info["progress"]).where(
112
+ cls.model.id == id).execute()
113
+ return
114
+
115
+ with DB.lock("update_progress", -1):
116
+ if info["progress_msg"]:
117
+ cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
118
+ cls.model.id == id).execute()
119
+ if "progress" in info:
120
+ cls.model.update(progress=info["progress"]).where(
121
+ cls.model.id == id).execute()
122
+
123
+
124
+ def queue_tasks(doc, bucket, name):
125
+ def new_task():
126
+ nonlocal doc
127
+ return {
128
+ "id": get_uuid(),
129
+ "doc_id": doc["id"]
130
+ }
131
+ tsks = []
132
+
133
+ if doc["type"] == FileType.PDF.value:
134
+ file_bin = MINIO.get(bucket, name)
135
+ do_layout = doc["parser_config"].get("layout_recognize", True)
136
+ pages = PdfParser.total_page_number(doc["name"], file_bin)
137
+ page_size = doc["parser_config"].get("task_page_size", 12)
138
+ if doc["parser_id"] == "paper":
139
+ page_size = doc["parser_config"].get("task_page_size", 22)
140
+ if doc["parser_id"] == "one":
141
+ page_size = 1000000000
142
+ if doc["parser_id"] == "knowledge_graph":
143
+ page_size = 1000000000
144
+ if not do_layout:
145
+ page_size = 1000000000
146
+ page_ranges = doc["parser_config"].get("pages")
147
+ if not page_ranges:
148
+ page_ranges = [(1, 100000)]
149
+ for s, e in page_ranges:
150
+ s -= 1
151
+ s = max(0, s)
152
+ e = min(e - 1, pages)
153
+ for p in range(s, e, page_size):
154
+ task = new_task()
155
+ task["from_page"] = p
156
+ task["to_page"] = min(p + page_size, e)
157
+ tsks.append(task)
158
+
159
+ elif doc["parser_id"] == "table":
160
+ file_bin = MINIO.get(bucket, name)
161
+ rn = RAGFlowExcelParser.row_number(
162
+ doc["name"], file_bin)
163
+ for i in range(0, rn, 3000):
164
+ task = new_task()
165
+ task["from_page"] = i
166
+ task["to_page"] = min(i + 3000, rn)
167
+ tsks.append(task)
168
+ else:
169
+ tsks.append(new_task())
170
+
171
+ bulk_insert_into_db(Task, tsks, True)
172
+ DocumentService.begin2parse(doc["id"])
173
+
174
+ for t in tsks:
175
+ assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
api/ragflow_server.py CHANGED
@@ -1,100 +1,100 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- import logging
18
- import os
19
- import signal
20
- import sys
21
- import time
22
- import traceback
23
- from concurrent.futures import ThreadPoolExecutor
24
-
25
- from werkzeug.serving import run_simple
26
- from api.apps import app
27
- from api.db.runtime_config import RuntimeConfig
28
- from api.db.services.document_service import DocumentService
29
- from api.settings import (
30
- HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
31
- )
32
- from api import utils
33
-
34
- from api.db.db_models import init_database_tables as init_web_db
35
- from api.db.init_data import init_web_data
36
- from api.versions import get_versions
37
-
38
-
39
- def update_progress():
40
- while True:
41
- time.sleep(1)
42
- try:
43
- DocumentService.update_progress()
44
- except Exception as e:
45
- stat_logger.error("update_progress exception:" + str(e))
46
-
47
-
48
- if __name__ == '__main__':
49
- print("""
50
- ____ ______ __
51
- / __ \ ____ _ ____ _ / ____// /____ _ __
52
- / /_/ // __ `// __ `// /_ / // __ \| | /| / /
53
- / _, _// /_/ // /_/ // __/ / // /_/ /| |/ |/ /
54
- /_/ |_| \__,_/ \__, //_/ /_/ \____/ |__/|__/
55
- /____/
56
-
57
- """, flush=True)
58
- stat_logger.info(
59
- f'project base: {utils.file_utils.get_project_base_directory()}'
60
- )
61
-
62
- # init db
63
- init_web_db()
64
- init_web_data()
65
- # init runtime config
66
- import argparse
67
- parser = argparse.ArgumentParser()
68
- parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
69
- parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
70
- args = parser.parse_args()
71
- if args.version:
72
- print(get_versions())
73
- sys.exit(0)
74
-
75
- RuntimeConfig.DEBUG = args.debug
76
- if RuntimeConfig.DEBUG:
77
- stat_logger.info("run on debug mode")
78
-
79
- RuntimeConfig.init_env()
80
- RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
81
-
82
- peewee_logger = logging.getLogger('peewee')
83
- peewee_logger.propagate = False
84
- # rag_arch.common.log.ROpenHandler
85
- peewee_logger.addHandler(database_logger.handlers[0])
86
- peewee_logger.setLevel(database_logger.level)
87
-
88
- thr = ThreadPoolExecutor(max_workers=1)
89
- thr.submit(update_progress)
90
-
91
- # start http server
92
- try:
93
- stat_logger.info("RAG Flow http server start...")
94
- werkzeug_logger = logging.getLogger("werkzeug")
95
- for h in access_logger.handlers:
96
- werkzeug_logger.addHandler(h)
97
- run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
98
- except Exception:
99
- traceback.print_exc()
100
  os.kill(os.getpid(), signal.SIGKILL)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import logging
18
+ import os
19
+ import signal
20
+ import sys
21
+ import time
22
+ import traceback
23
+ from concurrent.futures import ThreadPoolExecutor
24
+
25
+ from werkzeug.serving import run_simple
26
+ from api.apps import app
27
+ from api.db.runtime_config import RuntimeConfig
28
+ from api.db.services.document_service import DocumentService
29
+ from api.settings import (
30
+ HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
31
+ )
32
+ from api import utils
33
+
34
+ from api.db.db_models import init_database_tables as init_web_db
35
+ from api.db.init_data import init_web_data
36
+ from api.versions import get_versions
37
+
38
+
39
+ def update_progress():
40
+ while True:
41
+ time.sleep(1)
42
+ try:
43
+ DocumentService.update_progress()
44
+ except Exception as e:
45
+ stat_logger.error("update_progress exception:" + str(e))
46
+
47
+
48
+ if __name__ == '__main__':
49
+ print("""
50
+ ____ ______ __
51
+ / __ \ ____ _ ____ _ / ____// /____ _ __
52
+ / /_/ // __ `// __ `// /_ / // __ \| | /| / /
53
+ / _, _// /_/ // /_/ // __/ / // /_/ /| |/ |/ /
54
+ /_/ |_| \__,_/ \__, //_/ /_/ \____/ |__/|__/
55
+ /____/
56
+
57
+ """, flush=True)
58
+ stat_logger.info(
59
+ f'project base: {utils.file_utils.get_project_base_directory()}'
60
+ )
61
+
62
+ # init db
63
+ init_web_db()
64
+ init_web_data()
65
+ # init runtime config
66
+ import argparse
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument('--version', default=False, help="rag flow version", action='store_true')
69
+ parser.add_argument('--debug', default=False, help="debug mode", action='store_true')
70
+ args = parser.parse_args()
71
+ if args.version:
72
+ print(get_versions())
73
+ sys.exit(0)
74
+
75
+ RuntimeConfig.DEBUG = args.debug
76
+ if RuntimeConfig.DEBUG:
77
+ stat_logger.info("run on debug mode")
78
+
79
+ RuntimeConfig.init_env()
80
+ RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
81
+
82
+ peewee_logger = logging.getLogger('peewee')
83
+ peewee_logger.propagate = False
84
+ # rag_arch.common.log.ROpenHandler
85
+ peewee_logger.addHandler(database_logger.handlers[0])
86
+ peewee_logger.setLevel(database_logger.level)
87
+
88
+ thr = ThreadPoolExecutor(max_workers=1)
89
+ thr.submit(update_progress)
90
+
91
+ # start http server
92
+ try:
93
+ stat_logger.info("RAG Flow http server start...")
94
+ werkzeug_logger = logging.getLogger("werkzeug")
95
+ for h in access_logger.handlers:
96
+ werkzeug_logger.addHandler(h)
97
+ run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG)
98
+ except Exception:
99
+ traceback.print_exc()
100
  os.kill(os.getpid(), signal.SIGKILL)
api/settings.py CHANGED
@@ -1,251 +1,251 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import os
17
- from enum import IntEnum, Enum
18
- from api.utils.file_utils import get_project_base_directory
19
- from api.utils.log_utils import LoggerFactory, getLogger
20
-
21
- # Logger
22
- LoggerFactory.set_directory(
23
- os.path.join(
24
- get_project_base_directory(),
25
- "logs",
26
- "api"))
27
- # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
28
- LoggerFactory.LEVEL = 30
29
-
30
- stat_logger = getLogger("stat")
31
- access_logger = getLogger("access")
32
- database_logger = getLogger("database")
33
- chat_logger = getLogger("chat")
34
-
35
- from rag.utils.es_conn import ELASTICSEARCH
36
- from rag.nlp import search
37
- from graphrag import search as kg_search
38
- from api.utils import get_base_config, decrypt_database_config
39
-
40
- API_VERSION = "v1"
41
- RAG_FLOW_SERVICE_NAME = "ragflow"
42
- SERVER_MODULE = "rag_flow_server.py"
43
- TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
44
- RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
45
-
46
- SUBPROCESS_STD_LOG_NAME = "std.log"
47
-
48
- ERROR_REPORT = True
49
- ERROR_REPORT_WITH_PATH = False
50
-
51
- MAX_TIMESTAMP_INTERVAL = 60
52
- SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
53
-
54
- REQUEST_TRY_TIMES = 3
55
- REQUEST_WAIT_SEC = 2
56
- REQUEST_MAX_WAIT_SEC = 300
57
-
58
- USE_REGISTRY = get_base_config("use_registry")
59
-
60
- default_llm = {
61
- "Tongyi-Qianwen": {
62
- "chat_model": "qwen-plus",
63
- "embedding_model": "text-embedding-v2",
64
- "image2text_model": "qwen-vl-max",
65
- "asr_model": "paraformer-realtime-8k-v1",
66
- },
67
- "OpenAI": {
68
- "chat_model": "gpt-3.5-turbo",
69
- "embedding_model": "text-embedding-ada-002",
70
- "image2text_model": "gpt-4-vision-preview",
71
- "asr_model": "whisper-1",
72
- },
73
- "Azure-OpenAI": {
74
- "chat_model": "azure-gpt-35-turbo",
75
- "embedding_model": "azure-text-embedding-ada-002",
76
- "image2text_model": "azure-gpt-4-vision-preview",
77
- "asr_model": "azure-whisper-1",
78
- },
79
- "ZHIPU-AI": {
80
- "chat_model": "glm-3-turbo",
81
- "embedding_model": "embedding-2",
82
- "image2text_model": "glm-4v",
83
- "asr_model": "",
84
- },
85
- "Ollama": {
86
- "chat_model": "qwen-14B-chat",
87
- "embedding_model": "flag-embedding",
88
- "image2text_model": "",
89
- "asr_model": "",
90
- },
91
- "Moonshot": {
92
- "chat_model": "moonshot-v1-8k",
93
- "embedding_model": "",
94
- "image2text_model": "",
95
- "asr_model": "",
96
- },
97
- "DeepSeek": {
98
- "chat_model": "deepseek-chat",
99
- "embedding_model": "",
100
- "image2text_model": "",
101
- "asr_model": "",
102
- },
103
- "VolcEngine": {
104
- "chat_model": "",
105
- "embedding_model": "",
106
- "image2text_model": "",
107
- "asr_model": "",
108
- },
109
- "BAAI": {
110
- "chat_model": "",
111
- "embedding_model": "BAAI/bge-large-zh-v1.5",
112
- "image2text_model": "",
113
- "asr_model": "",
114
- "rerank_model": "BAAI/bge-reranker-v2-m3",
115
- }
116
- }
117
- LLM = get_base_config("user_default_llm", {})
118
- LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
119
- LLM_BASE_URL = LLM.get("base_url")
120
-
121
- if LLM_FACTORY not in default_llm:
122
- print(
123
- "\33[91m【ERROR】\33[0m:",
124
- f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
125
- LLM_FACTORY = "Tongyi-Qianwen"
126
- CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
127
- EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
128
- RERANK_MDL = default_llm["BAAI"]["rerank_model"]
129
- ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
130
- IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
131
-
132
- API_KEY = LLM.get("api_key", "")
133
- PARSERS = LLM.get(
134
- "parsers",
135
- "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
136
-
137
- # distribution
138
- DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
139
- RAG_FLOW_UPDATE_CHECK = False
140
-
141
- HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
142
- HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
143
-
144
- SECRET_KEY = get_base_config(
145
- RAG_FLOW_SERVICE_NAME,
146
- {}).get(
147
- "secret_key",
148
- "infiniflow")
149
- TOKEN_EXPIRE_IN = get_base_config(
150
- RAG_FLOW_SERVICE_NAME, {}).get(
151
- "token_expires_in", 3600)
152
-
153
- NGINX_HOST = get_base_config(
154
- RAG_FLOW_SERVICE_NAME, {}).get(
155
- "nginx", {}).get("host") or HOST
156
- NGINX_HTTP_PORT = get_base_config(
157
- RAG_FLOW_SERVICE_NAME, {}).get(
158
- "nginx", {}).get("http_port") or HTTP_PORT
159
-
160
- RANDOM_INSTANCE_ID = get_base_config(
161
- RAG_FLOW_SERVICE_NAME, {}).get(
162
- "random_instance_id", False)
163
-
164
- PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
165
- PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
166
-
167
- DATABASE = decrypt_database_config(name="mysql")
168
-
169
- # Switch
170
- # upload
171
- UPLOAD_DATA_FROM_CLIENT = True
172
-
173
- # authentication
174
- AUTHENTICATION_CONF = get_base_config("authentication", {})
175
-
176
- # client
177
- CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
178
- "client", {}).get(
179
- "switch", False)
180
- HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
181
- GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
182
- FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
183
- WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
184
-
185
- # site
186
- SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)
187
-
188
- # permission
189
- PERMISSION_CONF = get_base_config("permission", {})
190
- PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
191
- COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
192
- DATASET_PERMISSION = PERMISSION_CONF.get("dataset")
193
-
194
- HOOK_MODULE = get_base_config("hook_module")
195
- HOOK_SERVER_NAME = get_base_config("hook_server_name")
196
-
197
- ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
198
- # authentication
199
- USE_AUTHENTICATION = False
200
- USE_DATA_AUTHENTICATION = False
201
- AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
202
- USE_DEFAULT_TIMEOUT = False
203
- AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
204
- PRIVILEGE_COMMAND_WHITELIST = []
205
- CHECK_NODES_IDENTITY = False
206
-
207
- retrievaler = search.Dealer(ELASTICSEARCH)
208
- kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
209
-
210
-
211
- class CustomEnum(Enum):
212
- @classmethod
213
- def valid(cls, value):
214
- try:
215
- cls(value)
216
- return True
217
- except BaseException:
218
- return False
219
-
220
- @classmethod
221
- def values(cls):
222
- return [member.value for member in cls.__members__.values()]
223
-
224
- @classmethod
225
- def names(cls):
226
- return [member.name for member in cls.__members__.values()]
227
-
228
-
229
- class PythonDependenceName(CustomEnum):
230
- Rag_Source_Code = "python"
231
- Python_Env = "miniconda"
232
-
233
-
234
- class ModelStorage(CustomEnum):
235
- REDIS = "redis"
236
- MYSQL = "mysql"
237
-
238
-
239
- class RetCode(IntEnum, CustomEnum):
240
- SUCCESS = 0
241
- NOT_EFFECTIVE = 10
242
- EXCEPTION_ERROR = 100
243
- ARGUMENT_ERROR = 101
244
- DATA_ERROR = 102
245
- OPERATING_ERROR = 103
246
- CONNECTION_ERROR = 105
247
- RUNNING = 106
248
- PERMISSION_ERROR = 108
249
- AUTHENTICATION_ERROR = 109
250
- UNAUTHORIZED = 401
251
- SERVER_ERROR = 500
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import os
17
+ from enum import IntEnum, Enum
18
+ from api.utils.file_utils import get_project_base_directory
19
+ from api.utils.log_utils import LoggerFactory, getLogger
20
+
21
+ # Logger
22
+ LoggerFactory.set_directory(
23
+ os.path.join(
24
+ get_project_base_directory(),
25
+ "logs",
26
+ "api"))
27
+ # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
28
+ LoggerFactory.LEVEL = 30
29
+
30
+ stat_logger = getLogger("stat")
31
+ access_logger = getLogger("access")
32
+ database_logger = getLogger("database")
33
+ chat_logger = getLogger("chat")
34
+
35
+ from rag.utils.es_conn import ELASTICSEARCH
36
+ from rag.nlp import search
37
+ from graphrag import search as kg_search
38
+ from api.utils import get_base_config, decrypt_database_config
39
+
40
+ API_VERSION = "v1"
41
+ RAG_FLOW_SERVICE_NAME = "ragflow"
42
+ SERVER_MODULE = "rag_flow_server.py"
43
+ TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
44
+ RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
45
+
46
+ SUBPROCESS_STD_LOG_NAME = "std.log"
47
+
48
+ ERROR_REPORT = True
49
+ ERROR_REPORT_WITH_PATH = False
50
+
51
+ MAX_TIMESTAMP_INTERVAL = 60
52
+ SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
53
+
54
+ REQUEST_TRY_TIMES = 3
55
+ REQUEST_WAIT_SEC = 2
56
+ REQUEST_MAX_WAIT_SEC = 300
57
+
58
+ USE_REGISTRY = get_base_config("use_registry")
59
+
60
+ default_llm = {
61
+ "Tongyi-Qianwen": {
62
+ "chat_model": "qwen-plus",
63
+ "embedding_model": "text-embedding-v2",
64
+ "image2text_model": "qwen-vl-max",
65
+ "asr_model": "paraformer-realtime-8k-v1",
66
+ },
67
+ "OpenAI": {
68
+ "chat_model": "gpt-3.5-turbo",
69
+ "embedding_model": "text-embedding-ada-002",
70
+ "image2text_model": "gpt-4-vision-preview",
71
+ "asr_model": "whisper-1",
72
+ },
73
+ "Azure-OpenAI": {
74
+ "chat_model": "azure-gpt-35-turbo",
75
+ "embedding_model": "azure-text-embedding-ada-002",
76
+ "image2text_model": "azure-gpt-4-vision-preview",
77
+ "asr_model": "azure-whisper-1",
78
+ },
79
+ "ZHIPU-AI": {
80
+ "chat_model": "glm-3-turbo",
81
+ "embedding_model": "embedding-2",
82
+ "image2text_model": "glm-4v",
83
+ "asr_model": "",
84
+ },
85
+ "Ollama": {
86
+ "chat_model": "qwen-14B-chat",
87
+ "embedding_model": "flag-embedding",
88
+ "image2text_model": "",
89
+ "asr_model": "",
90
+ },
91
+ "Moonshot": {
92
+ "chat_model": "moonshot-v1-8k",
93
+ "embedding_model": "",
94
+ "image2text_model": "",
95
+ "asr_model": "",
96
+ },
97
+ "DeepSeek": {
98
+ "chat_model": "deepseek-chat",
99
+ "embedding_model": "",
100
+ "image2text_model": "",
101
+ "asr_model": "",
102
+ },
103
+ "VolcEngine": {
104
+ "chat_model": "",
105
+ "embedding_model": "",
106
+ "image2text_model": "",
107
+ "asr_model": "",
108
+ },
109
+ "BAAI": {
110
+ "chat_model": "",
111
+ "embedding_model": "BAAI/bge-large-zh-v1.5",
112
+ "image2text_model": "",
113
+ "asr_model": "",
114
+ "rerank_model": "BAAI/bge-reranker-v2-m3",
115
+ }
116
+ }
117
+ LLM = get_base_config("user_default_llm", {})
118
+ LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
119
+ LLM_BASE_URL = LLM.get("base_url")
120
+
121
+ if LLM_FACTORY not in default_llm:
122
+ print(
123
+ "\33[91m【ERROR】\33[0m:",
124
+ f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
125
+ LLM_FACTORY = "Tongyi-Qianwen"
126
+ CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
127
+ EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
128
+ RERANK_MDL = default_llm["BAAI"]["rerank_model"]
129
+ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
130
+ IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
131
+
132
+ API_KEY = LLM.get("api_key", "")
133
+ PARSERS = LLM.get(
134
+ "parsers",
135
+ "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
136
+
137
+ # distribution
138
+ DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
139
+ RAG_FLOW_UPDATE_CHECK = False
140
+
141
+ HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
142
+ HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
143
+
144
+ SECRET_KEY = get_base_config(
145
+ RAG_FLOW_SERVICE_NAME,
146
+ {}).get(
147
+ "secret_key",
148
+ "infiniflow")
149
+ TOKEN_EXPIRE_IN = get_base_config(
150
+ RAG_FLOW_SERVICE_NAME, {}).get(
151
+ "token_expires_in", 3600)
152
+
153
+ NGINX_HOST = get_base_config(
154
+ RAG_FLOW_SERVICE_NAME, {}).get(
155
+ "nginx", {}).get("host") or HOST
156
+ NGINX_HTTP_PORT = get_base_config(
157
+ RAG_FLOW_SERVICE_NAME, {}).get(
158
+ "nginx", {}).get("http_port") or HTTP_PORT
159
+
160
+ RANDOM_INSTANCE_ID = get_base_config(
161
+ RAG_FLOW_SERVICE_NAME, {}).get(
162
+ "random_instance_id", False)
163
+
164
+ PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
165
+ PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
166
+
167
+ DATABASE = decrypt_database_config(name="mysql")
168
+
169
+ # Switch
170
+ # upload
171
+ UPLOAD_DATA_FROM_CLIENT = True
172
+
173
+ # authentication
174
+ AUTHENTICATION_CONF = get_base_config("authentication", {})
175
+
176
+ # client
177
+ CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
178
+ "client", {}).get(
179
+ "switch", False)
180
+ HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
181
+ GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
182
+ FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
183
+ WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
184
+
185
+ # site
186
+ SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False)
187
+
188
+ # permission
189
+ PERMISSION_CONF = get_base_config("permission", {})
190
+ PERMISSION_SWITCH = PERMISSION_CONF.get("switch")
191
+ COMPONENT_PERMISSION = PERMISSION_CONF.get("component")
192
+ DATASET_PERMISSION = PERMISSION_CONF.get("dataset")
193
+
194
+ HOOK_MODULE = get_base_config("hook_module")
195
+ HOOK_SERVER_NAME = get_base_config("hook_server_name")
196
+
197
+ ENABLE_MODEL_STORE = get_base_config('enable_model_store', False)
198
+ # authentication
199
+ USE_AUTHENTICATION = False
200
+ USE_DATA_AUTHENTICATION = False
201
+ AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
202
+ USE_DEFAULT_TIMEOUT = False
203
+ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
204
+ PRIVILEGE_COMMAND_WHITELIST = []
205
+ CHECK_NODES_IDENTITY = False
206
+
207
+ retrievaler = search.Dealer(ELASTICSEARCH)
208
+ kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
209
+
210
+
211
+ class CustomEnum(Enum):
212
+ @classmethod
213
+ def valid(cls, value):
214
+ try:
215
+ cls(value)
216
+ return True
217
+ except BaseException:
218
+ return False
219
+
220
+ @classmethod
221
+ def values(cls):
222
+ return [member.value for member in cls.__members__.values()]
223
+
224
+ @classmethod
225
+ def names(cls):
226
+ return [member.name for member in cls.__members__.values()]
227
+
228
+
229
+ class PythonDependenceName(CustomEnum):
230
+ Rag_Source_Code = "python"
231
+ Python_Env = "miniconda"
232
+
233
+
234
+ class ModelStorage(CustomEnum):
235
+ REDIS = "redis"
236
+ MYSQL = "mysql"
237
+
238
+
239
+ class RetCode(IntEnum, CustomEnum):
240
+ SUCCESS = 0
241
+ NOT_EFFECTIVE = 10
242
+ EXCEPTION_ERROR = 100
243
+ ARGUMENT_ERROR = 101
244
+ DATA_ERROR = 102
245
+ OPERATING_ERROR = 103
246
+ CONNECTION_ERROR = 105
247
+ RUNNING = 106
248
+ PERMISSION_ERROR = 108
249
+ AUTHENTICATION_ERROR = 109
250
+ UNAUTHORIZED = 401
251
+ SERVER_ERROR = 500
api/utils/__init__.py CHANGED
@@ -1,346 +1,346 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import base64
17
- import datetime
18
- import io
19
- import json
20
- import os
21
- import pickle
22
- import socket
23
- import time
24
- import uuid
25
- import requests
26
- from enum import Enum, IntEnum
27
- import importlib
28
- from Cryptodome.PublicKey import RSA
29
- from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
30
-
31
- from filelock import FileLock
32
-
33
- from . import file_utils
34
-
35
- SERVICE_CONF = "service_conf.yaml"
36
-
37
-
38
- def conf_realpath(conf_name):
39
- conf_path = f"conf/{conf_name}"
40
- return os.path.join(file_utils.get_project_base_directory(), conf_path)
41
-
42
-
43
- def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
44
- local_config = {}
45
- local_path = conf_realpath(f'local.{conf_name}')
46
- if default is None:
47
- default = os.environ.get(key.upper())
48
-
49
- if os.path.exists(local_path):
50
- local_config = file_utils.load_yaml_conf(local_path)
51
- if not isinstance(local_config, dict):
52
- raise ValueError(f'Invalid config file: "{local_path}".')
53
-
54
- if key is not None and key in local_config:
55
- return local_config[key]
56
-
57
- config_path = conf_realpath(conf_name)
58
- config = file_utils.load_yaml_conf(config_path)
59
-
60
- if not isinstance(config, dict):
61
- raise ValueError(f'Invalid config file: "{config_path}".')
62
-
63
- config.update(local_config)
64
- return config.get(key, default) if key is not None else config
65
-
66
-
67
- use_deserialize_safe_module = get_base_config(
68
- 'use_deserialize_safe_module', False)
69
-
70
-
71
- class CoordinationCommunicationProtocol(object):
72
- HTTP = "http"
73
- GRPC = "grpc"
74
-
75
-
76
- class BaseType:
77
- def to_dict(self):
78
- return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
79
-
80
- def to_dict_with_type(self):
81
- def _dict(obj):
82
- module = None
83
- if issubclass(obj.__class__, BaseType):
84
- data = {}
85
- for attr, v in obj.__dict__.items():
86
- k = attr.lstrip("_")
87
- data[k] = _dict(v)
88
- module = obj.__module__
89
- elif isinstance(obj, (list, tuple)):
90
- data = []
91
- for i, vv in enumerate(obj):
92
- data.append(_dict(vv))
93
- elif isinstance(obj, dict):
94
- data = {}
95
- for _k, vv in obj.items():
96
- data[_k] = _dict(vv)
97
- else:
98
- data = obj
99
- return {"type": obj.__class__.__name__,
100
- "data": data, "module": module}
101
- return _dict(self)
102
-
103
-
104
- class CustomJSONEncoder(json.JSONEncoder):
105
- def __init__(self, **kwargs):
106
- self._with_type = kwargs.pop("with_type", False)
107
- super().__init__(**kwargs)
108
-
109
- def default(self, obj):
110
- if isinstance(obj, datetime.datetime):
111
- return obj.strftime('%Y-%m-%d %H:%M:%S')
112
- elif isinstance(obj, datetime.date):
113
- return obj.strftime('%Y-%m-%d')
114
- elif isinstance(obj, datetime.timedelta):
115
- return str(obj)
116
- elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
117
- return obj.value
118
- elif isinstance(obj, set):
119
- return list(obj)
120
- elif issubclass(type(obj), BaseType):
121
- if not self._with_type:
122
- return obj.to_dict()
123
- else:
124
- return obj.to_dict_with_type()
125
- elif isinstance(obj, type):
126
- return obj.__name__
127
- else:
128
- return json.JSONEncoder.default(self, obj)
129
-
130
-
131
- def rag_uuid():
132
- return uuid.uuid1().hex
133
-
134
-
135
- def string_to_bytes(string):
136
- return string if isinstance(
137
- string, bytes) else string.encode(encoding="utf-8")
138
-
139
-
140
- def bytes_to_string(byte):
141
- return byte.decode(encoding="utf-8")
142
-
143
-
144
- def json_dumps(src, byte=False, indent=None, with_type=False):
145
- dest = json.dumps(
146
- src,
147
- indent=indent,
148
- cls=CustomJSONEncoder,
149
- with_type=with_type)
150
- if byte:
151
- dest = string_to_bytes(dest)
152
- return dest
153
-
154
-
155
- def json_loads(src, object_hook=None, object_pairs_hook=None):
156
- if isinstance(src, bytes):
157
- src = bytes_to_string(src)
158
- return json.loads(src, object_hook=object_hook,
159
- object_pairs_hook=object_pairs_hook)
160
-
161
-
162
- def current_timestamp():
163
- return int(time.time() * 1000)
164
-
165
-
166
- def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
167
- if not timestamp:
168
- timestamp = time.time()
169
- timestamp = int(timestamp) / 1000
170
- time_array = time.localtime(timestamp)
171
- str_date = time.strftime(format_string, time_array)
172
- return str_date
173
-
174
-
175
- def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
176
- time_array = time.strptime(time_str, format_string)
177
- time_stamp = int(time.mktime(time_array) * 1000)
178
- return time_stamp
179
-
180
-
181
- def serialize_b64(src, to_str=False):
182
- dest = base64.b64encode(pickle.dumps(src))
183
- if not to_str:
184
- return dest
185
- else:
186
- return bytes_to_string(dest)
187
-
188
-
189
- def deserialize_b64(src):
190
- src = base64.b64decode(
191
- string_to_bytes(src) if isinstance(
192
- src, str) else src)
193
- if use_deserialize_safe_module:
194
- return restricted_loads(src)
195
- return pickle.loads(src)
196
-
197
-
198
- safe_module = {
199
- 'numpy',
200
- 'rag_flow'
201
- }
202
-
203
-
204
- class RestrictedUnpickler(pickle.Unpickler):
205
- def find_class(self, module, name):
206
- import importlib
207
- if module.split('.')[0] in safe_module:
208
- _module = importlib.import_module(module)
209
- return getattr(_module, name)
210
- # Forbid everything else.
211
- raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
212
- (module, name))
213
-
214
-
215
- def restricted_loads(src):
216
- """Helper function analogous to pickle.loads()."""
217
- return RestrictedUnpickler(io.BytesIO(src)).load()
218
-
219
-
220
- def get_lan_ip():
221
- if os.name != "nt":
222
- import fcntl
223
- import struct
224
-
225
- def get_interface_ip(ifname):
226
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
227
- return socket.inet_ntoa(
228
- fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
229
-
230
- ip = socket.gethostbyname(socket.getfqdn())
231
- if ip.startswith("127.") and os.name != "nt":
232
- interfaces = [
233
- "bond1",
234
- "eth0",
235
- "eth1",
236
- "eth2",
237
- "wlan0",
238
- "wlan1",
239
- "wifi0",
240
- "ath0",
241
- "ath1",
242
- "ppp0",
243
- ]
244
- for ifname in interfaces:
245
- try:
246
- ip = get_interface_ip(ifname)
247
- break
248
- except IOError as e:
249
- pass
250
- return ip or ''
251
-
252
-
253
- def from_dict_hook(in_dict: dict):
254
- if "type" in in_dict and "data" in in_dict:
255
- if in_dict["module"] is None:
256
- return in_dict["data"]
257
- else:
258
- return getattr(importlib.import_module(
259
- in_dict["module"]), in_dict["type"])(**in_dict["data"])
260
- else:
261
- return in_dict
262
-
263
-
264
- def decrypt_database_password(password):
265
- encrypt_password = get_base_config("encrypt_password", False)
266
- encrypt_module = get_base_config("encrypt_module", False)
267
- private_key = get_base_config("private_key", None)
268
-
269
- if not password or not encrypt_password:
270
- return password
271
-
272
- if not private_key:
273
- raise ValueError("No private key")
274
-
275
- module_fun = encrypt_module.split("#")
276
- pwdecrypt_fun = getattr(
277
- importlib.import_module(
278
- module_fun[0]),
279
- module_fun[1])
280
-
281
- return pwdecrypt_fun(private_key, password)
282
-
283
-
284
- def decrypt_database_config(
285
- database=None, passwd_key="password", name="database"):
286
- if not database:
287
- database = get_base_config(name, {})
288
-
289
- database[passwd_key] = decrypt_database_password(database[passwd_key])
290
- return database
291
-
292
-
293
- def update_config(key, value, conf_name=SERVICE_CONF):
294
- conf_path = conf_realpath(conf_name=conf_name)
295
- if not os.path.isabs(conf_path):
296
- conf_path = os.path.join(
297
- file_utils.get_project_base_directory(), conf_path)
298
-
299
- with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
300
- config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
301
- config[key] = value
302
- file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
303
-
304
-
305
- def get_uuid():
306
- return uuid.uuid1().hex
307
-
308
-
309
- def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
310
- return datetime.datetime(date_time.year, date_time.month, date_time.day,
311
- date_time.hour, date_time.minute, date_time.second)
312
-
313
-
314
- def get_format_time() -> datetime.datetime:
315
- return datetime_format(datetime.datetime.now())
316
-
317
-
318
- def str2date(date_time: str):
319
- return datetime.datetime.strptime(date_time, '%Y-%m-%d')
320
-
321
-
322
- def elapsed2time(elapsed):
323
- seconds = elapsed / 1000
324
- minuter, second = divmod(seconds, 60)
325
- hour, minuter = divmod(minuter, 60)
326
- return '%02d:%02d:%02d' % (hour, minuter, second)
327
-
328
-
329
- def decrypt(line):
330
- file_path = os.path.join(
331
- file_utils.get_project_base_directory(),
332
- "conf",
333
- "private.pem")
334
- rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
335
- cipher = Cipher_pkcs1_v1_5.new(rsa_key)
336
- return cipher.decrypt(base64.b64decode(
337
- line), "Fail to decrypt password!").decode('utf-8')
338
-
339
-
340
- def download_img(url):
341
- if not url:
342
- return ""
343
- response = requests.get(url)
344
- return "data:" + \
345
- response.headers.get('Content-Type', 'image/jpg') + ";" + \
346
- "base64," + base64.b64encode(response.content).decode("utf-8")
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import base64
17
+ import datetime
18
+ import io
19
+ import json
20
+ import os
21
+ import pickle
22
+ import socket
23
+ import time
24
+ import uuid
25
+ import requests
26
+ from enum import Enum, IntEnum
27
+ import importlib
28
+ from Cryptodome.PublicKey import RSA
29
+ from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
30
+
31
+ from filelock import FileLock
32
+
33
+ from . import file_utils
34
+
35
+ SERVICE_CONF = "service_conf.yaml"
36
+
37
+
38
+ def conf_realpath(conf_name):
39
+ conf_path = f"conf/{conf_name}"
40
+ return os.path.join(file_utils.get_project_base_directory(), conf_path)
41
+
42
+
43
+ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
44
+ local_config = {}
45
+ local_path = conf_realpath(f'local.{conf_name}')
46
+ if default is None:
47
+ default = os.environ.get(key.upper())
48
+
49
+ if os.path.exists(local_path):
50
+ local_config = file_utils.load_yaml_conf(local_path)
51
+ if not isinstance(local_config, dict):
52
+ raise ValueError(f'Invalid config file: "{local_path}".')
53
+
54
+ if key is not None and key in local_config:
55
+ return local_config[key]
56
+
57
+ config_path = conf_realpath(conf_name)
58
+ config = file_utils.load_yaml_conf(config_path)
59
+
60
+ if not isinstance(config, dict):
61
+ raise ValueError(f'Invalid config file: "{config_path}".')
62
+
63
+ config.update(local_config)
64
+ return config.get(key, default) if key is not None else config
65
+
66
+
67
+ use_deserialize_safe_module = get_base_config(
68
+ 'use_deserialize_safe_module', False)
69
+
70
+
71
+ class CoordinationCommunicationProtocol(object):
72
+ HTTP = "http"
73
+ GRPC = "grpc"
74
+
75
+
76
+ class BaseType:
77
+ def to_dict(self):
78
+ return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
79
+
80
+ def to_dict_with_type(self):
81
+ def _dict(obj):
82
+ module = None
83
+ if issubclass(obj.__class__, BaseType):
84
+ data = {}
85
+ for attr, v in obj.__dict__.items():
86
+ k = attr.lstrip("_")
87
+ data[k] = _dict(v)
88
+ module = obj.__module__
89
+ elif isinstance(obj, (list, tuple)):
90
+ data = []
91
+ for i, vv in enumerate(obj):
92
+ data.append(_dict(vv))
93
+ elif isinstance(obj, dict):
94
+ data = {}
95
+ for _k, vv in obj.items():
96
+ data[_k] = _dict(vv)
97
+ else:
98
+ data = obj
99
+ return {"type": obj.__class__.__name__,
100
+ "data": data, "module": module}
101
+ return _dict(self)
102
+
103
+
104
+ class CustomJSONEncoder(json.JSONEncoder):
105
+ def __init__(self, **kwargs):
106
+ self._with_type = kwargs.pop("with_type", False)
107
+ super().__init__(**kwargs)
108
+
109
+ def default(self, obj):
110
+ if isinstance(obj, datetime.datetime):
111
+ return obj.strftime('%Y-%m-%d %H:%M:%S')
112
+ elif isinstance(obj, datetime.date):
113
+ return obj.strftime('%Y-%m-%d')
114
+ elif isinstance(obj, datetime.timedelta):
115
+ return str(obj)
116
+ elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
117
+ return obj.value
118
+ elif isinstance(obj, set):
119
+ return list(obj)
120
+ elif issubclass(type(obj), BaseType):
121
+ if not self._with_type:
122
+ return obj.to_dict()
123
+ else:
124
+ return obj.to_dict_with_type()
125
+ elif isinstance(obj, type):
126
+ return obj.__name__
127
+ else:
128
+ return json.JSONEncoder.default(self, obj)
129
+
130
+
131
+ def rag_uuid():
132
+ return uuid.uuid1().hex
133
+
134
+
135
+ def string_to_bytes(string):
136
+ return string if isinstance(
137
+ string, bytes) else string.encode(encoding="utf-8")
138
+
139
+
140
+ def bytes_to_string(byte):
141
+ return byte.decode(encoding="utf-8")
142
+
143
+
144
+ def json_dumps(src, byte=False, indent=None, with_type=False):
145
+ dest = json.dumps(
146
+ src,
147
+ indent=indent,
148
+ cls=CustomJSONEncoder,
149
+ with_type=with_type)
150
+ if byte:
151
+ dest = string_to_bytes(dest)
152
+ return dest
153
+
154
+
155
+ def json_loads(src, object_hook=None, object_pairs_hook=None):
156
+ if isinstance(src, bytes):
157
+ src = bytes_to_string(src)
158
+ return json.loads(src, object_hook=object_hook,
159
+ object_pairs_hook=object_pairs_hook)
160
+
161
+
162
+ def current_timestamp():
163
+ return int(time.time() * 1000)
164
+
165
+
166
+ def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
167
+ if not timestamp:
168
+ timestamp = time.time()
169
+ timestamp = int(timestamp) / 1000
170
+ time_array = time.localtime(timestamp)
171
+ str_date = time.strftime(format_string, time_array)
172
+ return str_date
173
+
174
+
175
+ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
176
+ time_array = time.strptime(time_str, format_string)
177
+ time_stamp = int(time.mktime(time_array) * 1000)
178
+ return time_stamp
179
+
180
+
181
+ def serialize_b64(src, to_str=False):
182
+ dest = base64.b64encode(pickle.dumps(src))
183
+ if not to_str:
184
+ return dest
185
+ else:
186
+ return bytes_to_string(dest)
187
+
188
+
189
+ def deserialize_b64(src):
190
+ src = base64.b64decode(
191
+ string_to_bytes(src) if isinstance(
192
+ src, str) else src)
193
+ if use_deserialize_safe_module:
194
+ return restricted_loads(src)
195
+ return pickle.loads(src)
196
+
197
+
198
+ safe_module = {
199
+ 'numpy',
200
+ 'rag_flow'
201
+ }
202
+
203
+
204
+ class RestrictedUnpickler(pickle.Unpickler):
205
+ def find_class(self, module, name):
206
+ import importlib
207
+ if module.split('.')[0] in safe_module:
208
+ _module = importlib.import_module(module)
209
+ return getattr(_module, name)
210
+ # Forbid everything else.
211
+ raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
212
+ (module, name))
213
+
214
+
215
+ def restricted_loads(src):
216
+ """Helper function analogous to pickle.loads()."""
217
+ return RestrictedUnpickler(io.BytesIO(src)).load()
218
+
219
+
220
+ def get_lan_ip():
221
+ if os.name != "nt":
222
+ import fcntl
223
+ import struct
224
+
225
+ def get_interface_ip(ifname):
226
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
227
+ return socket.inet_ntoa(
228
+ fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
229
+
230
+ ip = socket.gethostbyname(socket.getfqdn())
231
+ if ip.startswith("127.") and os.name != "nt":
232
+ interfaces = [
233
+ "bond1",
234
+ "eth0",
235
+ "eth1",
236
+ "eth2",
237
+ "wlan0",
238
+ "wlan1",
239
+ "wifi0",
240
+ "ath0",
241
+ "ath1",
242
+ "ppp0",
243
+ ]
244
+ for ifname in interfaces:
245
+ try:
246
+ ip = get_interface_ip(ifname)
247
+ break
248
+ except IOError as e:
249
+ pass
250
+ return ip or ''
251
+
252
+
253
+ def from_dict_hook(in_dict: dict):
254
+ if "type" in in_dict and "data" in in_dict:
255
+ if in_dict["module"] is None:
256
+ return in_dict["data"]
257
+ else:
258
+ return getattr(importlib.import_module(
259
+ in_dict["module"]), in_dict["type"])(**in_dict["data"])
260
+ else:
261
+ return in_dict
262
+
263
+
264
+ def decrypt_database_password(password):
265
+ encrypt_password = get_base_config("encrypt_password", False)
266
+ encrypt_module = get_base_config("encrypt_module", False)
267
+ private_key = get_base_config("private_key", None)
268
+
269
+ if not password or not encrypt_password:
270
+ return password
271
+
272
+ if not private_key:
273
+ raise ValueError("No private key")
274
+
275
+ module_fun = encrypt_module.split("#")
276
+ pwdecrypt_fun = getattr(
277
+ importlib.import_module(
278
+ module_fun[0]),
279
+ module_fun[1])
280
+
281
+ return pwdecrypt_fun(private_key, password)
282
+
283
+
284
+ def decrypt_database_config(
285
+ database=None, passwd_key="password", name="database"):
286
+ if not database:
287
+ database = get_base_config(name, {})
288
+
289
+ database[passwd_key] = decrypt_database_password(database[passwd_key])
290
+ return database
291
+
292
+
293
+ def update_config(key, value, conf_name=SERVICE_CONF):
294
+ conf_path = conf_realpath(conf_name=conf_name)
295
+ if not os.path.isabs(conf_path):
296
+ conf_path = os.path.join(
297
+ file_utils.get_project_base_directory(), conf_path)
298
+
299
+ with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
300
+ config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
301
+ config[key] = value
302
+ file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
303
+
304
+
305
+ def get_uuid():
306
+ return uuid.uuid1().hex
307
+
308
+
309
+ def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
310
+ return datetime.datetime(date_time.year, date_time.month, date_time.day,
311
+ date_time.hour, date_time.minute, date_time.second)
312
+
313
+
314
+ def get_format_time() -> datetime.datetime:
315
+ return datetime_format(datetime.datetime.now())
316
+
317
+
318
+ def str2date(date_time: str):
319
+ return datetime.datetime.strptime(date_time, '%Y-%m-%d')
320
+
321
+
322
+ def elapsed2time(elapsed):
323
+ seconds = elapsed / 1000
324
+ minuter, second = divmod(seconds, 60)
325
+ hour, minuter = divmod(minuter, 60)
326
+ return '%02d:%02d:%02d' % (hour, minuter, second)
327
+
328
+
329
+ def decrypt(line):
330
+ file_path = os.path.join(
331
+ file_utils.get_project_base_directory(),
332
+ "conf",
333
+ "private.pem")
334
+ rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
335
+ cipher = Cipher_pkcs1_v1_5.new(rsa_key)
336
+ return cipher.decrypt(base64.b64decode(
337
+ line), "Fail to decrypt password!").decode('utf-8')
338
+
339
+
340
+ def download_img(url):
341
+ if not url:
342
+ return ""
343
+ response = requests.get(url)
344
+ return "data:" + \
345
+ response.headers.get('Content-Type', 'image/jpg') + ";" + \
346
+ "base64," + base64.b64encode(response.content).decode("utf-8")
api/utils/api_utils.py CHANGED
@@ -1,269 +1,269 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import json
17
- import random
18
- import time
19
- from functools import wraps
20
- from io import BytesIO
21
- from flask import (
22
- Response, jsonify, send_file, make_response,
23
- request as flask_request,
24
- )
25
- from werkzeug.http import HTTP_STATUS_CODES
26
-
27
- from api.utils import json_dumps
28
- from api.settings import RetCode
29
- from api.settings import (
30
- REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
31
- stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
32
- )
33
- import requests
34
- import functools
35
- from api.utils import CustomJSONEncoder
36
- from uuid import uuid1
37
- from base64 import b64encode
38
- from hmac import HMAC
39
- from urllib.parse import quote, urlencode
40
-
41
- requests.models.complexjson.dumps = functools.partial(
42
- json.dumps, cls=CustomJSONEncoder)
43
-
44
-
45
- def request(**kwargs):
46
- sess = requests.Session()
47
- stream = kwargs.pop('stream', sess.stream)
48
- timeout = kwargs.pop('timeout', None)
49
- kwargs['headers'] = {
50
- k.replace(
51
- '_',
52
- '-').upper(): v for k,
53
- v in kwargs.get(
54
- 'headers',
55
- {}).items()}
56
- prepped = requests.Request(**kwargs).prepare()
57
-
58
- if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
59
- timestamp = str(round(time() * 1000))
60
- nonce = str(uuid1())
61
- signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
62
- timestamp.encode('ascii'),
63
- nonce.encode('ascii'),
64
- HTTP_APP_KEY.encode('ascii'),
65
- prepped.path_url.encode('ascii'),
66
- prepped.body if kwargs.get('json') else b'',
67
- urlencode(
68
- sorted(
69
- kwargs['data'].items()),
70
- quote_via=quote,
71
- safe='-._~').encode('ascii')
72
- if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
73
- ]), 'sha1').digest()).decode('ascii')
74
-
75
- prepped.headers.update({
76
- 'TIMESTAMP': timestamp,
77
- 'NONCE': nonce,
78
- 'APP-KEY': HTTP_APP_KEY,
79
- 'SIGNATURE': signature,
80
- })
81
-
82
- return sess.send(prepped, stream=stream, timeout=timeout)
83
-
84
-
85
- def get_exponential_backoff_interval(retries, full_jitter=False):
86
- """Calculate the exponential backoff wait time."""
87
- # Will be zero if factor equals 0
88
- countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
89
- # Full jitter according to
90
- # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
91
- if full_jitter:
92
- countdown = random.randrange(countdown + 1)
93
- # Adjust according to maximum wait time and account for negative values.
94
- return max(0, countdown)
95
-
96
-
97
- def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
98
- data=None, job_id=None, meta=None):
99
- import re
100
- result_dict = {
101
- "retcode": retcode,
102
- "retmsg": retmsg,
103
- # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
104
- "data": data,
105
- "jobId": job_id,
106
- "meta": meta,
107
- }
108
-
109
- response = {}
110
- for key, value in result_dict.items():
111
- if value is None and key != "retcode":
112
- continue
113
- else:
114
- response[key] = value
115
- return jsonify(response)
116
-
117
-
118
- def get_data_error_result(retcode=RetCode.DATA_ERROR,
119
- retmsg='Sorry! Data missing!'):
120
- import re
121
- result_dict = {
122
- "retcode": retcode,
123
- "retmsg": re.sub(
124
- r"rag",
125
- "seceum",
126
- retmsg,
127
- flags=re.IGNORECASE)}
128
- response = {}
129
- for key, value in result_dict.items():
130
- if value is None and key != "retcode":
131
- continue
132
- else:
133
- response[key] = value
134
- return jsonify(response)
135
-
136
-
137
- def server_error_response(e):
138
- stat_logger.exception(e)
139
- try:
140
- if e.code == 401:
141
- return get_json_result(retcode=401, retmsg=repr(e))
142
- except BaseException:
143
- pass
144
- if len(e.args) > 1:
145
- return get_json_result(
146
- retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
147
- if repr(e).find("index_not_found_exception") >= 0:
148
- return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
149
-
150
- return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
151
-
152
-
153
- def error_response(response_code, retmsg=None):
154
- if retmsg is None:
155
- retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
156
-
157
- return Response(json.dumps({
158
- 'retmsg': retmsg,
159
- 'retcode': response_code,
160
- }), status=response_code, mimetype='application/json')
161
-
162
-
163
- def validate_request(*args, **kwargs):
164
- def wrapper(func):
165
- @wraps(func)
166
- def decorated_function(*_args, **_kwargs):
167
- input_arguments = flask_request.json or flask_request.form.to_dict()
168
- no_arguments = []
169
- error_arguments = []
170
- for arg in args:
171
- if arg not in input_arguments:
172
- no_arguments.append(arg)
173
- for k, v in kwargs.items():
174
- config_value = input_arguments.get(k, None)
175
- if config_value is None:
176
- no_arguments.append(k)
177
- elif isinstance(v, (tuple, list)):
178
- if config_value not in v:
179
- error_arguments.append((k, set(v)))
180
- elif config_value != v:
181
- error_arguments.append((k, v))
182
- if no_arguments or error_arguments:
183
- error_string = ""
184
- if no_arguments:
185
- error_string += "required argument are missing: {}; ".format(
186
- ",".join(no_arguments))
187
- if error_arguments:
188
- error_string += "required argument values: {}".format(
189
- ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
190
- return get_json_result(
191
- retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
192
- return func(*_args, **_kwargs)
193
- return decorated_function
194
- return wrapper
195
-
196
-
197
- def is_localhost(ip):
198
- return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
199
-
200
-
201
- def send_file_in_mem(data, filename):
202
- if not isinstance(data, (str, bytes)):
203
- data = json_dumps(data)
204
- if isinstance(data, str):
205
- data = data.encode('utf-8')
206
-
207
- f = BytesIO()
208
- f.write(data)
209
- f.seek(0)
210
-
211
- return send_file(f, as_attachment=True, attachment_filename=filename)
212
-
213
-
214
- def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
215
- response = {"retcode": retcode, "retmsg": retmsg, "data": data}
216
- return jsonify(response)
217
-
218
-
219
- def cors_reponse(retcode=RetCode.SUCCESS,
220
- retmsg='success', data=None, auth=None):
221
- result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
222
- response_dict = {}
223
- for key, value in result_dict.items():
224
- if value is None and key != "retcode":
225
- continue
226
- else:
227
- response_dict[key] = value
228
- response = make_response(jsonify(response_dict))
229
- if auth:
230
- response.headers["Authorization"] = auth
231
- response.headers["Access-Control-Allow-Origin"] = "*"
232
- response.headers["Access-Control-Allow-Method"] = "*"
233
- response.headers["Access-Control-Allow-Headers"] = "*"
234
- response.headers["Access-Control-Allow-Headers"] = "*"
235
- response.headers["Access-Control-Expose-Headers"] = "Authorization"
236
- return response
237
-
238
- def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
239
- import re
240
- result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
241
- response = {}
242
- for key, value in result_dict.items():
243
- if value is None and key != "code":
244
- continue
245
- else:
246
- response[key] = value
247
- return jsonify(response)
248
-
249
-
250
- def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
251
- if data is None:
252
- return jsonify({"code": code, "message": message})
253
- else:
254
- return jsonify({"code": code, "message": message, "data": data})
255
-
256
-
257
- def construct_error_response(e):
258
- stat_logger.exception(e)
259
- try:
260
- if e.code == 401:
261
- return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
262
- except BaseException:
263
- pass
264
- if len(e.args) > 1:
265
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
266
- if repr(e).find("index_not_found_exception") >=0:
267
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
268
-
269
- return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import json
17
+ import random
18
+ import time
19
+ from functools import wraps
20
+ from io import BytesIO
21
+ from flask import (
22
+ Response, jsonify, send_file, make_response,
23
+ request as flask_request,
24
+ )
25
+ from werkzeug.http import HTTP_STATUS_CODES
26
+
27
+ from api.utils import json_dumps
28
+ from api.settings import RetCode
29
+ from api.settings import (
30
+ REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
31
+ stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
32
+ )
33
+ import requests
34
+ import functools
35
+ from api.utils import CustomJSONEncoder
36
+ from uuid import uuid1
37
+ from base64 import b64encode
38
+ from hmac import HMAC
39
+ from urllib.parse import quote, urlencode
40
+
41
+ requests.models.complexjson.dumps = functools.partial(
42
+ json.dumps, cls=CustomJSONEncoder)
43
+
44
+
45
+ def request(**kwargs):
46
+ sess = requests.Session()
47
+ stream = kwargs.pop('stream', sess.stream)
48
+ timeout = kwargs.pop('timeout', None)
49
+ kwargs['headers'] = {
50
+ k.replace(
51
+ '_',
52
+ '-').upper(): v for k,
53
+ v in kwargs.get(
54
+ 'headers',
55
+ {}).items()}
56
+ prepped = requests.Request(**kwargs).prepare()
57
+
58
+ if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
59
+ timestamp = str(round(time() * 1000))
60
+ nonce = str(uuid1())
61
+ signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
62
+ timestamp.encode('ascii'),
63
+ nonce.encode('ascii'),
64
+ HTTP_APP_KEY.encode('ascii'),
65
+ prepped.path_url.encode('ascii'),
66
+ prepped.body if kwargs.get('json') else b'',
67
+ urlencode(
68
+ sorted(
69
+ kwargs['data'].items()),
70
+ quote_via=quote,
71
+ safe='-._~').encode('ascii')
72
+ if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
73
+ ]), 'sha1').digest()).decode('ascii')
74
+
75
+ prepped.headers.update({
76
+ 'TIMESTAMP': timestamp,
77
+ 'NONCE': nonce,
78
+ 'APP-KEY': HTTP_APP_KEY,
79
+ 'SIGNATURE': signature,
80
+ })
81
+
82
+ return sess.send(prepped, stream=stream, timeout=timeout)
83
+
84
+
85
+ def get_exponential_backoff_interval(retries, full_jitter=False):
86
+ """Calculate the exponential backoff wait time."""
87
+ # Will be zero if factor equals 0
88
+ countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
89
+ # Full jitter according to
90
+ # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
91
+ if full_jitter:
92
+ countdown = random.randrange(countdown + 1)
93
+ # Adjust according to maximum wait time and account for negative values.
94
+ return max(0, countdown)
95
+
96
+
97
+ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
98
+ data=None, job_id=None, meta=None):
99
+ import re
100
+ result_dict = {
101
+ "retcode": retcode,
102
+ "retmsg": retmsg,
103
+ # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
104
+ "data": data,
105
+ "jobId": job_id,
106
+ "meta": meta,
107
+ }
108
+
109
+ response = {}
110
+ for key, value in result_dict.items():
111
+ if value is None and key != "retcode":
112
+ continue
113
+ else:
114
+ response[key] = value
115
+ return jsonify(response)
116
+
117
+
118
+ def get_data_error_result(retcode=RetCode.DATA_ERROR,
119
+ retmsg='Sorry! Data missing!'):
120
+ import re
121
+ result_dict = {
122
+ "retcode": retcode,
123
+ "retmsg": re.sub(
124
+ r"rag",
125
+ "seceum",
126
+ retmsg,
127
+ flags=re.IGNORECASE)}
128
+ response = {}
129
+ for key, value in result_dict.items():
130
+ if value is None and key != "retcode":
131
+ continue
132
+ else:
133
+ response[key] = value
134
+ return jsonify(response)
135
+
136
+
137
+ def server_error_response(e):
138
+ stat_logger.exception(e)
139
+ try:
140
+ if e.code == 401:
141
+ return get_json_result(retcode=401, retmsg=repr(e))
142
+ except BaseException:
143
+ pass
144
+ if len(e.args) > 1:
145
+ return get_json_result(
146
+ retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
147
+ if repr(e).find("index_not_found_exception") >= 0:
148
+ return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
149
+
150
+ return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
151
+
152
+
153
+ def error_response(response_code, retmsg=None):
154
+ if retmsg is None:
155
+ retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
156
+
157
+ return Response(json.dumps({
158
+ 'retmsg': retmsg,
159
+ 'retcode': response_code,
160
+ }), status=response_code, mimetype='application/json')
161
+
162
+
163
+ def validate_request(*args, **kwargs):
164
+ def wrapper(func):
165
+ @wraps(func)
166
+ def decorated_function(*_args, **_kwargs):
167
+ input_arguments = flask_request.json or flask_request.form.to_dict()
168
+ no_arguments = []
169
+ error_arguments = []
170
+ for arg in args:
171
+ if arg not in input_arguments:
172
+ no_arguments.append(arg)
173
+ for k, v in kwargs.items():
174
+ config_value = input_arguments.get(k, None)
175
+ if config_value is None:
176
+ no_arguments.append(k)
177
+ elif isinstance(v, (tuple, list)):
178
+ if config_value not in v:
179
+ error_arguments.append((k, set(v)))
180
+ elif config_value != v:
181
+ error_arguments.append((k, v))
182
+ if no_arguments or error_arguments:
183
+ error_string = ""
184
+ if no_arguments:
185
+ error_string += "required argument are missing: {}; ".format(
186
+ ",".join(no_arguments))
187
+ if error_arguments:
188
+ error_string += "required argument values: {}".format(
189
+ ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
190
+ return get_json_result(
191
+ retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
192
+ return func(*_args, **_kwargs)
193
+ return decorated_function
194
+ return wrapper
195
+
196
+
197
+ def is_localhost(ip):
198
+ return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
199
+
200
+
201
+ def send_file_in_mem(data, filename):
202
+ if not isinstance(data, (str, bytes)):
203
+ data = json_dumps(data)
204
+ if isinstance(data, str):
205
+ data = data.encode('utf-8')
206
+
207
+ f = BytesIO()
208
+ f.write(data)
209
+ f.seek(0)
210
+
211
+ return send_file(f, as_attachment=True, attachment_filename=filename)
212
+
213
+
214
+ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
215
+ response = {"retcode": retcode, "retmsg": retmsg, "data": data}
216
+ return jsonify(response)
217
+
218
+
219
+ def cors_reponse(retcode=RetCode.SUCCESS,
220
+ retmsg='success', data=None, auth=None):
221
+ result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
222
+ response_dict = {}
223
+ for key, value in result_dict.items():
224
+ if value is None and key != "retcode":
225
+ continue
226
+ else:
227
+ response_dict[key] = value
228
+ response = make_response(jsonify(response_dict))
229
+ if auth:
230
+ response.headers["Authorization"] = auth
231
+ response.headers["Access-Control-Allow-Origin"] = "*"
232
+ response.headers["Access-Control-Allow-Method"] = "*"
233
+ response.headers["Access-Control-Allow-Headers"] = "*"
234
+ response.headers["Access-Control-Allow-Headers"] = "*"
235
+ response.headers["Access-Control-Expose-Headers"] = "Authorization"
236
+ return response
237
+
238
+ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
239
+ import re
240
+ result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
241
+ response = {}
242
+ for key, value in result_dict.items():
243
+ if value is None and key != "code":
244
+ continue
245
+ else:
246
+ response[key] = value
247
+ return jsonify(response)
248
+
249
+
250
+ def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
251
+ if data is None:
252
+ return jsonify({"code": code, "message": message})
253
+ else:
254
+ return jsonify({"code": code, "message": message, "data": data})
255
+
256
+
257
+ def construct_error_response(e):
258
+ stat_logger.exception(e)
259
+ try:
260
+ if e.code == 401:
261
+ return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
262
+ except BaseException:
263
+ pass
264
+ if len(e.args) > 1:
265
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
266
+ if repr(e).find("index_not_found_exception") >=0:
267
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
268
+
269
+ return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
api/utils/commands.py CHANGED
@@ -1,78 +1,78 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- import base64
18
- import click
19
- import re
20
-
21
- from flask import Flask
22
- from werkzeug.security import generate_password_hash
23
-
24
- from api.db.services import UserService
25
-
26
-
27
- @click.command('reset-password', help='Reset the account password.')
28
- @click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
29
- @click.option('--new-password', prompt=True, help='the new password.')
30
- @click.option('--password-confirm', prompt=True, help='the new password confirm.')
31
- def reset_password(email, new_password, password_confirm):
32
- if str(new_password).strip() != str(password_confirm).strip():
33
- click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
34
- return
35
- user = UserService.query(email=email)
36
- if not user:
37
- click.echo(click.style('sorry. The Email is not registered!.', fg='red'))
38
- return
39
- encode_password = base64.b64encode(new_password.encode('utf-8')).decode('utf-8')
40
- password_hash = generate_password_hash(encode_password)
41
- user_dict = {
42
- 'password': password_hash
43
- }
44
- UserService.update_user(user[0].id,user_dict)
45
- click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
46
-
47
-
48
- @click.command('reset-email', help='Reset the account email.')
49
- @click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
50
- @click.option('--new-email', prompt=True, help='the new email.')
51
- @click.option('--email-confirm', prompt=True, help='the new email confirm.')
52
- def reset_email(email, new_email, email_confirm):
53
- if str(new_email).strip() != str(email_confirm).strip():
54
- click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
55
- return
56
- if str(new_email).strip() == str(email).strip():
57
- click.echo(click.style('Sorry, new email and old email are the same.', fg='red'))
58
- return
59
- user = UserService.query(email=email)
60
- if not user:
61
- click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
62
- return
63
- if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", new_email):
64
- click.echo(click.style('sorry. {} is not a valid email. '.format(new_email), fg='red'))
65
- return
66
- new_user = UserService.query(email=new_email)
67
- if new_user:
68
- click.echo(click.style('sorry. the account: [{}] is exist .'.format(new_email), fg='red'))
69
- return
70
- user_dict = {
71
- 'email': new_email
72
- }
73
- UserService.update_user(user[0].id,user_dict)
74
- click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
75
-
76
- def register_commands(app: Flask):
77
- app.cli.add_command(reset_password)
78
- app.cli.add_command(reset_email)
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import base64
18
+ import click
19
+ import re
20
+
21
+ from flask import Flask
22
+ from werkzeug.security import generate_password_hash
23
+
24
+ from api.db.services import UserService
25
+
26
+
27
+ @click.command('reset-password', help='Reset the account password.')
28
+ @click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
29
+ @click.option('--new-password', prompt=True, help='the new password.')
30
+ @click.option('--password-confirm', prompt=True, help='the new password confirm.')
31
+ def reset_password(email, new_password, password_confirm):
32
+ if str(new_password).strip() != str(password_confirm).strip():
33
+ click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
34
+ return
35
+ user = UserService.query(email=email)
36
+ if not user:
37
+ click.echo(click.style('sorry. The Email is not registered!.', fg='red'))
38
+ return
39
+ encode_password = base64.b64encode(new_password.encode('utf-8')).decode('utf-8')
40
+ password_hash = generate_password_hash(encode_password)
41
+ user_dict = {
42
+ 'password': password_hash
43
+ }
44
+ UserService.update_user(user[0].id,user_dict)
45
+ click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
46
+
47
+
48
+ @click.command('reset-email', help='Reset the account email.')
49
+ @click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
50
+ @click.option('--new-email', prompt=True, help='the new email.')
51
+ @click.option('--email-confirm', prompt=True, help='the new email confirm.')
52
+ def reset_email(email, new_email, email_confirm):
53
+ if str(new_email).strip() != str(email_confirm).strip():
54
+ click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
55
+ return
56
+ if str(new_email).strip() == str(email).strip():
57
+ click.echo(click.style('Sorry, new email and old email are the same.', fg='red'))
58
+ return
59
+ user = UserService.query(email=email)
60
+ if not user:
61
+ click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
62
+ return
63
+ if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", new_email):
64
+ click.echo(click.style('sorry. {} is not a valid email. '.format(new_email), fg='red'))
65
+ return
66
+ new_user = UserService.query(email=new_email)
67
+ if new_user:
68
+ click.echo(click.style('sorry. the account: [{}] is exist .'.format(new_email), fg='red'))
69
+ return
70
+ user_dict = {
71
+ 'email': new_email
72
+ }
73
+ UserService.update_user(user[0].id,user_dict)
74
+ click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
75
+
76
+ def register_commands(app: Flask):
77
+ app.cli.add_command(reset_password)
78
+ app.cli.add_command(reset_email)
api/utils/file_utils.py CHANGED
@@ -1,207 +1,207 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import base64
17
- import json
18
- import os
19
- import re
20
- from io import BytesIO
21
-
22
- import pdfplumber
23
- from PIL import Image
24
- from cachetools import LRUCache, cached
25
- from ruamel.yaml import YAML
26
-
27
- from api.db import FileType
28
-
29
- PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
30
- RAG_BASE = os.getenv("RAG_BASE")
31
-
32
-
33
- def get_project_base_directory(*args):
34
- global PROJECT_BASE
35
- if PROJECT_BASE is None:
36
- PROJECT_BASE = os.path.abspath(
37
- os.path.join(
38
- os.path.dirname(os.path.realpath(__file__)),
39
- os.pardir,
40
- os.pardir,
41
- )
42
- )
43
-
44
- if args:
45
- return os.path.join(PROJECT_BASE, *args)
46
- return PROJECT_BASE
47
-
48
-
49
- def get_rag_directory(*args):
50
- global RAG_BASE
51
- if RAG_BASE is None:
52
- RAG_BASE = os.path.abspath(
53
- os.path.join(
54
- os.path.dirname(os.path.realpath(__file__)),
55
- os.pardir,
56
- os.pardir,
57
- os.pardir,
58
- )
59
- )
60
- if args:
61
- return os.path.join(RAG_BASE, *args)
62
- return RAG_BASE
63
-
64
-
65
- def get_rag_python_directory(*args):
66
- return get_rag_directory("python", *args)
67
-
68
-
69
- def get_home_cache_dir():
70
- dir = os.path.join(os.path.expanduser('~'), ".ragflow")
71
- try:
72
- os.mkdir(dir)
73
- except OSError as error:
74
- pass
75
- return dir
76
-
77
-
78
- @cached(cache=LRUCache(maxsize=10))
79
- def load_json_conf(conf_path):
80
- if os.path.isabs(conf_path):
81
- json_conf_path = conf_path
82
- else:
83
- json_conf_path = os.path.join(get_project_base_directory(), conf_path)
84
- try:
85
- with open(json_conf_path) as f:
86
- return json.load(f)
87
- except BaseException:
88
- raise EnvironmentError(
89
- "loading json file config from '{}' failed!".format(json_conf_path)
90
- )
91
-
92
-
93
- def dump_json_conf(config_data, conf_path):
94
- if os.path.isabs(conf_path):
95
- json_conf_path = conf_path
96
- else:
97
- json_conf_path = os.path.join(get_project_base_directory(), conf_path)
98
- try:
99
- with open(json_conf_path, "w") as f:
100
- json.dump(config_data, f, indent=4)
101
- except BaseException:
102
- raise EnvironmentError(
103
- "loading json file config from '{}' failed!".format(json_conf_path)
104
- )
105
-
106
-
107
- def load_json_conf_real_time(conf_path):
108
- if os.path.isabs(conf_path):
109
- json_conf_path = conf_path
110
- else:
111
- json_conf_path = os.path.join(get_project_base_directory(), conf_path)
112
- try:
113
- with open(json_conf_path) as f:
114
- return json.load(f)
115
- except BaseException:
116
- raise EnvironmentError(
117
- "loading json file config from '{}' failed!".format(json_conf_path)
118
- )
119
-
120
-
121
- def load_yaml_conf(conf_path):
122
- if not os.path.isabs(conf_path):
123
- conf_path = os.path.join(get_project_base_directory(), conf_path)
124
- try:
125
- with open(conf_path) as f:
126
- yaml = YAML(typ='safe', pure=True)
127
- return yaml.load(f)
128
- except Exception as e:
129
- raise EnvironmentError(
130
- "loading yaml file config from {} failed:".format(conf_path), e
131
- )
132
-
133
-
134
- def rewrite_yaml_conf(conf_path, config):
135
- if not os.path.isabs(conf_path):
136
- conf_path = os.path.join(get_project_base_directory(), conf_path)
137
- try:
138
- with open(conf_path, "w") as f:
139
- yaml = YAML(typ="safe")
140
- yaml.dump(config, f)
141
- except Exception as e:
142
- raise EnvironmentError(
143
- "rewrite yaml file config {} failed:".format(conf_path), e
144
- )
145
-
146
-
147
- def rewrite_json_file(filepath, json_data):
148
- with open(filepath, "w") as f:
149
- json.dump(json_data, f, indent=4, separators=(",", ": "))
150
- f.close()
151
-
152
-
153
- def filename_type(filename):
154
- filename = filename.lower()
155
- if re.match(r".*\.pdf$", filename):
156
- return FileType.PDF.value
157
-
158
- if re.match(
159
- r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
160
- return FileType.DOC.value
161
-
162
- if re.match(
163
- r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
164
- return FileType.AURAL.value
165
-
166
- if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
167
- return FileType.VISUAL.value
168
-
169
- return FileType.OTHER.value
170
-
171
-
172
- def thumbnail(filename, blob):
173
- filename = filename.lower()
174
- if re.match(r".*\.pdf$", filename):
175
- pdf = pdfplumber.open(BytesIO(blob))
176
- buffered = BytesIO()
177
- pdf.pages[0].to_image(resolution=32).annotated.save(buffered, format="png")
178
- return "data:image/png;base64," + \
179
- base64.b64encode(buffered.getvalue()).decode("utf-8")
180
-
181
- if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
182
- image = Image.open(BytesIO(blob))
183
- image.thumbnail((30, 30))
184
- buffered = BytesIO()
185
- image.save(buffered, format="png")
186
- return "data:image/png;base64," + \
187
- base64.b64encode(buffered.getvalue()).decode("utf-8")
188
-
189
- if re.match(r".*\.(ppt|pptx)$", filename):
190
- import aspose.slides as slides
191
- import aspose.pydrawing as drawing
192
- try:
193
- with slides.Presentation(BytesIO(blob)) as presentation:
194
- buffered = BytesIO()
195
- presentation.slides[0].get_thumbnail(0.03, 0.03).save(
196
- buffered, drawing.imaging.ImageFormat.png)
197
- return "data:image/png;base64," + \
198
- base64.b64encode(buffered.getvalue()).decode("utf-8")
199
- except Exception as e:
200
- pass
201
-
202
-
203
- def traversal_files(base):
204
- for root, ds, fs in os.walk(base):
205
- for f in fs:
206
- fullname = os.path.join(root, f)
207
- yield fullname
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import base64
17
+ import json
18
+ import os
19
+ import re
20
+ from io import BytesIO
21
+
22
+ import pdfplumber
23
+ from PIL import Image
24
+ from cachetools import LRUCache, cached
25
+ from ruamel.yaml import YAML
26
+
27
+ from api.db import FileType
28
+
29
+ PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
30
+ RAG_BASE = os.getenv("RAG_BASE")
31
+
32
+
33
+ def get_project_base_directory(*args):
34
+ global PROJECT_BASE
35
+ if PROJECT_BASE is None:
36
+ PROJECT_BASE = os.path.abspath(
37
+ os.path.join(
38
+ os.path.dirname(os.path.realpath(__file__)),
39
+ os.pardir,
40
+ os.pardir,
41
+ )
42
+ )
43
+
44
+ if args:
45
+ return os.path.join(PROJECT_BASE, *args)
46
+ return PROJECT_BASE
47
+
48
+
49
+ def get_rag_directory(*args):
50
+ global RAG_BASE
51
+ if RAG_BASE is None:
52
+ RAG_BASE = os.path.abspath(
53
+ os.path.join(
54
+ os.path.dirname(os.path.realpath(__file__)),
55
+ os.pardir,
56
+ os.pardir,
57
+ os.pardir,
58
+ )
59
+ )
60
+ if args:
61
+ return os.path.join(RAG_BASE, *args)
62
+ return RAG_BASE
63
+
64
+
65
+ def get_rag_python_directory(*args):
66
+ return get_rag_directory("python", *args)
67
+
68
+
69
+ def get_home_cache_dir():
70
+ dir = os.path.join(os.path.expanduser('~'), ".ragflow")
71
+ try:
72
+ os.mkdir(dir)
73
+ except OSError as error:
74
+ pass
75
+ return dir
76
+
77
+
78
+ @cached(cache=LRUCache(maxsize=10))
79
+ def load_json_conf(conf_path):
80
+ if os.path.isabs(conf_path):
81
+ json_conf_path = conf_path
82
+ else:
83
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
84
+ try:
85
+ with open(json_conf_path) as f:
86
+ return json.load(f)
87
+ except BaseException:
88
+ raise EnvironmentError(
89
+ "loading json file config from '{}' failed!".format(json_conf_path)
90
+ )
91
+
92
+
93
+ def dump_json_conf(config_data, conf_path):
94
+ if os.path.isabs(conf_path):
95
+ json_conf_path = conf_path
96
+ else:
97
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
98
+ try:
99
+ with open(json_conf_path, "w") as f:
100
+ json.dump(config_data, f, indent=4)
101
+ except BaseException:
102
+ raise EnvironmentError(
103
+ "loading json file config from '{}' failed!".format(json_conf_path)
104
+ )
105
+
106
+
107
+ def load_json_conf_real_time(conf_path):
108
+ if os.path.isabs(conf_path):
109
+ json_conf_path = conf_path
110
+ else:
111
+ json_conf_path = os.path.join(get_project_base_directory(), conf_path)
112
+ try:
113
+ with open(json_conf_path) as f:
114
+ return json.load(f)
115
+ except BaseException:
116
+ raise EnvironmentError(
117
+ "loading json file config from '{}' failed!".format(json_conf_path)
118
+ )
119
+
120
+
121
+ def load_yaml_conf(conf_path):
122
+ if not os.path.isabs(conf_path):
123
+ conf_path = os.path.join(get_project_base_directory(), conf_path)
124
+ try:
125
+ with open(conf_path) as f:
126
+ yaml = YAML(typ='safe', pure=True)
127
+ return yaml.load(f)
128
+ except Exception as e:
129
+ raise EnvironmentError(
130
+ "loading yaml file config from {} failed:".format(conf_path), e
131
+ )
132
+
133
+
134
+ def rewrite_yaml_conf(conf_path, config):
135
+ if not os.path.isabs(conf_path):
136
+ conf_path = os.path.join(get_project_base_directory(), conf_path)
137
+ try:
138
+ with open(conf_path, "w") as f:
139
+ yaml = YAML(typ="safe")
140
+ yaml.dump(config, f)
141
+ except Exception as e:
142
+ raise EnvironmentError(
143
+ "rewrite yaml file config {} failed:".format(conf_path), e
144
+ )
145
+
146
+
147
+ def rewrite_json_file(filepath, json_data):
148
+ with open(filepath, "w") as f:
149
+ json.dump(json_data, f, indent=4, separators=(",", ": "))
150
+ f.close()
151
+
152
+
153
+ def filename_type(filename):
154
+ filename = filename.lower()
155
+ if re.match(r".*\.pdf$", filename):
156
+ return FileType.PDF.value
157
+
158
+ if re.match(
159
+ r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
160
+ return FileType.DOC.value
161
+
162
+ if re.match(
163
+ r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
164
+ return FileType.AURAL.value
165
+
166
+ if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
167
+ return FileType.VISUAL.value
168
+
169
+ return FileType.OTHER.value
170
+
171
+
172
+ def thumbnail(filename, blob):
173
+ filename = filename.lower()
174
+ if re.match(r".*\.pdf$", filename):
175
+ pdf = pdfplumber.open(BytesIO(blob))
176
+ buffered = BytesIO()
177
+ pdf.pages[0].to_image(resolution=32).annotated.save(buffered, format="png")
178
+ return "data:image/png;base64," + \
179
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
180
+
181
+ if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
182
+ image = Image.open(BytesIO(blob))
183
+ image.thumbnail((30, 30))
184
+ buffered = BytesIO()
185
+ image.save(buffered, format="png")
186
+ return "data:image/png;base64," + \
187
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
188
+
189
+ if re.match(r".*\.(ppt|pptx)$", filename):
190
+ import aspose.slides as slides
191
+ import aspose.pydrawing as drawing
192
+ try:
193
+ with slides.Presentation(BytesIO(blob)) as presentation:
194
+ buffered = BytesIO()
195
+ presentation.slides[0].get_thumbnail(0.03, 0.03).save(
196
+ buffered, drawing.imaging.ImageFormat.png)
197
+ return "data:image/png;base64," + \
198
+ base64.b64encode(buffered.getvalue()).decode("utf-8")
199
+ except Exception as e:
200
+ pass
201
+
202
+
203
+ def traversal_files(base):
204
+ for root, ds, fs in os.walk(base):
205
+ for f in fs:
206
+ fullname = os.path.join(root, f)
207
+ yield fullname
api/utils/log_utils.py CHANGED
@@ -1,313 +1,313 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import os
17
- import typing
18
- import traceback
19
- import logging
20
- import inspect
21
- from logging.handlers import TimedRotatingFileHandler
22
- from threading import RLock
23
-
24
- from api.utils import file_utils
25
-
26
-
27
- class LoggerFactory(object):
28
- TYPE = "FILE"
29
- LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
30
- logging.basicConfig(format=LOG_FORMAT)
31
- LEVEL = logging.DEBUG
32
- logger_dict = {}
33
- global_handler_dict = {}
34
-
35
- LOG_DIR = None
36
- PARENT_LOG_DIR = None
37
- log_share = True
38
-
39
- append_to_parent_log = None
40
-
41
- lock = RLock()
42
- # CRITICAL = 50
43
- # FATAL = CRITICAL
44
- # ERROR = 40
45
- # WARNING = 30
46
- # WARN = WARNING
47
- # INFO = 20
48
- # DEBUG = 10
49
- # NOTSET = 0
50
- levels = (10, 20, 30, 40)
51
- schedule_logger_dict = {}
52
-
53
- @staticmethod
54
- def set_directory(directory=None, parent_log_dir=None,
55
- append_to_parent_log=None, force=False):
56
- if parent_log_dir:
57
- LoggerFactory.PARENT_LOG_DIR = parent_log_dir
58
- if append_to_parent_log:
59
- LoggerFactory.append_to_parent_log = append_to_parent_log
60
- with LoggerFactory.lock:
61
- if not directory:
62
- directory = file_utils.get_project_base_directory("logs")
63
- if not LoggerFactory.LOG_DIR or force:
64
- LoggerFactory.LOG_DIR = directory
65
- if LoggerFactory.log_share:
66
- oldmask = os.umask(000)
67
- os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
68
- os.umask(oldmask)
69
- else:
70
- os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
71
- for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
72
- for className, (logger,
73
- handler) in LoggerFactory.logger_dict.items():
74
- logger.removeHandler(ghandler)
75
- ghandler.close()
76
- LoggerFactory.global_handler_dict = {}
77
- for className, (logger,
78
- handler) in LoggerFactory.logger_dict.items():
79
- logger.removeHandler(handler)
80
- _handler = None
81
- if handler:
82
- handler.close()
83
- if className != "default":
84
- _handler = LoggerFactory.get_handler(className)
85
- logger.addHandler(_handler)
86
- LoggerFactory.assemble_global_handler(logger)
87
- LoggerFactory.logger_dict[className] = logger, _handler
88
-
89
- @staticmethod
90
- def new_logger(name):
91
- logger = logging.getLogger(name)
92
- logger.propagate = False
93
- logger.setLevel(LoggerFactory.LEVEL)
94
- return logger
95
-
96
- @staticmethod
97
- def get_logger(class_name=None):
98
- with LoggerFactory.lock:
99
- if class_name in LoggerFactory.logger_dict.keys():
100
- logger, handler = LoggerFactory.logger_dict[class_name]
101
- if not logger:
102
- logger, handler = LoggerFactory.init_logger(class_name)
103
- else:
104
- logger, handler = LoggerFactory.init_logger(class_name)
105
- return logger
106
-
107
- @staticmethod
108
- def get_global_handler(logger_name, level=None, log_dir=None):
109
- if not LoggerFactory.LOG_DIR:
110
- return logging.StreamHandler()
111
- if log_dir:
112
- logger_name_key = logger_name + "_" + log_dir
113
- else:
114
- logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
115
- # if loggerName not in LoggerFactory.globalHandlerDict:
116
- if logger_name_key not in LoggerFactory.global_handler_dict:
117
- with LoggerFactory.lock:
118
- if logger_name_key not in LoggerFactory.global_handler_dict:
119
- handler = LoggerFactory.get_handler(
120
- logger_name, level, log_dir)
121
- LoggerFactory.global_handler_dict[logger_name_key] = handler
122
- return LoggerFactory.global_handler_dict[logger_name_key]
123
-
124
- @staticmethod
125
- def get_handler(class_name, level=None, log_dir=None,
126
- log_type=None, job_id=None):
127
- if not log_type:
128
- if not LoggerFactory.LOG_DIR or not class_name:
129
- return logging.StreamHandler()
130
- # return Diy_StreamHandler()
131
-
132
- if not log_dir:
133
- log_file = os.path.join(
134
- LoggerFactory.LOG_DIR,
135
- "{}.log".format(class_name))
136
- else:
137
- log_file = os.path.join(log_dir, "{}.log".format(class_name))
138
- else:
139
- log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
140
- log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))
141
-
142
- os.makedirs(os.path.dirname(log_file), exist_ok=True)
143
- if LoggerFactory.log_share:
144
- handler = ROpenHandler(log_file,
145
- when='D',
146
- interval=1,
147
- backupCount=14,
148
- delay=True)
149
- else:
150
- handler = TimedRotatingFileHandler(log_file,
151
- when='D',
152
- interval=1,
153
- backupCount=14,
154
- delay=True)
155
- if level:
156
- handler.level = level
157
-
158
- return handler
159
-
160
- @staticmethod
161
- def init_logger(class_name):
162
- with LoggerFactory.lock:
163
- logger = LoggerFactory.new_logger(class_name)
164
- handler = None
165
- if class_name:
166
- handler = LoggerFactory.get_handler(class_name)
167
- logger.addHandler(handler)
168
- LoggerFactory.logger_dict[class_name] = logger, handler
169
-
170
- else:
171
- LoggerFactory.logger_dict["default"] = logger, handler
172
-
173
- LoggerFactory.assemble_global_handler(logger)
174
- return logger, handler
175
-
176
- @staticmethod
177
- def assemble_global_handler(logger):
178
- if LoggerFactory.LOG_DIR:
179
- for level in LoggerFactory.levels:
180
- if level >= LoggerFactory.LEVEL:
181
- level_logger_name = logging._levelToName[level]
182
- logger.addHandler(
183
- LoggerFactory.get_global_handler(
184
- level_logger_name, level))
185
- if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
186
- for level in LoggerFactory.levels:
187
- if level >= LoggerFactory.LEVEL:
188
- level_logger_name = logging._levelToName[level]
189
- logger.addHandler(
190
- LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))
191
-
192
-
193
- def setDirectory(directory=None):
194
- LoggerFactory.set_directory(directory)
195
-
196
-
197
- def setLevel(level):
198
- LoggerFactory.LEVEL = level
199
-
200
-
201
- def getLogger(className=None, useLevelFile=False):
202
- if className is None:
203
- frame = inspect.stack()[1]
204
- module = inspect.getmodule(frame[0])
205
- className = 'stat'
206
- return LoggerFactory.get_logger(className)
207
-
208
-
209
- def exception_to_trace_string(ex):
210
- return "".join(traceback.TracebackException.from_exception(ex).format())
211
-
212
-
213
- class ROpenHandler(TimedRotatingFileHandler):
214
- def _open(self):
215
- prevumask = os.umask(000)
216
- rtv = TimedRotatingFileHandler._open(self)
217
- os.umask(prevumask)
218
- return rtv
219
-
220
-
221
- def sql_logger(job_id='', log_type='sql'):
222
- key = job_id + log_type
223
- if key in LoggerFactory.schedule_logger_dict.keys():
224
- return LoggerFactory.schedule_logger_dict[key]
225
- return get_job_logger(job_id=job_id, log_type=log_type)
226
-
227
-
228
- def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
229
- prefix, suffix = base_msg(job, task, role, party_id, detail)
230
- return f"{prefix}{msg} ready{suffix}"
231
-
232
-
233
- def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
234
- prefix, suffix = base_msg(job, task, role, party_id, detail)
235
- return f"{prefix}start to {msg}{suffix}"
236
-
237
-
238
- def successful_log(msg, job=None, task=None, role=None,
239
- party_id=None, detail=None):
240
- prefix, suffix = base_msg(job, task, role, party_id, detail)
241
- return f"{prefix}{msg} successfully{suffix}"
242
-
243
-
244
- def warning_log(msg, job=None, task=None, role=None,
245
- party_id=None, detail=None):
246
- prefix, suffix = base_msg(job, task, role, party_id, detail)
247
- return f"{prefix}{msg} is not effective{suffix}"
248
-
249
-
250
- def failed_log(msg, job=None, task=None, role=None,
251
- party_id=None, detail=None):
252
- prefix, suffix = base_msg(job, task, role, party_id, detail)
253
- return f"{prefix}failed to {msg}{suffix}"
254
-
255
-
256
- def base_msg(job=None, task=None, role: str = None,
257
- party_id: typing.Union[str, int] = None, detail=None):
258
- if detail:
259
- detail_msg = f" detail: \n{detail}"
260
- else:
261
- detail_msg = ""
262
- if task is not None:
263
- return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
264
- elif job is not None:
265
- return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
266
- elif role and party_id:
267
- return "", f" on {role} {party_id}{detail_msg}"
268
- else:
269
- return "", f"{detail_msg}"
270
-
271
-
272
- def exception_to_trace_string(ex):
273
- return "".join(traceback.TracebackException.from_exception(ex).format())
274
-
275
-
276
- def get_logger_base_dir():
277
- job_log_dir = file_utils.get_rag_flow_directory('logs')
278
- return job_log_dir
279
-
280
-
281
- def get_job_logger(job_id, log_type):
282
- rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
283
- job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
284
- if not job_id:
285
- log_dirs = [rag_flow_log_dir]
286
- else:
287
- if log_type == 'audit':
288
- log_dirs = [job_log_dir, rag_flow_log_dir]
289
- else:
290
- log_dirs = [job_log_dir]
291
- if LoggerFactory.log_share:
292
- oldmask = os.umask(000)
293
- os.makedirs(job_log_dir, exist_ok=True)
294
- os.makedirs(rag_flow_log_dir, exist_ok=True)
295
- os.umask(oldmask)
296
- else:
297
- os.makedirs(job_log_dir, exist_ok=True)
298
- os.makedirs(rag_flow_log_dir, exist_ok=True)
299
- logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
300
- for job_log_dir in log_dirs:
301
- handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
302
- log_dir=job_log_dir, log_type=log_type, job_id=job_id)
303
- error_handler = LoggerFactory.get_handler(
304
- class_name=None,
305
- level=logging.ERROR,
306
- log_dir=job_log_dir,
307
- log_type=log_type,
308
- job_id=job_id)
309
- logger.addHandler(handler)
310
- logger.addHandler(error_handler)
311
- with LoggerFactory.lock:
312
- LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
313
- return logger
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import os
17
+ import typing
18
+ import traceback
19
+ import logging
20
+ import inspect
21
+ from logging.handlers import TimedRotatingFileHandler
22
+ from threading import RLock
23
+
24
+ from api.utils import file_utils
25
+
26
+
27
+ class LoggerFactory(object):
28
+ TYPE = "FILE"
29
+ LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
30
+ logging.basicConfig(format=LOG_FORMAT)
31
+ LEVEL = logging.DEBUG
32
+ logger_dict = {}
33
+ global_handler_dict = {}
34
+
35
+ LOG_DIR = None
36
+ PARENT_LOG_DIR = None
37
+ log_share = True
38
+
39
+ append_to_parent_log = None
40
+
41
+ lock = RLock()
42
+ # CRITICAL = 50
43
+ # FATAL = CRITICAL
44
+ # ERROR = 40
45
+ # WARNING = 30
46
+ # WARN = WARNING
47
+ # INFO = 20
48
+ # DEBUG = 10
49
+ # NOTSET = 0
50
+ levels = (10, 20, 30, 40)
51
+ schedule_logger_dict = {}
52
+
53
+ @staticmethod
54
+ def set_directory(directory=None, parent_log_dir=None,
55
+ append_to_parent_log=None, force=False):
56
+ if parent_log_dir:
57
+ LoggerFactory.PARENT_LOG_DIR = parent_log_dir
58
+ if append_to_parent_log:
59
+ LoggerFactory.append_to_parent_log = append_to_parent_log
60
+ with LoggerFactory.lock:
61
+ if not directory:
62
+ directory = file_utils.get_project_base_directory("logs")
63
+ if not LoggerFactory.LOG_DIR or force:
64
+ LoggerFactory.LOG_DIR = directory
65
+ if LoggerFactory.log_share:
66
+ oldmask = os.umask(000)
67
+ os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
68
+ os.umask(oldmask)
69
+ else:
70
+ os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
71
+ for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
72
+ for className, (logger,
73
+ handler) in LoggerFactory.logger_dict.items():
74
+ logger.removeHandler(ghandler)
75
+ ghandler.close()
76
+ LoggerFactory.global_handler_dict = {}
77
+ for className, (logger,
78
+ handler) in LoggerFactory.logger_dict.items():
79
+ logger.removeHandler(handler)
80
+ _handler = None
81
+ if handler:
82
+ handler.close()
83
+ if className != "default":
84
+ _handler = LoggerFactory.get_handler(className)
85
+ logger.addHandler(_handler)
86
+ LoggerFactory.assemble_global_handler(logger)
87
+ LoggerFactory.logger_dict[className] = logger, _handler
88
+
89
+ @staticmethod
90
+ def new_logger(name):
91
+ logger = logging.getLogger(name)
92
+ logger.propagate = False
93
+ logger.setLevel(LoggerFactory.LEVEL)
94
+ return logger
95
+
96
+ @staticmethod
97
+ def get_logger(class_name=None):
98
+ with LoggerFactory.lock:
99
+ if class_name in LoggerFactory.logger_dict.keys():
100
+ logger, handler = LoggerFactory.logger_dict[class_name]
101
+ if not logger:
102
+ logger, handler = LoggerFactory.init_logger(class_name)
103
+ else:
104
+ logger, handler = LoggerFactory.init_logger(class_name)
105
+ return logger
106
+
107
+ @staticmethod
108
+ def get_global_handler(logger_name, level=None, log_dir=None):
109
+ if not LoggerFactory.LOG_DIR:
110
+ return logging.StreamHandler()
111
+ if log_dir:
112
+ logger_name_key = logger_name + "_" + log_dir
113
+ else:
114
+ logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
115
+ # if loggerName not in LoggerFactory.globalHandlerDict:
116
+ if logger_name_key not in LoggerFactory.global_handler_dict:
117
+ with LoggerFactory.lock:
118
+ if logger_name_key not in LoggerFactory.global_handler_dict:
119
+ handler = LoggerFactory.get_handler(
120
+ logger_name, level, log_dir)
121
+ LoggerFactory.global_handler_dict[logger_name_key] = handler
122
+ return LoggerFactory.global_handler_dict[logger_name_key]
123
+
124
+ @staticmethod
125
+ def get_handler(class_name, level=None, log_dir=None,
126
+ log_type=None, job_id=None):
127
+ if not log_type:
128
+ if not LoggerFactory.LOG_DIR or not class_name:
129
+ return logging.StreamHandler()
130
+ # return Diy_StreamHandler()
131
+
132
+ if not log_dir:
133
+ log_file = os.path.join(
134
+ LoggerFactory.LOG_DIR,
135
+ "{}.log".format(class_name))
136
+ else:
137
+ log_file = os.path.join(log_dir, "{}.log".format(class_name))
138
+ else:
139
+ log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
140
+ log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))
141
+
142
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
143
+ if LoggerFactory.log_share:
144
+ handler = ROpenHandler(log_file,
145
+ when='D',
146
+ interval=1,
147
+ backupCount=14,
148
+ delay=True)
149
+ else:
150
+ handler = TimedRotatingFileHandler(log_file,
151
+ when='D',
152
+ interval=1,
153
+ backupCount=14,
154
+ delay=True)
155
+ if level:
156
+ handler.level = level
157
+
158
+ return handler
159
+
160
+ @staticmethod
161
+ def init_logger(class_name):
162
+ with LoggerFactory.lock:
163
+ logger = LoggerFactory.new_logger(class_name)
164
+ handler = None
165
+ if class_name:
166
+ handler = LoggerFactory.get_handler(class_name)
167
+ logger.addHandler(handler)
168
+ LoggerFactory.logger_dict[class_name] = logger, handler
169
+
170
+ else:
171
+ LoggerFactory.logger_dict["default"] = logger, handler
172
+
173
+ LoggerFactory.assemble_global_handler(logger)
174
+ return logger, handler
175
+
176
+ @staticmethod
177
+ def assemble_global_handler(logger):
178
+ if LoggerFactory.LOG_DIR:
179
+ for level in LoggerFactory.levels:
180
+ if level >= LoggerFactory.LEVEL:
181
+ level_logger_name = logging._levelToName[level]
182
+ logger.addHandler(
183
+ LoggerFactory.get_global_handler(
184
+ level_logger_name, level))
185
+ if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
186
+ for level in LoggerFactory.levels:
187
+ if level >= LoggerFactory.LEVEL:
188
+ level_logger_name = logging._levelToName[level]
189
+ logger.addHandler(
190
+ LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))
191
+
192
+
193
+ def setDirectory(directory=None):
194
+ LoggerFactory.set_directory(directory)
195
+
196
+
197
+ def setLevel(level):
198
+ LoggerFactory.LEVEL = level
199
+
200
+
201
+ def getLogger(className=None, useLevelFile=False):
202
+ if className is None:
203
+ frame = inspect.stack()[1]
204
+ module = inspect.getmodule(frame[0])
205
+ className = 'stat'
206
+ return LoggerFactory.get_logger(className)
207
+
208
+
209
+ def exception_to_trace_string(ex):
210
+ return "".join(traceback.TracebackException.from_exception(ex).format())
211
+
212
+
213
+ class ROpenHandler(TimedRotatingFileHandler):
214
+ def _open(self):
215
+ prevumask = os.umask(000)
216
+ rtv = TimedRotatingFileHandler._open(self)
217
+ os.umask(prevumask)
218
+ return rtv
219
+
220
+
221
+ def sql_logger(job_id='', log_type='sql'):
222
+ key = job_id + log_type
223
+ if key in LoggerFactory.schedule_logger_dict.keys():
224
+ return LoggerFactory.schedule_logger_dict[key]
225
+ return get_job_logger(job_id=job_id, log_type=log_type)
226
+
227
+
228
+ def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
229
+ prefix, suffix = base_msg(job, task, role, party_id, detail)
230
+ return f"{prefix}{msg} ready{suffix}"
231
+
232
+
233
+ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
234
+ prefix, suffix = base_msg(job, task, role, party_id, detail)
235
+ return f"{prefix}start to {msg}{suffix}"
236
+
237
+
238
+ def successful_log(msg, job=None, task=None, role=None,
239
+ party_id=None, detail=None):
240
+ prefix, suffix = base_msg(job, task, role, party_id, detail)
241
+ return f"{prefix}{msg} successfully{suffix}"
242
+
243
+
244
+ def warning_log(msg, job=None, task=None, role=None,
245
+ party_id=None, detail=None):
246
+ prefix, suffix = base_msg(job, task, role, party_id, detail)
247
+ return f"{prefix}{msg} is not effective{suffix}"
248
+
249
+
250
+ def failed_log(msg, job=None, task=None, role=None,
251
+ party_id=None, detail=None):
252
+ prefix, suffix = base_msg(job, task, role, party_id, detail)
253
+ return f"{prefix}failed to {msg}{suffix}"
254
+
255
+
256
+ def base_msg(job=None, task=None, role: str = None,
257
+ party_id: typing.Union[str, int] = None, detail=None):
258
+ if detail:
259
+ detail_msg = f" detail: \n{detail}"
260
+ else:
261
+ detail_msg = ""
262
+ if task is not None:
263
+ return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
264
+ elif job is not None:
265
+ return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
266
+ elif role and party_id:
267
+ return "", f" on {role} {party_id}{detail_msg}"
268
+ else:
269
+ return "", f"{detail_msg}"
270
+
271
+
272
+ def exception_to_trace_string(ex):
273
+ return "".join(traceback.TracebackException.from_exception(ex).format())
274
+
275
+
276
+ def get_logger_base_dir():
277
+ job_log_dir = file_utils.get_rag_flow_directory('logs')
278
+ return job_log_dir
279
+
280
+
281
+ def get_job_logger(job_id, log_type):
282
+ rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
283
+ job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
284
+ if not job_id:
285
+ log_dirs = [rag_flow_log_dir]
286
+ else:
287
+ if log_type == 'audit':
288
+ log_dirs = [job_log_dir, rag_flow_log_dir]
289
+ else:
290
+ log_dirs = [job_log_dir]
291
+ if LoggerFactory.log_share:
292
+ oldmask = os.umask(000)
293
+ os.makedirs(job_log_dir, exist_ok=True)
294
+ os.makedirs(rag_flow_log_dir, exist_ok=True)
295
+ os.umask(oldmask)
296
+ else:
297
+ os.makedirs(job_log_dir, exist_ok=True)
298
+ os.makedirs(rag_flow_log_dir, exist_ok=True)
299
+ logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
300
+ for job_log_dir in log_dirs:
301
+ handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
302
+ log_dir=job_log_dir, log_type=log_type, job_id=job_id)
303
+ error_handler = LoggerFactory.get_handler(
304
+ class_name=None,
305
+ level=logging.ERROR,
306
+ log_dir=job_log_dir,
307
+ log_type=log_type,
308
+ job_id=job_id)
309
+ logger.addHandler(handler)
310
+ logger.addHandler(error_handler)
311
+ with LoggerFactory.lock:
312
+ LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
313
+ return logger
api/utils/t_crypt.py CHANGED
@@ -1,24 +1,24 @@
1
- import base64
2
- import os
3
- import sys
4
- from Cryptodome.PublicKey import RSA
5
- from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
6
- from api.utils import decrypt, file_utils
7
-
8
-
9
- def crypt(line):
10
- file_path = os.path.join(
11
- file_utils.get_project_base_directory(),
12
- "conf",
13
- "public.pem")
14
- rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
15
- cipher = Cipher_pkcs1_v1_5.new(rsa_key)
16
- password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
17
- encrypted_password = cipher.encrypt(password_base64.encode())
18
- return base64.b64encode(encrypted_password).decode('utf-8')
19
-
20
-
21
- if __name__ == "__main__":
22
- pswd = crypt(sys.argv[1])
23
- print(pswd)
24
- print(decrypt(pswd))
 
1
+ import base64
2
+ import os
3
+ import sys
4
+ from Cryptodome.PublicKey import RSA
5
+ from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
6
+ from api.utils import decrypt, file_utils
7
+
8
+
9
+ def crypt(line):
10
+ file_path = os.path.join(
11
+ file_utils.get_project_base_directory(),
12
+ "conf",
13
+ "public.pem")
14
+ rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
15
+ cipher = Cipher_pkcs1_v1_5.new(rsa_key)
16
+ password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
17
+ encrypted_password = cipher.encrypt(password_base64.encode())
18
+ return base64.b64encode(encrypted_password).decode('utf-8')
19
+
20
+
21
+ if __name__ == "__main__":
22
+ pswd = crypt(sys.argv[1])
23
+ print(pswd)
24
+ print(decrypt(pswd))
api/versions.py CHANGED
@@ -1,28 +1,28 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import os
17
- import dotenv
18
- import typing
19
- from api.utils.file_utils import get_project_base_directory
20
-
21
-
22
- def get_versions() -> typing.Mapping[str, typing.Any]:
23
- dotenv.load_dotenv(dotenv.find_dotenv())
24
- return dotenv.dotenv_values()
25
-
26
-
27
- def get_rag_version() -> typing.Optional[str]:
28
  return get_versions().get("RAGFLOW_VERSION", "dev")
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import os
17
+ import dotenv
18
+ import typing
19
+ from api.utils.file_utils import get_project_base_directory
20
+
21
+
22
+ def get_versions() -> typing.Mapping[str, typing.Any]:
23
+ dotenv.load_dotenv(dotenv.find_dotenv())
24
+ return dotenv.dotenv_values()
25
+
26
+
27
+ def get_rag_version() -> typing.Optional[str]:
28
  return get_versions().get("RAGFLOW_VERSION", "dev")
conf/service_conf.yaml CHANGED
@@ -1,49 +1,49 @@
1
- ragflow:
2
- host: 0.0.0.0
3
- http_port: 9380
4
- mysql:
5
- name: 'rag_flow'
6
- user: 'root'
7
- password: 'infini_rag_flow'
8
- host: 'mysql'
9
- port: 3306
10
- max_connections: 100
11
- stale_timeout: 30
12
- minio:
13
- user: 'rag_flow'
14
- password: 'infini_rag_flow'
15
- host: 'minio:9000'
16
- es:
17
- hosts: 'http://es01:9200'
18
- username: 'elastic'
19
- password: 'infini_rag_flow'
20
- redis:
21
- db: 1
22
- password: 'infini_rag_flow'
23
- host: 'redis:6379'
24
- user_default_llm:
25
- factory: 'Tongyi-Qianwen'
26
- api_key: 'sk-xxxxxxxxxxxxx'
27
- base_url: ''
28
- oauth:
29
- github:
30
- client_id: xxxxxxxxxxxxxxxxxxxxxxxxx
31
- secret_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxx
32
- url: https://github.com/login/oauth/access_token
33
- feishu:
34
- app_id: cli_xxxxxxxxxxxxxxxxxxx
35
- app_secret: xxxxxxxxxxxxxxxxxxxxxxxxxxxx
36
- app_access_token_url: https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal
37
- user_access_token_url: https://open.feishu.cn/open-apis/authen/v1/oidc/access_token
38
- grant_type: 'authorization_code'
39
- authentication:
40
- client:
41
- switch: false
42
- http_app_key:
43
- http_secret_key:
44
- site:
45
- switch: false
46
- permission:
47
- switch: false
48
- component: false
49
- dataset: false
 
1
+ ragflow:
2
+ host: 0.0.0.0
3
+ http_port: 9380
4
+ mysql:
5
+ name: 'rag_flow'
6
+ user: 'root'
7
+ password: 'infini_rag_flow'
8
+ host: 'mysql'
9
+ port: 3306
10
+ max_connections: 100
11
+ stale_timeout: 30
12
+ minio:
13
+ user: 'rag_flow'
14
+ password: 'infini_rag_flow'
15
+ host: 'minio:9000'
16
+ es:
17
+ hosts: 'http://es01:9200'
18
+ username: 'elastic'
19
+ password: 'infini_rag_flow'
20
+ redis:
21
+ db: 1
22
+ password: 'infini_rag_flow'
23
+ host: 'redis:6379'
24
+ user_default_llm:
25
+ factory: 'Tongyi-Qianwen'
26
+ api_key: 'sk-xxxxxxxxxxxxx'
27
+ base_url: ''
28
+ oauth:
29
+ github:
30
+ client_id: xxxxxxxxxxxxxxxxxxxxxxxxx
31
+ secret_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxx
32
+ url: https://github.com/login/oauth/access_token
33
+ feishu:
34
+ app_id: cli_xxxxxxxxxxxxxxxxxxx
35
+ app_secret: xxxxxxxxxxxxxxxxxxxxxxxxxxxx
36
+ app_access_token_url: https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal
37
+ user_access_token_url: https://open.feishu.cn/open-apis/authen/v1/oidc/access_token
38
+ grant_type: 'authorization_code'
39
+ authentication:
40
+ client:
41
+ switch: false
42
+ http_app_key:
43
+ http_secret_key:
44
+ site:
45
+ switch: false
46
+ permission:
47
+ switch: false
48
+ component: false
49
+ dataset: false
deepdoc/README.md CHANGED
@@ -1,122 +1,122 @@
1
- English | [简体中文](./README_zh.md)
2
-
3
- # *Deep*Doc
4
-
5
- - [1. Introduction](#1)
6
- - [2. Vision](#2)
7
- - [3. Parser](#3)
8
-
9
- <a name="1"></a>
10
- ## 1. Introduction
11
-
12
- With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
13
- an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
14
- There are 2 parts in *Deep*Doc so far: vision and parser.
15
- You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
16
- ```bash
17
- python deepdoc/vision/t_ocr.py -h
18
- usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
19
-
20
- options:
21
- -h, --help show this help message and exit
22
- --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
23
- --output_dir OUTPUT_DIR
24
- Directory where to store the output images. Default: './ocr_outputs'
25
- ```
26
- ```bash
27
- python deepdoc/vision/t_recognizer.py -h
28
- usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
29
-
30
- options:
31
- -h, --help show this help message and exit
32
- --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
33
- --output_dir OUTPUT_DIR
34
- Directory where to store the output images. Default: './layouts_outputs'
35
- --threshold THRESHOLD
36
- A threshold to filter out detections. Default: 0.5
37
- --mode {layout,tsr} Task mode: layout recognition or table structure recognition
38
- ```
39
-
40
- Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
41
- ```bash
42
- export HF_ENDPOINT=https://hf-mirror.com
43
- ```
44
-
45
- <a name="2"></a>
46
- ## 2. Vision
47
-
48
- We use vision information to resolve problems as human being.
49
- - OCR. Since a lot of documents presented as images or at least be able to transform to image,
50
- OCR is a very essential and fundamental or even universal solution for text extraction.
51
- ```bash
52
- python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
53
- ```
54
- The inputs could be directory to images or PDF, or a image or PDF.
55
- You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
56
- txt files which contain the OCR text.
57
- <div align="center" style="margin-top:20px;margin-bottom:20px;">
58
- <img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
59
- </div>
60
-
61
- - Layout recognition. Documents from different domain may have various layouts,
62
- like, newspaper, magazine, book and résumé are distinct in terms of layout.
63
- Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
64
- or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
65
- We have 10 basic layout components which covers most cases:
66
- - Text
67
- - Title
68
- - Figure
69
- - Figure caption
70
- - Table
71
- - Table caption
72
- - Header
73
- - Footer
74
- - Reference
75
- - Equation
76
-
77
- Have a try on the following command to see the layout detection results.
78
- ```bash
79
- python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
80
- ```
81
- The inputs could be directory to images or PDF, or a image or PDF.
82
- You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
83
- <div align="center" style="margin-top:20px;margin-bottom:20px;">
84
- <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
85
- </div>
86
-
87
- - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
88
- And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
89
- Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
90
- We have five labels for TSR task:
91
- - Column
92
- - Row
93
- - Column header
94
- - Projected row header
95
- - Spanning cell
96
-
97
- Have a try on the following command to see the layout detection results.
98
- ```bash
99
- python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
100
- ```
101
- The inputs could be directory to images or PDF, or a image or PDF.
102
- You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
103
- <div align="center" style="margin-top:20px;margin-bottom:20px;">
104
- <img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
105
- </div>
106
-
107
- <a name="3"></a>
108
- ## 3. Parser
109
-
110
- Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
111
- The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
112
- - Text chunks with their own positions in PDF(page number and rectangular positions).
113
- - Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
114
- - Figures with caption and text in the figures.
115
-
116
- ### Résumé
117
-
118
- The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
119
- with various layouts could be resolved into structured data composed of nearly a hundred of fields.
120
- We haven't opened the parser yet, as we open the processing method after parsing procedure.
121
-
122
 
 
1
+ English | [简体中文](./README_zh.md)
2
+
3
+ # *Deep*Doc
4
+
5
+ - [1. Introduction](#1)
6
+ - [2. Vision](#2)
7
+ - [3. Parser](#3)
8
+
9
+ <a name="1"></a>
10
+ ## 1. Introduction
11
+
12
+ With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
13
+ an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
14
+ There are 2 parts in *Deep*Doc so far: vision and parser.
15
+ You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
16
+ ```bash
17
+ python deepdoc/vision/t_ocr.py -h
18
+ usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
19
+
20
+ options:
21
+ -h, --help show this help message and exit
22
+ --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
23
+ --output_dir OUTPUT_DIR
24
+ Directory where to store the output images. Default: './ocr_outputs'
25
+ ```
26
+ ```bash
27
+ python deepdoc/vision/t_recognizer.py -h
28
+ usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
29
+
30
+ options:
31
+ -h, --help show this help message and exit
32
+ --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
33
+ --output_dir OUTPUT_DIR
34
+ Directory where to store the output images. Default: './layouts_outputs'
35
+ --threshold THRESHOLD
36
+ A threshold to filter out detections. Default: 0.5
37
+ --mode {layout,tsr} Task mode: layout recognition or table structure recognition
38
+ ```
39
+
40
+ Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
41
+ ```bash
42
+ export HF_ENDPOINT=https://hf-mirror.com
43
+ ```
44
+
45
+ <a name="2"></a>
46
+ ## 2. Vision
47
+
48
+ We use vision information to resolve problems as human being.
49
+ - OCR. Since a lot of documents presented as images or at least be able to transform to image,
50
+ OCR is a very essential and fundamental or even universal solution for text extraction.
51
+ ```bash
52
+ python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
53
+ ```
54
+ The inputs could be directory to images or PDF, or a image or PDF.
55
+ You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
56
+ txt files which contain the OCR text.
57
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
58
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
59
+ </div>
60
+
61
+ - Layout recognition. Documents from different domain may have various layouts,
62
+ like, newspaper, magazine, book and résumé are distinct in terms of layout.
63
+ Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
64
+ or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
65
+ We have 10 basic layout components which covers most cases:
66
+ - Text
67
+ - Title
68
+ - Figure
69
+ - Figure caption
70
+ - Table
71
+ - Table caption
72
+ - Header
73
+ - Footer
74
+ - Reference
75
+ - Equation
76
+
77
+ Have a try on the following command to see the layout detection results.
78
+ ```bash
79
+ python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
80
+ ```
81
+ The inputs could be directory to images or PDF, or a image or PDF.
82
+ You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
83
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
84
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
85
+ </div>
86
+
87
+ - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
88
+ And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
89
+ Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
90
+ We have five labels for TSR task:
91
+ - Column
92
+ - Row
93
+ - Column header
94
+ - Projected row header
95
+ - Spanning cell
96
+
97
+ Have a try on the following command to see the layout detection results.
98
+ ```bash
99
+ python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
100
+ ```
101
+ The inputs could be directory to images or PDF, or a image or PDF.
102
+ You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
103
+ <div align="center" style="margin-top:20px;margin-bottom:20px;">
104
+ <img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
105
+ </div>
106
+
107
+ <a name="3"></a>
108
+ ## 3. Parser
109
+
110
+ Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
111
+ The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
112
+ - Text chunks with their own positions in PDF(page number and rectangular positions).
113
+ - Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
114
+ - Figures with caption and text in the figures.
115
+
116
+ ### Résumé
117
+
118
+ The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
119
+ with various layouts could be resolved into structured data composed of nearly a hundred of fields.
120
+ We haven't opened the parser yet, as we open the processing method after parsing procedure.
121
+
122
 
deepdoc/parser/ppt_parser.py CHANGED
@@ -1,61 +1,61 @@
1
- # Licensed under the Apache License, Version 2.0 (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
- #
13
-
14
- from io import BytesIO
15
- from pptx import Presentation
16
-
17
-
18
- class RAGFlowPptParser(object):
19
- def __init__(self):
20
- super().__init__()
21
-
22
- def __extract(self, shape):
23
- if shape.shape_type == 19:
24
- tb = shape.table
25
- rows = []
26
- for i in range(1, len(tb.rows)):
27
- rows.append("; ".join([tb.cell(
28
- 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
29
- return "\n".join(rows)
30
-
31
- if shape.has_text_frame:
32
- return shape.text_frame.text
33
-
34
- if shape.shape_type == 6:
35
- texts = []
36
- for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
37
- t = self.__extract(p)
38
- if t:
39
- texts.append(t)
40
- return "\n".join(texts)
41
-
42
- def __call__(self, fnm, from_page, to_page, callback=None):
43
- ppt = Presentation(fnm) if isinstance(
44
- fnm, str) else Presentation(
45
- BytesIO(fnm))
46
- txts = []
47
- self.total_page = len(ppt.slides)
48
- for i, slide in enumerate(ppt.slides):
49
- if i < from_page:
50
- continue
51
- if i >= to_page:
52
- break
53
- texts = []
54
- for shape in sorted(
55
- slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left)):
56
- txt = self.__extract(shape)
57
- if txt:
58
- texts.append(txt)
59
- txts.append("\n".join(texts))
60
-
61
- return txts
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ #
13
+
14
+ from io import BytesIO
15
+ from pptx import Presentation
16
+
17
+
18
+ class RAGFlowPptParser(object):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def __extract(self, shape):
23
+ if shape.shape_type == 19:
24
+ tb = shape.table
25
+ rows = []
26
+ for i in range(1, len(tb.rows)):
27
+ rows.append("; ".join([tb.cell(
28
+ 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
29
+ return "\n".join(rows)
30
+
31
+ if shape.has_text_frame:
32
+ return shape.text_frame.text
33
+
34
+ if shape.shape_type == 6:
35
+ texts = []
36
+ for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
37
+ t = self.__extract(p)
38
+ if t:
39
+ texts.append(t)
40
+ return "\n".join(texts)
41
+
42
+ def __call__(self, fnm, from_page, to_page, callback=None):
43
+ ppt = Presentation(fnm) if isinstance(
44
+ fnm, str) else Presentation(
45
+ BytesIO(fnm))
46
+ txts = []
47
+ self.total_page = len(ppt.slides)
48
+ for i, slide in enumerate(ppt.slides):
49
+ if i < from_page:
50
+ continue
51
+ if i >= to_page:
52
+ break
53
+ texts = []
54
+ for shape in sorted(
55
+ slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left)):
56
+ txt = self.__extract(shape)
57
+ if txt:
58
+ texts.append(txt)
59
+ txts.append("\n".join(texts))
60
+
61
+ return txts
deepdoc/parser/resume/__init__.py CHANGED
@@ -1,65 +1,65 @@
1
- # Licensed under the Apache License, Version 2.0 (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
- #
13
-
14
- import datetime
15
-
16
-
17
- def refactor(cv):
18
- for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
19
- if n in cv and cv[n] is not None: del cv[n]
20
- cv["is_deleted"] = 0
21
- if "basic" not in cv: cv["basic"] = {}
22
- if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
23
-
24
- for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
25
- if n not in cv or cv[n] is None: continue
26
- if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
27
- if type(cv[n]) != type([]):
28
- del cv[n]
29
- continue
30
- vv = []
31
- for v in cv[n]:
32
- if "external" in v and v["external"] is not None: del v["external"]
33
- vv.append(v)
34
- cv[n] = {str(i): vv[i] for i in range(len(vv))}
35
-
36
- basics = [
37
- ("basic_salary_month", "salary_month"),
38
- ("expect_annual_salary_from", "expect_annual_salary"),
39
- ]
40
- for n, t in basics:
41
- if cv["basic"].get(n):
42
- cv["basic"][t] = cv["basic"][n]
43
- del cv["basic"][n]
44
-
45
- work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
46
- edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
47
-
48
- if work:
49
- cv["basic"]["work_start_time"] = work[0].get("start_time", "")
50
- cv["basic"]["management_experience"] = 'Y' if any(
51
- [w.get("management_experience", '') == 'Y' for w in work]) else 'N'
52
- cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
53
-
54
- for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
55
- "corporation_type", "scale", "corporation_name"]:
56
- cv["basic"][n] = work[-1].get(n, "")
57
-
58
- if edu:
59
- for n in ["school_name", "discipline_name"]:
60
- if n in edu[-1]: cv["basic"][n] = edu[-1][n]
61
-
62
- cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
63
- if "contact" not in cv: cv["contact"] = {}
64
- if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
65
  return cv
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ #
13
+
14
+ import datetime
15
+
16
+
17
+ def refactor(cv):
18
+ for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
19
+ if n in cv and cv[n] is not None: del cv[n]
20
+ cv["is_deleted"] = 0
21
+ if "basic" not in cv: cv["basic"] = {}
22
+ if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
23
+
24
+ for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
25
+ if n not in cv or cv[n] is None: continue
26
+ if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
27
+ if type(cv[n]) != type([]):
28
+ del cv[n]
29
+ continue
30
+ vv = []
31
+ for v in cv[n]:
32
+ if "external" in v and v["external"] is not None: del v["external"]
33
+ vv.append(v)
34
+ cv[n] = {str(i): vv[i] for i in range(len(vv))}
35
+
36
+ basics = [
37
+ ("basic_salary_month", "salary_month"),
38
+ ("expect_annual_salary_from", "expect_annual_salary"),
39
+ ]
40
+ for n, t in basics:
41
+ if cv["basic"].get(n):
42
+ cv["basic"][t] = cv["basic"][n]
43
+ del cv["basic"][n]
44
+
45
+ work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
46
+ edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
47
+
48
+ if work:
49
+ cv["basic"]["work_start_time"] = work[0].get("start_time", "")
50
+ cv["basic"]["management_experience"] = 'Y' if any(
51
+ [w.get("management_experience", '') == 'Y' for w in work]) else 'N'
52
+ cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")
53
+
54
+ for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
55
+ "corporation_type", "scale", "corporation_name"]:
56
+ cv["basic"][n] = work[-1].get(n, "")
57
+
58
+ if edu:
59
+ for n in ["school_name", "discipline_name"]:
60
+ if n in edu[-1]: cv["basic"][n] = edu[-1][n]
61
+
62
+ cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
63
+ if "contact" not in cv: cv["contact"] = {}
64
+ if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
65
  return cv