cecilia-uu commited on
Commit
090b2e7
·
1 Parent(s): 2ef3748

create list_dataset api and tests (#1138)

Browse files

### What problem does this PR solve?

This PR have completed both HTTP API and Python SDK for 'list_dataset".
In addition, there are tests for it.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/dataset_api.py CHANGED
@@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT
46
 
47
  # ------------------------------ create a dataset ---------------------------------------
48
  @manager.route('/', methods=['POST'])
49
- @login_required # use login
50
  @validate_request("name") # check name key
51
  def create_dataset():
52
  # Check if Authorization header is present
@@ -111,10 +111,27 @@ def create_dataset():
111
  if not KnowledgebaseService.save(**request_body):
112
  # failed to create new dataset
113
  return construct_result()
114
- return construct_json_result(data={"dataset_id": request_body["id"]})
115
  except Exception as e:
116
  return construct_error_response(e)
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @manager.route('/<dataset_id>', methods=['DELETE'])
120
  @login_required
@@ -135,8 +152,5 @@ def get_dataset(dataset_id):
135
  return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
136
 
137
 
138
- @manager.route('/', methods=['GET'])
139
- @login_required
140
- def list_datasets():
141
- return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
142
 
 
46
 
47
  # ------------------------------ create a dataset ---------------------------------------
48
  @manager.route('/', methods=['POST'])
49
+ @login_required # use login
50
  @validate_request("name") # check name key
51
  def create_dataset():
52
  # Check if Authorization header is present
 
111
  if not KnowledgebaseService.save(**request_body):
112
  # failed to create new dataset
113
  return construct_result()
114
+ return construct_json_result(data={"dataset_name": request_body["name"]})
115
  except Exception as e:
116
  return construct_error_response(e)
117
 
118
+ # -----------------------------list datasets-------------------------------------------------------
119
+ @manager.route('/', methods=['GET'])
120
+ @login_required
121
+ def list_datasets():
122
+ offset = request.args.get("offset", 0)
123
+ count = request.args.get("count", -1)
124
+ orderby = request.args.get("orderby", "create_time")
125
+ desc = request.args.get("desc", True)
126
+ try:
127
+ tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
128
+ kbs = KnowledgebaseService.get_by_tenant_ids(
129
+ [m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc)
130
+ return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
131
+ except Exception as e:
132
+ return construct_error_response(e)
133
+
134
+ # ---------------------------------delete a dataset ----------------------------
135
 
136
  @manager.route('/<dataset_id>', methods=['DELETE'])
137
  @login_required
 
152
  return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
153
 
154
 
155
+
 
 
 
156
 
api/db/services/knowledgebase_service.py CHANGED
@@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService):
40
 
41
  return list(kbs.dicts())
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  @classmethod
44
  @DB.connection_context()
45
  def get_detail(cls, kb_id):
 
40
 
41
  return list(kbs.dicts())
42
 
43
+ @classmethod
44
+ @DB.connection_context()
45
+ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
46
+ offset, count, orderby, desc):
47
+ kbs = cls.model.select().where(
48
+ ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
49
+ TenantPermission.TEAM.value)) | (
50
+ cls.model.tenant_id == user_id))
51
+ & (cls.model.status == StatusEnum.VALID.value)
52
+ )
53
+ if desc:
54
+ kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
55
+ else:
56
+ kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
57
+
58
+ kbs = list(kbs.dicts())
59
+
60
+ kbs_length = len(kbs)
61
+ if offset < 0 or offset > kbs_length:
62
+ raise IndexError("Offset is out of the valid range.")
63
+
64
+ return kbs[offset:offset+count]
65
+
66
  @classmethod
67
  @DB.connection_context()
68
  def get_detail(cls, kb_id):
sdk/python/ragflow/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
  import importlib.metadata
2
 
3
  __version__ = importlib.metadata.version("ragflow")
 
 
 
1
  import importlib.metadata
2
 
3
  __version__ = importlib.metadata.version("ragflow")
