Spaces:
Build error
Build error
File size: 5,190 Bytes
3932407 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import json
from logging import warning
from typing import Any, Dict, List
import jsonschema
import requests
import warnings
from schemathesis.models import APIOperation
from schemathesis.specs.openapi.references import ConvertingResolver
from schemathesis.specs.openapi.schemas import OpenApi30
from functools import lru_cache
from .settings import QDRANT_HOST, SCHEMA, QDRANT_HOST_HEADERS
def get_api_string(host, api, path_params):
"""
>>> get_api_string('http://localhost:6333', '/collections/{name}', {'name': 'hello', 'a': 'b'})
'http://localhost:6333/collections/hello'
"""
return f"{host}{api}".format(**path_params)
def validate_schema(data, operation_schema: OpenApi30, raw_definitions):
"""
:param data: concrete values to validate
:param operation_schema: operation schema
:param raw_definitions: definitions to check data with
:return:
"""
resolver = ConvertingResolver(
operation_schema.location or "",
operation_schema.raw_schema,
nullable_name=operation_schema.nullable_name,
is_response_schema=False
)
jsonschema.validate(data, raw_definitions, cls=jsonschema.Draft7Validator, resolver=resolver)
def request_with_validation(
api: str,
method: str,
path_params: dict = None,
query_params: dict = None,
body: dict = None
) -> requests.Response:
operation: APIOperation = SCHEMA[api][method]
assert isinstance(operation.schema, OpenApi30)
if body:
validate_schema(
data=body,
operation_schema=operation.schema,
raw_definitions=operation.definition.raw['requestBody']['content']['application/json']['schema']
)
if path_params is None:
path_params = {}
if query_params is None:
query_params = {}
action = getattr(requests, method.lower(), None)
for param in operation.path_parameters.items:
if param.is_required:
assert param.name in path_params
for param in operation.query.items:
if param.is_required:
assert param.name in query_params
for param in path_params.keys():
assert param in set(p.name for p in operation.path_parameters.items)
for param in query_params.keys():
assert param in set(p.name for p in operation.query.items)
if not action:
raise RuntimeError(f"Method {method} does not exists")
if api.endswith("/delete") and method == "POST" and "wait" not in query_params:
warnings.warn(f"Delete call for {api} missing wait=true param, adding it")
query_params["wait"] = "true"
response = action(
url=get_api_string(QDRANT_HOST, api, path_params),
params=query_params,
json=body,
headers=qdrant_host_headers()
)
operation.validate_response(response)
return response
# from client implementation:
# https://github.com/qdrant/qdrant-client/blob/d18cb1702f4cf8155766c7b32d1e4a68af11cd6a/qdrant_client/hybrid/fusion.py#L6C1-L31C25
def reciprocal_rank_fusion(
responses: List[List[Any]], limit: int = 10
) -> List[Any]:
def compute_score(pos: int) -> float:
ranking_constant = (
2 # the constant mitigates the impact of high rankings by outlier systems
)
return 1 / (ranking_constant + pos)
scores: Dict[Any, float] = {} # id -> score
point_pile = {}
for response in responses:
for i, scored_point in enumerate(response):
if scored_point["id"] in scores:
scores[scored_point["id"]] += compute_score(i)
else:
point_pile[scored_point["id"]] = scored_point
scores[scored_point["id"]] = compute_score(i)
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
sorted_points = []
for point_id, score in sorted_scores[:limit]:
point = point_pile[point_id]
point["score"] = score
sorted_points.append(point)
return sorted_points
def distribution_based_score_fusion(responses: List[List[Any]], limit: int = 10) -> List[Any]:
def normalize(response: List[Any]) -> List[Any]:
total = sum([point["score"] for point in response])
mean = total / len(response)
variance = sum([(point["score"] - mean) ** 2 for point in response]) / (len(response) - 1)
std_dev = variance ** 0.5
min = mean - 3 * std_dev
max = mean + 3 * std_dev
for point in response:
point["score"] = (point["score"] - min) / (max - min)
return response
points_map = {}
for response in responses:
normalized = normalize(response)
for point in normalized:
entry = points_map.get(point["id"])
if entry is None:
points_map[point["id"]] = point
else:
entry["score"] += point["score"]
sorted_points = sorted(points_map.values(), key=lambda item: item['score'], reverse=True)
return sorted_points[:limit]
@lru_cache
def qdrant_host_headers():
headers = json.loads(QDRANT_HOST_HEADERS)
return headers |