colibri.qdrant / tests /openapi /test_order_by.py
Gouzi Mohaled
Ajout du dossier tests
3932407
import datetime
import math
import random
import pytest
from more_itertools import batched
from .helpers.collection_setup import basic_collection_setup, drop_collection
from .helpers.helpers import request_with_validation
total_points = 300
def upsert_points(collection_name, amount=100):
def maybe_repeated():
"""Descending sequence of possibly repeated floats"""
repeated_float = float(amount)
while True:
repeated_float = repeated_float if random.random() > 0.5 else repeated_float - 1.0
yield repeated_float
maybe_repeated_generator = maybe_repeated()
def date():
date = datetime.datetime.now()
while True:
if random.random() > 0.5:
date += datetime.timedelta(days=1)
yield date.isoformat() + "Z" # RFC-3339 format
date_generator = date()
points = [
{
"id": i,
"vector": [0.1 * i] * 4,
"payload": {
"city": "London" if i % 3 == 0 else "Moscow",
"is_middle_split": i > amount * 0.25 and i < amount * 0.75,
"price": float(amount - i),
"payload_id": i,
"multi_id": [i, amount - i + 1],
"maybe_repeated_float": next(maybe_repeated_generator),
"date_rfc3339": next(date_generator),
"date_simple": next(date_generator).split("T")[0],
},
}
for i in range(amount)
]
for batch in batched(points, 50):
response = request_with_validation(
api="/collections/{collection_name}/points",
method="PUT",
path_params={"collection_name": collection_name},
query_params={"wait": "true"},
body={"points": list(batch)},
)
assert response.ok
def create_payload_index(collection_name, field_name, field_schema):
response = request_with_validation(
api="/collections/{collection_name}/index",
method="PUT",
path_params={"collection_name": collection_name},
query_params={"wait": "true"},
body={"field_name": field_name, "field_schema": field_schema},
)
assert response.ok, response.json()
@pytest.fixture(autouse=True, scope="module")
def setup(on_disk_vectors, collection_name):
basic_collection_setup(collection_name=collection_name, on_disk_vectors=on_disk_vectors)
upsert_points(collection_name=collection_name, amount=total_points)
create_payload_index(
collection_name=collection_name, field_name="city", field_schema="keyword"
)
create_payload_index(
collection_name=collection_name, field_name="is_middle_split", field_schema="bool"
)
create_payload_index(collection_name=collection_name, field_name="price", field_schema="float")
create_payload_index(
collection_name=collection_name, field_name="maybe_repeated_float", field_schema="float"
)
create_payload_index(
collection_name=collection_name, field_name="payload_id", field_schema="integer"
)
create_payload_index(
collection_name=collection_name, field_name="multi_id", field_schema="integer"
)
create_payload_index(
collection_name=collection_name, field_name="date_rfc3339", field_schema="datetime"
)
create_payload_index(
collection_name=collection_name, field_name="date_simple", field_schema="datetime"
)
yield
drop_collection(collection_name=collection_name)
def test_order_by_int_ascending(collection_name):
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"order_by": {"key": "payload_id", "direction": "asc"},
"limit": 5,
},
)
assert response.ok, response.json()
result = response.json()["result"]
assert len(result["points"]) == 5
ids = [x["id"] for x in result["points"]]
assert [0, 1, 2, 3, 4] == ids
# Offset is not supported for order_by
assert result["next_page_offset"] == None
def test_order_by_int_descending(collection_name):
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"order_by": {"key": "payload_id", "direction": "desc"},
"limit": 5,
},
)
assert response.ok, response.json()
result = response.json()["result"]
assert len(result["points"]) == 5
ids = [x["id"] for x in result["points"]]
# We expect the last ids
expected_ids = [total_points - (i + 1) for i in range(5)]
assert expected_ids == ids
# Offset is not supported for order_by
assert result["next_page_offset"] == None
def paginate_whole_collection(collection_name, key, direction, must=None):
limit = 23
pages = 0
points_count = 0
points_set = set()
last_value = None
last_value_ids = set()
start_from = None
# Get filtered total points
response = request_with_validation(
api="/collections/{collection_name}/points/count",
method="POST",
path_params={"collection_name": collection_name},
body={
"filter": {"must": must},
"exact": True,
},
)
assert response.ok, response.json()
expected_points = response.json()["result"]["count"]
while True:
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"order_by": {"key": key, "direction": direction, "start_from": start_from},
"limit": limit,
"filter": {
"must": must,
"must_not": [{"has_id": [id_ for id_ in last_value_ids]}],
},
},
)
assert response.ok, response.json()
points = response.json()["result"]["points"]
if len(points) > 0:
last_value = points[-1]["payload"][key]
# Exclude the ids we've already seen for the start_from value.
# This is what we expect the users to do in order to paginate with order_by
if start_from != last_value:
last_value_ids.clear()
start_from = last_value
same_value_points = [
point["id"] for point in points if point["payload"][key] == last_value
]
last_value_ids.update(same_value_points)
points_len = len(response.json()["result"]["points"])
points_count += points_len
pages += 1
# Check no duplicates
for point in points:
assert point["id"] not in points_set
points_set.add(point["id"])
if points_len < limit:
break
try:
assert math.ceil(expected_points / limit) == pages
assert expected_points == points_count
except AssertionError:
# Check which points we're missing
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"limit": total_points,
"filter": {"must": must},
},
)
assert response.ok, response.json()
filtered_points = set([point["id"] for point in response.json()["result"]["points"]])
assert filtered_points == points_set, f"Missing points: {filtered_points - points_set}"
@pytest.mark.parametrize(
"key, direction",
[
("payload_id", "asc"),
("payload_id", "desc"),
("price", "asc"),
("price", "desc"),
("maybe_repeated_float", "asc"),
("maybe_repeated_float", "desc"),
("date_rfc3339", "asc"),
("date_rfc3339", "desc"),
("date_simple", "asc"),
("date_simple", "desc"),
],
)
@pytest.mark.timeout(60) # possibly break of an infinite loop
def test_paginate_whole_collection(collection_name, key, direction):
paginate_whole_collection(collection_name, key, direction)
@pytest.mark.parametrize(
"key, direction",
[
("payload_id", "asc"),
("payload_id", "desc"),
("price", "asc"),
("price", "desc"),
("maybe_repeated_float", "asc"),
("maybe_repeated_float", "desc"),
("date_rfc3339", "asc"),
("date_rfc3339", "desc"),
("date_simple", "asc"),
("date_simple", "desc"),
],
)
@pytest.mark.timeout(60) # possibly break of an infinite loop
def test_order_by_pagination_with_filters(collection_name, key, direction):
musts = [
[
{
"key": "city",
"match": {
"value": "London",
},
}
],
[
{
"key": "is_middle_split",
"match": {
"value": False,
},
}
],
]
for must in musts:
paginate_whole_collection(collection_name, key, direction, must)
def test_multi_values_appear_multiple_times(collection_name):
limit = total_points * 2
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"order_by": "multi_id",
"limit": limit,
},
)
assert response.ok, response.json()
points = response.json()["result"]["points"]
assert len(points) == limit
freqs = {}
for point in points:
id_ = point["id"]
if id_ in freqs:
freqs[id_] += 1
else:
freqs[id_] = 1
assert all([count == 2 for count in freqs.values()])
def test_cannot_use_offset_with_order_by(collection_name):
response = request_with_validation(
api="/collections/{collection_name}/points/scroll",
method="POST",
path_params={"collection_name": collection_name},
body={
"order_by": "payload_id",
"offset": 10,
"limit": 10,
},
)
assert not response.ok
assert response.status_code == 400