4
+
5
+ from .ragflow import RAGFlow
sdk/python/ragflow/dataset.py CHANGED
@@ -18,4 +18,4 @@ class DataSet:
18
  self.user_key = user_key
19
  self.dataset_url = dataset_url
20
  self.uuid = uuid
21
- self.name = name
 
18
  self.user_key = user_key
19
  self.dataset_url = dataset_url
20
  self.uuid = uuid
21
+ self.name = name
sdk/python/ragflow/ragflow.py CHANGED
@@ -17,7 +17,10 @@ import os
17
  import requests
18
  import json
19
 
20
- class RAGFLow:
 
 
 
21
  def __init__(self, user_key, base_url, version = 'v1'):
22
  '''
23
  api_url: http://<host_address>/api/v1
@@ -36,16 +39,39 @@ class RAGFLow:
36
  result_dict = json.loads(res.text)
37
  return result_dict
38
 
39
- def delete_dataset(self, dataset_name = None, dataset_id = None):
40
  return dataset_name
41
 
42
- def list_dataset(self):
43
- response = requests.get(self.dataset_url)
44
- print(response)
45
- if response.status_code == 200:
46
- return response.json()['datasets']
47
- else:
48
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_dataset(self, dataset_id):
51
  endpoint = f"{self.dataset_url}/{dataset_id}"
@@ -61,4 +87,4 @@ class RAGFLow:
61
  if response.status_code == 200:
62
  return True
63
  else:
64
- return False
 
17
  import requests
18
  import json
19
 
20
+ from httpx import HTTPError
21
+
22
+
23
+ class RAGFlow:
24
  def __init__(self, user_key, base_url, version = 'v1'):
25
  '''
26
  api_url: http://<host_address>/api/v1
 
39
  result_dict = json.loads(res.text)
40
  return result_dict
41
 
42
+ def delete_dataset(self, dataset_name=None, dataset_id=None):
43
  return dataset_name
44
 
45
+ def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
46
+ params = {
47
+ "offset": offset,
48
+ "count": count,
49
+ "orderby": orderby,
50
+ "desc": desc
51
+ }
52
+ try:
53
+ response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
54
+ response.raise_for_status() # if it is not 200
55
+ original_data = response.json()
56
+ # TODO: format the data
57
+ # print(original_data)
58
+ # # Process the original data into the desired format
59
+ # formatted_data = {
60
+ # "datasets": [
61
+ # {
62
+ # "id": dataset["id"],
63
+ # "created": dataset["create_time"], # Adjust the key based on the actual response
64
+ # "fileCount": dataset["doc_num"], # Adjust the key based on the actual response
65
+ # "name": dataset["name"]
66
+ # }
67
+ # for dataset in original_data
68
+ # ]
69
+ # }
70
+ return response.status_code, original_data
71
+ except HTTPError as http_err:
72
+ print(f"HTTP error occurred: {http_err}")
73
+ except Exception as err:
74
+ print(f"An error occurred: {err}")
75
 
76
  def get_dataset(self, dataset_id):
77
  endpoint = f"{self.dataset_url}/{dataset_id}"
 
87
  if response.status_code == 200:
88
  return True
89
  else:
90
+ return False
sdk/python/test/common.py CHANGED
@@ -1,4 +1,4 @@
1
 
2
 
3
- API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE'
4
  HOST_ADDRESS = 'http://127.0.0.1:9380'
 
1
 
2
 
3
+ API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E'
4
  HOST_ADDRESS = 'http://127.0.0.1:9380'
sdk/python/test/test_dataset.py CHANGED
@@ -1,10 +1,10 @@
1
  from test_sdkbase import TestSdk
2
- import ragflow
3
- from ragflow.ragflow import RAGFLow
4
  import pytest
5
- from unittest.mock import MagicMock
6
  from common import API_KEY, HOST_ADDRESS
7
 
 
 
8
  class TestDataset(TestSdk):
9
 
10
  def test_create_dataset(self):
@@ -15,12 +15,92 @@ class TestDataset(TestSdk):
15
  4. update the kb
