|
from qdrant_client.http import models |
|
import pickle as pickle |
|
import torch |
|
import io |
|
|
|
device_str = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
device = torch.device(device_str) |
|
|
|
|
|
class Device_Unpickler(pickle.Unpickler): |
|
|
|
def find_class(self, module, name): |
|
if module == "torch.storage" and name == "_load_from_bytes": |
|
return lambda b: torch.load(io.BytesIO(b), map_location=device_str) |
|
else: |
|
return super().find_class(module, name) |
|
|
|
|
|
def pickle_to_document_store(path): |
|
with open(path, "rb") as f: |
|
document_store = Device_Unpickler(f).load() |
|
document_store.embeddings.encode_kwargs["device"] = device_str |
|
return document_store |
|
|
|
|
|
def get_qdrant_filters(filter_dict: dict): |
|
"""Build a Qdrant filter based on a filter dict. |
|
|
|
Filter dict must use metadata fields and be formated like: |
|
|
|
filter_dict = {'file_name':['file1', 'file2'],'sub_type':['text']} |
|
""" |
|
return models.Filter( |
|
must=[ |
|
models.FieldCondition( |
|
key=f"metadata.{field}", |
|
match=models.MatchAny(any=filter_dict[field]), |
|
) |
|
for field in filter_dict |
|
] |
|
) |
|
|