cecilia-uu
commited on
Commit
·
090b2e7
1
Parent(s):
2ef3748
create list_dataset api and tests (#1138)
Browse files### What problem does this PR solve?
This PR have completed both HTTP API and Python SDK for 'list_dataset".
In addition, there are tests for it.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/dataset_api.py +20 -6
- api/db/services/knowledgebase_service.py +23 -0
- sdk/python/ragflow/__init__.py +2 -0
- sdk/python/ragflow/dataset.py +1 -1
- sdk/python/ragflow/ragflow.py +36 -10
- sdk/python/test/common.py +1 -1
- sdk/python/test/test_dataset.py +87 -7
api/apps/dataset_api.py
CHANGED
@@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT
|
|
46 |
|
47 |
# ------------------------------ create a dataset ---------------------------------------
|
48 |
@manager.route('/', methods=['POST'])
|
49 |
-
@login_required
|
50 |
@validate_request("name") # check name key
|
51 |
def create_dataset():
|
52 |
# Check if Authorization header is present
|
@@ -111,10 +111,27 @@ def create_dataset():
|
|
111 |
if not KnowledgebaseService.save(**request_body):
|
112 |
# failed to create new dataset
|
113 |
return construct_result()
|
114 |
-
return construct_json_result(data={"
|
115 |
except Exception as e:
|
116 |
return construct_error_response(e)
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@manager.route('/<dataset_id>', methods=['DELETE'])
|
120 |
@login_required
|
@@ -135,8 +152,5 @@ def get_dataset(dataset_id):
|
|
135 |
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
|
136 |
|
137 |
|
138 |
-
|
139 |
-
@login_required
|
140 |
-
def list_datasets():
|
141 |
-
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
142 |
|
|
|
46 |
|
47 |
# ------------------------------ create a dataset ---------------------------------------
|
48 |
@manager.route('/', methods=['POST'])
|
49 |
+
@login_required # use login
|
50 |
@validate_request("name") # check name key
|
51 |
def create_dataset():
|
52 |
# Check if Authorization header is present
|
|
|
111 |
if not KnowledgebaseService.save(**request_body):
|
112 |
# failed to create new dataset
|
113 |
return construct_result()
|
114 |
+
return construct_json_result(data={"dataset_name": request_body["name"]})
|
115 |
except Exception as e:
|
116 |
return construct_error_response(e)
|
117 |
|
118 |
+
# -----------------------------list datasets-------------------------------------------------------
|
119 |
+
@manager.route('/', methods=['GET'])
|
120 |
+
@login_required
|
121 |
+
def list_datasets():
|
122 |
+
offset = request.args.get("offset", 0)
|
123 |
+
count = request.args.get("count", -1)
|
124 |
+
orderby = request.args.get("orderby", "create_time")
|
125 |
+
desc = request.args.get("desc", True)
|
126 |
+
try:
|
127 |
+
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
128 |
+
kbs = KnowledgebaseService.get_by_tenant_ids(
|
129 |
+
[m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc)
|
130 |
+
return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
131 |
+
except Exception as e:
|
132 |
+
return construct_error_response(e)
|
133 |
+
|
134 |
+
# ---------------------------------delete a dataset ----------------------------
|
135 |
|
136 |
@manager.route('/<dataset_id>', methods=['DELETE'])
|
137 |
@login_required
|
|
|
152 |
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
|
153 |
|
154 |
|
155 |
+
|
|
|
|
|
|
|
156 |
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService):
|
|
40 |
|
41 |
return list(kbs.dicts())
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
@classmethod
|
44 |
@DB.connection_context()
|
45 |
def get_detail(cls, kb_id):
|
|
|
40 |
|
41 |
return list(kbs.dicts())
|
42 |
|
43 |
+
@classmethod
|
44 |
+
@DB.connection_context()
|
45 |
+
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
46 |
+
offset, count, orderby, desc):
|
47 |
+
kbs = cls.model.select().where(
|
48 |
+
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
49 |
+
TenantPermission.TEAM.value)) | (
|
50 |
+
cls.model.tenant_id == user_id))
|
51 |
+
& (cls.model.status == StatusEnum.VALID.value)
|
52 |
+
)
|
53 |
+
if desc:
|
54 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
55 |
+
else:
|
56 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
57 |
+
|
58 |
+
kbs = list(kbs.dicts())
|
59 |
+
|
60 |
+
kbs_length = len(kbs)
|
61 |
+
if offset < 0 or offset > kbs_length:
|
62 |
+
raise IndexError("Offset is out of the valid range.")
|
63 |
+
|
64 |
+
return kbs[offset:offset+count]
|
65 |
+
|
66 |
@classmethod
|
67 |
@DB.connection_context()
|
68 |
def get_detail(cls, kb_id):
|
sdk/python/ragflow/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
import importlib.metadata
|
2 |
|
3 |
__version__ = importlib.metadata.version("ragflow")
|
|
|
|
|
|
1 |
import importlib.metadata
|
2 |
|
3 |
__version__ = importlib.metadata.version("ragflow")
|
4 |
+
|
5 |
+
from .ragflow import RAGFlow
|
sdk/python/ragflow/dataset.py
CHANGED
@@ -18,4 +18,4 @@ class DataSet:
|
|
18 |
self.user_key = user_key
|
19 |
self.dataset_url = dataset_url
|
20 |
self.uuid = uuid
|
21 |
-
self.name = name
|
|
|
18 |
self.user_key = user_key
|
19 |
self.dataset_url = dataset_url
|
20 |
self.uuid = uuid
|
21 |
+
self.name = name
|
sdk/python/ragflow/ragflow.py
CHANGED
@@ -17,7 +17,10 @@ import os
|
|
17 |
import requests
|
18 |
import json
|
19 |
|
20 |
-
|
|
|
|
|
|
|
21 |
def __init__(self, user_key, base_url, version = 'v1'):
|
22 |
'''
|
23 |
api_url: http://<host_address>/api/v1
|
@@ -36,16 +39,39 @@ class RAGFLow:
|
|
36 |
result_dict = json.loads(res.text)
|
37 |
return result_dict
|
38 |
|
39 |
-
def delete_dataset(self, dataset_name
|
40 |
return dataset_name
|
41 |
|
42 |
-
def list_dataset(self):
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def get_dataset(self, dataset_id):
|
51 |
endpoint = f"{self.dataset_url}/{dataset_id}"
|
@@ -61,4 +87,4 @@ class RAGFLow:
|
|
61 |
if response.status_code == 200:
|
62 |
return True
|
63 |
else:
|
64 |
-
return False
|
|
|
17 |
import requests
|
18 |
import json
|
19 |
|
20 |
+
from httpx import HTTPError
|
21 |
+
|
22 |
+
|
23 |
+
class RAGFlow:
|
24 |
def __init__(self, user_key, base_url, version = 'v1'):
|
25 |
'''
|
26 |
api_url: http://<host_address>/api/v1
|
|
|
39 |
result_dict = json.loads(res.text)
|
40 |
return result_dict
|
41 |
|
42 |
+
def delete_dataset(self, dataset_name=None, dataset_id=None):
|
43 |
return dataset_name
|
44 |
|
45 |
+
def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
|
46 |
+
params = {
|
47 |
+
"offset": offset,
|
48 |
+
"count": count,
|
49 |
+
"orderby": orderby,
|
50 |
+
"desc": desc
|
51 |
+
}
|
52 |
+
try:
|
53 |
+
response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
|
54 |
+
response.raise_for_status() # if it is not 200
|
55 |
+
original_data = response.json()
|
56 |
+
# TODO: format the data
|
57 |
+
# print(original_data)
|
58 |
+
# # Process the original data into the desired format
|
59 |
+
# formatted_data = {
|
60 |
+
# "datasets": [
|
61 |
+
# {
|
62 |
+
# "id": dataset["id"],
|
63 |
+
# "created": dataset["create_time"], # Adjust the key based on the actual response
|
64 |
+
# "fileCount": dataset["doc_num"], # Adjust the key based on the actual response
|
65 |
+
# "name": dataset["name"]
|
66 |
+
# }
|
67 |
+
# for dataset in original_data
|
68 |
+
# ]
|
69 |
+
# }
|
70 |
+
return response.status_code, original_data
|
71 |
+
except HTTPError as http_err:
|
72 |
+
print(f"HTTP error occurred: {http_err}")
|
73 |
+
except Exception as err:
|
74 |
+
print(f"An error occurred: {err}")
|
75 |
|
76 |
def get_dataset(self, dataset_id):
|
77 |
endpoint = f"{self.dataset_url}/{dataset_id}"
|
|
|
87 |
if response.status_code == 200:
|
88 |
return True
|
89 |
else:
|
90 |
+
return False
|
sdk/python/test/common.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
|
2 |
|
3 |
-
API_KEY = '
|
4 |
HOST_ADDRESS = 'http://127.0.0.1:9380'
|
|
|
1 |
|
2 |
|
3 |
+
API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E'
|
4 |
HOST_ADDRESS = 'http://127.0.0.1:9380'
|
sdk/python/test/test_dataset.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from test_sdkbase import TestSdk
|
2 |
-
import
|
3 |
-
from ragflow.ragflow import RAGFLow
|
4 |
import pytest
|
5 |
-
from unittest.mock import MagicMock
|
6 |
from common import API_KEY, HOST_ADDRESS
|
7 |
|
|
|
|
|
8 |
class TestDataset(TestSdk):
|
9 |
|
10 |
def test_create_dataset(self):
|
@@ -15,12 +15,92 @@ class TestDataset(TestSdk):
|
|
15 |
4. update the kb
|
16 |
5. delete the kb
|
17 |
'''
|
18 |
-
ragflow = RAGFLow(API_KEY, HOST_ADDRESS)
|
19 |
|
|
|
20 |
# create a kb
|
21 |
res = ragflow.create_dataset("kb1")
|
22 |
assert res['code'] == 0 and res['message'] == 'success'
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
# TODO: list the kb
|
|
|
1 |
from test_sdkbase import TestSdk
|
2 |
+
from ragflow import RAGFlow
|
|
|
3 |
import pytest
|
|
|
4 |
from common import API_KEY, HOST_ADDRESS
|
5 |
|
6 |
+
|
7 |
+
|
8 |
class TestDataset(TestSdk):
|
9 |
|
10 |
def test_create_dataset(self):
|
|
|
15 |
4. update the kb
|
16 |
5. delete the kb
|
17 |
'''
|
|
|
18 |
|
19 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
20 |
# create a kb
|
21 |
res = ragflow.create_dataset("kb1")
|
22 |
assert res['code'] == 0 and res['message'] == 'success'
|
23 |
+
dataset_name = res['data']['dataset_name']
|
24 |
+
|
25 |
+
def test_list_dataset_success(self):
|
26 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
27 |
+
# Call the list_datasets method
|
28 |
+
response = ragflow.list_dataset()
|
29 |
+
|
30 |
+
code, datasets = response
|
31 |
+
|
32 |
+
assert code == 200
|
33 |
+
|
34 |
+
def test_list_dataset_with_checking_size_and_name(self):
|
35 |
+
datasets_to_create = ["dataset1", "dataset2", "dataset3"]
|
36 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
37 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
38 |
+
|
39 |
+
real_name_to_create = set()
|
40 |
+
for response in created_response:
|
41 |
+
assert 'data' in response, "Response is missing 'data' key"
|
42 |
+
dataset_name = response['data']['dataset_name']
|
43 |
+
real_name_to_create.add(dataset_name)
|
44 |
+
|
45 |
+
status_code, listed_data = ragflow.list_dataset(0, 3)
|
46 |
+
listed_data = listed_data['data']
|
47 |
+
|
48 |
+
listed_names = {d['name'] for d in listed_data}
|
49 |
+
assert listed_names == real_name_to_create
|
50 |
+
assert status_code == 200
|
51 |
+
assert len(listed_data) == len(datasets_to_create)
|
52 |
+
|
53 |
+
def test_list_dataset_with_getting_empty_result(self):
|
54 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
55 |
+
datasets_to_create = []
|
56 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
57 |
+
|
58 |
+
real_name_to_create = set()
|
59 |
+
for response in created_response:
|
60 |
+
assert 'data' in response, "Response is missing 'data' key"
|
61 |
+
dataset_name = response['data']['dataset_name']
|
62 |
+
real_name_to_create.add(dataset_name)
|
63 |
+
|
64 |
+
status_code, listed_data = ragflow.list_dataset(0, 0)
|
65 |
+
listed_data = listed_data['data']
|
66 |
+
|
67 |
+
listed_names = {d['name'] for d in listed_data}
|
68 |
+
assert listed_names == real_name_to_create
|
69 |
+
assert status_code == 200
|
70 |
+
assert len(listed_data) == 0
|
71 |
+
|
72 |
+
def test_list_dataset_with_creating_100_knowledge_bases(self):
|
73 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
74 |
+
datasets_to_create = ["dataset1"] * 100
|
75 |
+
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
76 |
+
|
77 |
+
real_name_to_create = set()
|
78 |
+
for response in created_response:
|
79 |
+
assert 'data' in response, "Response is missing 'data' key"
|
80 |
+
dataset_name = response['data']['dataset_name']
|
81 |
+
real_name_to_create.add(dataset_name)
|
82 |
+
|
83 |
+
status_code, listed_data = ragflow.list_dataset(0, 100)
|
84 |
+
listed_data = listed_data['data']
|
85 |
+
|
86 |
+
listed_names = {d['name'] for d in listed_data}
|
87 |
+
assert listed_names == real_name_to_create
|
88 |
+
assert status_code == 200
|
89 |
+
assert len(listed_data) == 100
|
90 |
+
|
91 |
+
def test_list_dataset_with_showing_one_dataset(self):
|
92 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
93 |
+
response = ragflow.list_dataset(0, 1)
|
94 |
+
code, response = response
|
95 |
+
datasets = response['data']
|
96 |
+
assert len(datasets) == 1
|
97 |
+
|
98 |
+
def test_list_dataset_failure(self):
|
99 |
+
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
100 |
+
response = ragflow.list_dataset(-1, -1)
|
101 |
+
_, res = response
|
102 |
+
assert "IndexError" in res['message']
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
|
|