16
  5. delete the kb
17
  '''
18
- ragflow = RAGFLow(API_KEY, HOST_ADDRESS)
19
 
 
20
  # create a kb
21
  res = ragflow.create_dataset("kb1")
22
  assert res['code'] == 0 and res['message'] == 'success'
23
- dataset_id = res['data']['dataset_id']
24
- print(dataset_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # TODO: list the kb
 
1
  from test_sdkbase import TestSdk
2
+ from ragflow import RAGFlow
 
3
  import pytest
 
4
  from common import API_KEY, HOST_ADDRESS
5
 
6
+
7
+
8
  class TestDataset(TestSdk):
9
 
10
  def test_create_dataset(self):
 
15
  4. update the kb
16
  5. delete the kb
17
  '''
 
18
 
19
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
20
  # create a kb
21
  res = ragflow.create_dataset("kb1")
22
  assert res['code'] == 0 and res['message'] == 'success'
23
+ dataset_name = res['data']['dataset_name']
24
+
25
+ def test_list_dataset_success(self):
26
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
27
+ # Call the list_datasets method
28
+ response = ragflow.list_dataset()
29
+
30
+ code, datasets = response
31
+
32
+ assert code == 200
33
+
34
+ def test_list_dataset_with_checking_size_and_name(self):
35
+ datasets_to_create = ["dataset1", "dataset2", "dataset3"]
36
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
37
+ created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
38
+
39
+ real_name_to_create = set()
40
+ for response in created_response:
41
+ assert 'data' in response, "Response is missing 'data' key"
42
+ dataset_name = response['data']['dataset_name']
43
+ real_name_to_create.add(dataset_name)
44
+
45
+ status_code, listed_data = ragflow.list_dataset(0, 3)
46
+ listed_data = listed_data['data']
47
+
48
+ listed_names = {d['name'] for d in listed_data}
49
+ assert listed_names == real_name_to_create
50
+ assert status_code == 200
51
+ assert len(listed_data) == len(datasets_to_create)
52
+
53
+ def test_list_dataset_with_getting_empty_result(self):
54
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
55
+ datasets_to_create = []
56
+ created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
57
+
58
+ real_name_to_create = set()
59
+ for response in created_response:
60
+ assert 'data' in response, "Response is missing 'data' key"
61
+ dataset_name = response['data']['dataset_name']
62
+ real_name_to_create.add(dataset_name)
63
+
64
+ status_code, listed_data = ragflow.list_dataset(0, 0)
65
+ listed_data = listed_data['data']
66
+
67
+ listed_names = {d['name'] for d in listed_data}
68
+ assert listed_names == real_name_to_create
69
+ assert status_code == 200
70
+ assert len(listed_data) == 0
71
+
72
+ def test_list_dataset_with_creating_100_knowledge_bases(self):
73
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
74
+ datasets_to_create = ["dataset1"] * 100
75
+ created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
76
+
77
+ real_name_to_create = set()
78
+ for response in created_response:
79
+ assert 'data' in response, "Response is missing 'data' key"
80
+ dataset_name = response['data']['dataset_name']
81
+ real_name_to_create.add(dataset_name)
82
+
83
+ status_code, listed_data = ragflow.list_dataset(0, 100)
84
+ listed_data = listed_data['data']
85
+
86
+ listed_names = {d['name'] for d in listed_data}
87
+ assert listed_names == real_name_to_create
88
+ assert status_code == 200
89
+ assert len(listed_data) == 100
90
+
91
+ def test_list_dataset_with_showing_one_dataset(self):
92
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
93
+ response = ragflow.list_dataset(0, 1)
94
+ code, response = response
95
+ datasets = response['data']
96
+ assert len(datasets) == 1
97
+
98
+ def test_list_dataset_failure(self):
99
+ ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
100
+ response = ragflow.list_dataset(-1, -1)
101
+ _, res = response
102
+ assert "IndexError" in res['message']
103
+
104
+
105
+
106