File size: 4,423 Bytes
5f988b9
 
 
 
 
 
 
 
 
 
941a695
5f988b9
941a695
 
5f988b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89c3aac
 
5f988b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection


DATA_PATH = './data'

ft_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database.npy'
binary_visual_features_file = DATA_PATH + '/dataset_v1_visual_features_database_packed.npy'
ft_visual_features_database = np.load(ft_visual_features_file)
binary_visual_features = np.load(binary_visual_features_file)

database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)


class NearestNeighbors:
    """
    Class for NearestNeighbors.   
    """
    def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
        """
         metric = 'cosine' / 'binary' 
         if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
        """
        self.n_neighbors = n_neighbors
        self.metric = metric        
        self.rerank_from = rerank_from                
        
    def normalize(self, a):
        return a / np.sum(a**2, axis=1, keepdims=True)
    
    def fit(self, data, o_data=None):
        if self.metric == 'cosine':
            data = self.normalize(data)
            self.index = faiss.IndexFlatIP(data.shape[1])        
        elif self.metric == 'binary':
            self.o_data = data if o_data is None else o_data
            #assuming data already packed
            self.index = faiss.IndexBinaryFlat(data.shape[1]*8)            
        self.index.add(np.ascontiguousarray(data))
        
    def kneighbors(self, q_data):                
        if self.metric == 'cosine':
            print('cosine search')
            q_data = self.normalize(q_data)            
            sim, idx = self.index.search(q_data, self.n_neighbors)        
        else:            
            if self.metric == 'binary':
                print('binary search')
                bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)            
            print(bq_data.shape, self.index.d)
            sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
            
            if self.rerank_from > self.n_neighbors:
                sim_float = np.zeros([len(q_data), self.rerank_from], dtype=float)
                for i, q in enumerate(q_data):
                    candidates = np.take_along_axis(self.o_data, idx[i:i+1,:].T, axis=0)
                    sim_float[i,:] = q @ candidates.T
                    sort_idx = np.argsort(sim_float[i,:])[::-1]
                    sim_float[i,:] = sim_float[i,:][sort_idx]
                    idx[i,:] = idx[i,:][sort_idx]
                sim = sim_float[:,:self.n_neighbors]
                idx = idx[:,:self.n_neighbors]
        
        return sim, idx
    

def search(search_sentence):
    my_model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
    tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")


    inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)

    outputs = my_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
        
    text_projection = my_model.state_dict()['text_projection.weight']
    text_embeds = outputs[1] @ text_projection
    final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]


    final_output = final_output / final_output.norm(dim=-1, keepdim=True)
    final_output = final_output.cpu().detach().numpy()
    sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
    
    nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
    nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
    sims, idxs = nn_search.kneighbors(sequence_output)  
    return database_df.iloc[idxs[0]]['contentUrl'].to_list()

    
gradio.close_all()
    
interface = gradio.Interface(search, 
                             inputs=[gradio.Textbox()],
                             outputs=[gradio.Video(format='mp4') for _ in range(5)],
                             title = 'Video Search Demo',
                             description = 'Type some text to search by content within a video database!',
                            ).launch()