Spaces:
Build error
Build error
import datetime | |
import math | |
import random | |
import pytest | |
from more_itertools import batched | |
from .helpers.collection_setup import basic_collection_setup, drop_collection | |
from .helpers.helpers import request_with_validation | |
total_points = 300 | |
def upsert_points(collection_name, amount=100): | |
def maybe_repeated(): | |
"""Descending sequence of possibly repeated floats""" | |
repeated_float = float(amount) | |
while True: | |
repeated_float = repeated_float if random.random() > 0.5 else repeated_float - 1.0 | |
yield repeated_float | |
maybe_repeated_generator = maybe_repeated() | |
def date(): | |
date = datetime.datetime.now() | |
while True: | |
if random.random() > 0.5: | |
date += datetime.timedelta(days=1) | |
yield date.isoformat() + "Z" # RFC-3339 format | |
date_generator = date() | |
points = [ | |
{ | |
"id": i, | |
"vector": [0.1 * i] * 4, | |
"payload": { | |
"city": "London" if i % 3 == 0 else "Moscow", | |
"is_middle_split": i > amount * 0.25 and i < amount * 0.75, | |
"price": float(amount - i), | |
"payload_id": i, | |
"multi_id": [i, amount - i + 1], | |
"maybe_repeated_float": next(maybe_repeated_generator), | |
"date_rfc3339": next(date_generator), | |
"date_simple": next(date_generator).split("T")[0], | |
}, | |
} | |
for i in range(amount) | |
] | |
for batch in batched(points, 50): | |
response = request_with_validation( | |
api="/collections/{collection_name}/points", | |
method="PUT", | |
path_params={"collection_name": collection_name}, | |
query_params={"wait": "true"}, | |
body={"points": list(batch)}, | |
) | |
assert response.ok | |
def create_payload_index(collection_name, field_name, field_schema): | |
response = request_with_validation( | |
api="/collections/{collection_name}/index", | |
method="PUT", | |
path_params={"collection_name": collection_name}, | |
query_params={"wait": "true"}, | |
body={"field_name": field_name, "field_schema": field_schema}, | |
) | |
assert response.ok, response.json() | |
def setup(on_disk_vectors, collection_name): | |
basic_collection_setup(collection_name=collection_name, on_disk_vectors=on_disk_vectors) | |
upsert_points(collection_name=collection_name, amount=total_points) | |
create_payload_index( | |
collection_name=collection_name, field_name="city", field_schema="keyword" | |
) | |
create_payload_index( | |
collection_name=collection_name, field_name="is_middle_split", field_schema="bool" | |
) | |
create_payload_index(collection_name=collection_name, field_name="price", field_schema="float") | |
create_payload_index( | |
collection_name=collection_name, field_name="maybe_repeated_float", field_schema="float" | |
) | |
create_payload_index( | |
collection_name=collection_name, field_name="payload_id", field_schema="integer" | |
) | |
create_payload_index( | |
collection_name=collection_name, field_name="multi_id", field_schema="integer" | |
) | |
create_payload_index( | |
collection_name=collection_name, field_name="date_rfc3339", field_schema="datetime" | |
) | |
create_payload_index( | |
collection_name=collection_name, field_name="date_simple", field_schema="datetime" | |
) | |
yield | |
drop_collection(collection_name=collection_name) | |
def test_order_by_int_ascending(collection_name): | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"order_by": {"key": "payload_id", "direction": "asc"}, | |
"limit": 5, | |
}, | |
) | |
assert response.ok, response.json() | |
result = response.json()["result"] | |
assert len(result["points"]) == 5 | |
ids = [x["id"] for x in result["points"]] | |
assert [0, 1, 2, 3, 4] == ids | |
# Offset is not supported for order_by | |
assert result["next_page_offset"] == None | |
def test_order_by_int_descending(collection_name): | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"order_by": {"key": "payload_id", "direction": "desc"}, | |
"limit": 5, | |
}, | |
) | |
assert response.ok, response.json() | |
result = response.json()["result"] | |
assert len(result["points"]) == 5 | |
ids = [x["id"] for x in result["points"]] | |
# We expect the last ids | |
expected_ids = [total_points - (i + 1) for i in range(5)] | |
assert expected_ids == ids | |
# Offset is not supported for order_by | |
assert result["next_page_offset"] == None | |
def paginate_whole_collection(collection_name, key, direction, must=None): | |
limit = 23 | |
pages = 0 | |
points_count = 0 | |
points_set = set() | |
last_value = None | |
last_value_ids = set() | |
start_from = None | |
# Get filtered total points | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/count", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"filter": {"must": must}, | |
"exact": True, | |
}, | |
) | |
assert response.ok, response.json() | |
expected_points = response.json()["result"]["count"] | |
while True: | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"order_by": {"key": key, "direction": direction, "start_from": start_from}, | |
"limit": limit, | |
"filter": { | |
"must": must, | |
"must_not": [{"has_id": [id_ for id_ in last_value_ids]}], | |
}, | |
}, | |
) | |
assert response.ok, response.json() | |
points = response.json()["result"]["points"] | |
if len(points) > 0: | |
last_value = points[-1]["payload"][key] | |
# Exclude the ids we've already seen for the start_from value. | |
# This is what we expect the users to do in order to paginate with order_by | |
if start_from != last_value: | |
last_value_ids.clear() | |
start_from = last_value | |
same_value_points = [ | |
point["id"] for point in points if point["payload"][key] == last_value | |
] | |
last_value_ids.update(same_value_points) | |
points_len = len(response.json()["result"]["points"]) | |
points_count += points_len | |
pages += 1 | |
# Check no duplicates | |
for point in points: | |
assert point["id"] not in points_set | |
points_set.add(point["id"]) | |
if points_len < limit: | |
break | |
try: | |
assert math.ceil(expected_points / limit) == pages | |
assert expected_points == points_count | |
except AssertionError: | |
# Check which points we're missing | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"limit": total_points, | |
"filter": {"must": must}, | |
}, | |
) | |
assert response.ok, response.json() | |
filtered_points = set([point["id"] for point in response.json()["result"]["points"]]) | |
assert filtered_points == points_set, f"Missing points: {filtered_points - points_set}" | |
# possibly break of an infinite loop | |
def test_paginate_whole_collection(collection_name, key, direction): | |
paginate_whole_collection(collection_name, key, direction) | |
# possibly break of an infinite loop | |
def test_order_by_pagination_with_filters(collection_name, key, direction): | |
musts = [ | |
[ | |
{ | |
"key": "city", | |
"match": { | |
"value": "London", | |
}, | |
} | |
], | |
[ | |
{ | |
"key": "is_middle_split", | |
"match": { | |
"value": False, | |
}, | |
} | |
], | |
] | |
for must in musts: | |
paginate_whole_collection(collection_name, key, direction, must) | |
def test_multi_values_appear_multiple_times(collection_name): | |
limit = total_points * 2 | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"order_by": "multi_id", | |
"limit": limit, | |
}, | |
) | |
assert response.ok, response.json() | |
points = response.json()["result"]["points"] | |
assert len(points) == limit | |
freqs = {} | |
for point in points: | |
id_ = point["id"] | |
if id_ in freqs: | |
freqs[id_] += 1 | |
else: | |
freqs[id_] = 1 | |
assert all([count == 2 for count in freqs.values()]) | |
def test_cannot_use_offset_with_order_by(collection_name): | |
response = request_with_validation( | |
api="/collections/{collection_name}/points/scroll", | |
method="POST", | |
path_params={"collection_name": collection_name}, | |
body={ | |
"order_by": "payload_id", | |
"offset": 10, | |
"limit": 10, | |
}, | |
) | |
assert not response.ok | |
assert response.status_code == 400 | |