File size: 5,066 Bytes
5f988b9
 
 
 
 
 
 
 
 
 
eec4792
5f988b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio
import os
import numpy as np
import pandas as pd
from IPython import display
import faiss
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection


DATA_PATH = '/data'

ft_visual_features_file = DATA_PATH + '/features/features_Cl4Cl_ckpt_webvid_retrieval_looseType_bs26_gpus2_lr7_150k_finalsample/dataset_v1_visual_features_database.npy'
binary_visual_features_file = DATA_PATH + '/features/features_Cl4Cl_ckpt_webvid_retrieval_looseType_bs26_gpus2_lr7_150k_finalsample_binary20/dataset_v1_visual_features_database_packed.npy'
ft_visual_features_database = np.load(ft_visual_features_file)
binary_visual_features = np.load(binary_visual_features_file)

database_csv_path = os.path.join(DATA_PATH, 'dataset_v1.csv')
database_df = pd.read_csv(database_csv_path)

#Gradio can display URL
def display_videos(display_df):
    display_path_list = display_df['contentUrl'].to_list()

    display_text_list = display_df['name'].to_list()
    html = ''
    for path, text in zip(display_path_list, display_text_list):
        html_line = '<video autoplay loop {}> <source src="{}" type="video/mp4"> </video> <div class="caption">{}</div><br/>'.format("muted", path, text)
        html += html_line
    return display.HTML(html)


class NearestNeighbors:
    """
    Class for NearestNeighbors.   
    """
    def __init__(self, n_neighbors=10, metric='cosine', rerank_from=-1):
        """
         metric = 'cosine' / 'binary' 
         if metric ~= 'cosine' and rerank_from > n_neighbors then a cosine rerank will be performed
        """
        self.n_neighbors = n_neighbors
        self.metric = metric        
        self.rerank_from = rerank_from                
        
    def normalize(self, a):
        return a / np.sum(a**2, axis=1, keepdims=True)
    
    def fit(self, data, o_data=None):
        if self.metric == 'cosine':
            data = self.normalize(data)
            self.index = faiss.IndexFlatIP(data.shape[1])        
        elif self.metric == 'binary':
            self.o_data = data if o_data is None else o_data
            #assuming data already packed
            self.index = faiss.IndexBinaryFlat(data.shape[1]*8)            
        self.index.add(np.ascontiguousarray(data))
        
    def kneighbors(self, q_data):                
        if self.metric == 'cosine':
            print('cosine search')
            q_data = self.normalize(q_data)            
            sim, idx = self.index.search(q_data, self.n_neighbors)        
        else:            
            if self.metric == 'binary':
                print('binary search')
                bq_data = np.packbits((q_data > 0.0).astype(bool), axis=1)            
            print(bq_data.shape, self.index.d)
            sim, idx = self.index.search(bq_data, max(self.rerank_from, self.n_neighbors))
            
            if self.rerank_from > self.n_neighbors:
                sim_float = np.zeros([len(q_data), self.rerank_from], dtype=float)
                for i, q in enumerate(q_data):
                    candidates = np.take_along_axis(self.o_data, idx[i:i+1,:].T, axis=0)
                    sim_float[i,:] = q @ candidates.T
                    sort_idx = np.argsort(sim_float[i,:])[::-1]
                    sim_float[i,:] = sim_float[i,:][sort_idx]
                    idx[i,:] = idx[i,:][sort_idx]
                sim = sim_float[:,:self.n_neighbors]
                idx = idx[:,:self.n_neighbors]
        
        return sim, idx
    

def search(search_sentence):
    my_model = CLIPTextModelWithProjection.from_pretrained("Diangle/clip4clip-webvid")
    tokenizer = AutoTokenizer.from_pretrained("Diangle/clip4clip-webvid")


    inputs = tokenizer(text=search_sentence , return_tensors="pt", padding=True)

    outputs = my_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict=False)
        
    text_projection = my_model.state_dict()['text_projection.weight']
    text_embeds = outputs[1] @ text_projection
    final_output = text_embeds[torch.arange(text_embeds.shape[0]), inputs["input_ids"].argmax(dim=-1)]


    final_output = final_output / final_output.norm(dim=-1, keepdim=True)
    final_output = final_output.cpu().detach().numpy()
    sequence_output = final_output / np.sum(final_output**2, axis=1, keepdims=True)
    
    nn_search = NearestNeighbors(n_neighbors=5, metric='binary', rerank_from=100)
    nn_search.fit(np.packbits((ft_visual_features_database > 0.0).astype(bool), axis=1), o_data=ft_visual_features_database)
    sims, idxs = nn_search.kneighbors(sequence_output)  
    return database_df.iloc[idxs[0]]['contentUrl'].to_list()

    
gradio.close_all()
    
interface = gradio.Interface(search, 
                             inputs=[gradio.Textbox()],
                             outputs=[gradio.Video(format='mp4') for _ in range(5)],
                             title = 'Video Search Demo',
                             description = 'Type some text to search by content within a video database!',
                            ).launch()