File size: 2,171 Bytes
52b1203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import faiss
import os

from functools import lru_cache
from pathlib import Path


class DatabaseCursor:
    def __init__(
        self,
        index_file_path: str,
        index_subframes_file_path: str,
        keyframes_groups_json_path: str,
        subframes_groups_json_path: str,
    ):
        self._load_index(index_file_path, index_subframes_file_path)
        self._load_keyframes_groups_info(
            keyframes_groups_json_path, subframes_groups_json_path
        )

    @lru_cache(maxsize=1)
    def _load_index(self, index_file_path, index_subframes_file_path):
        self.index = faiss.read_index(index_file_path)
        index_subframes = faiss.read_index(index_subframes_file_path)
        try:
            self.index.merge_from(index_subframes)
        except:
            raise Exception("dtb_cursor::cannot merge keyframes and subframes index")

    @lru_cache(maxsize=1)
    def _load_keyframes_groups_info(
        self, keyframes_groups_json_path: str, subframes_groups_json_path: str
    ):
        with open(keyframes_groups_json_path) as file:
            keyframes_group_info = json.loads(file.read())
            self.no_keyframes = len(keyframes_group_info)
        with open(subframes_groups_json_path) as file:
            subframes_groups_info = json.loads(file.read())
            self.no_subframes = len(subframes_groups_info)

        self.frames_groups_info = keyframes_group_info
        self.frames_groups_info.extend(subframes_groups_info)
        print(self.index.ntotal)
        assert self.index.ntotal == len(
            self.frames_groups_info
        ), "dtb_cursor::Index length and map lenght mismatch"

    def kNN_search(self, query_vector: str, topk: int = 10):
        results = []
        distances, ids = self.index.search(query_vector, topk)
        for i in range(len(ids[0])):
            frame_detail = self.frames_groups_info[ids[0][i]]
            frame_detail["distance"] = str(distances[0][i])
            frame_detail["folder"] = (
                "Keyframes" if ids[0][i] < self.no_keyframes else "Subframes"
            )
            results.append(frame_detail)
        return results