File size: 6,794 Bytes
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75ba36
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5bf018
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c53f1
41b72e9
 
 
d5bf018
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c53f1
41b72e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6992194
 
 
 
 
41b72e9
 
 
ec67083
41b72e9
 
 
6992194
41b72e9
 
 
 
 
 
 
 
6992194
 
 
 
 
 
 
 
 
 
 
 
 
ec67083
 
41b72e9
5b33534
18a2d8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
import torch
from tqdm import tqdm
import clip
from glob import glob
import gradio as gr
import os
import torchvision
import pickle
from collections import Counter

from SimSearch import FaissCosineNeighbors

# DOWNLOAD THE DATASET and Files

torchvision.datasets.utils.download_file_from_google_drive('1kB1vNdVaNS1OGZ3K8BspBUKkPACCsnrG', '.', 'GTAV-Videos.zip')
torchvision.datasets.utils.download_file_from_google_drive('1IF1ljcoFd31C-PA2SO8F5fEblDYf0Bw6', '.', 'GTAV-embedding-vit32.zip')

# EXTRACT
torchvision.datasets.utils.extract_archive(from_path='GTAV-embedding-vit32.zip', to_path='Embeddings/VIT32/', remove_finished=False)
torchvision.datasets.utils.extract_archive(from_path='GTAV-Videos.zip', to_path='Videos/', remove_finished=False)

# Initialize CLIP model
clip.available_models()

# # Searcher

class GamePhysicsSearcher:
  def __init__(self, CLIP_MODEL, GAME_NAME, EMBEDDING_PATH='./Embeddings/VIT32/'):
    self.CLIP_MODEL = CLIP_MODEL
    self.GAME_NAME = GAME_NAME
    self.simsearcher = FaissCosineNeighbors()

    self.all_embeddings = glob(f'{EMBEDDING_PATH}{self.GAME_NAME}/*.npy')
    
    self.filenames = [os.path.basename(x) for x in self.all_embeddings]
    self.file_to_class_id = {x:i for i, x in enumerate(self.filenames)}
    self.class_id_to_file = {i:x for i, x in enumerate(self.filenames)}
    self.build_index()
    
  def read_features(self, file_path):
    with open(file_path, 'rb') as f:
      video_features = pickle.load(f)
    return video_features
  
  def read_all_features(self):
    features = {}
    filenames_extended = []

    X_train = []
    y_train = []

    for i, vfile in enumerate(tqdm(self.all_embeddings)):
      vfeatures = self.read_features(vfile)
      features[vfile.split('/')[-1]] = vfeatures
      X_train.extend(vfeatures)
      y_train.extend([i]*vfeatures.shape[0])
      filenames_extended.extend(vfeatures.shape[0]*[vfile.split('/')[-1]])

    X_train = np.asarray(X_train)
    y_train = np.asarray(y_train)
    
    return X_train, y_train
  
  def build_index(self):
    X_train, y_train = self.read_all_features()
    self.simsearcher.fit(X_train, y_train)
    
  def text_to_vector(self, query):
    text_tokens = clip.tokenize(query)
    with torch.no_grad():
      text_features = self.CLIP_MODEL.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

  # Source: https://stackoverflow.com/a/480227
  def f7(self, seq):
    seen = set()
    seen_add = seen.add # This is for performance improvement, don't remove
    return [x for x in seq if not (x in seen or seen_add(x))]
  
  def search_top_k(self, q, k=5, pool_size=1000, search_mod='Majority'):
    q = self.text_to_vector(q)
    nearest_data_points = self.simsearcher.get_nearest_labels(q, pool_size)
    
    if search_mod == 'Majority':
      topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
    elif  search_mod == 'Top-K':
      topKs = list(self.f7(nearest_data_points[0]))[:k]
    
    video_filename = [f'./Videos/{self.GAME_NAME}/' + self.class_id_to_file[x].replace('npy', 'mp4') for x in topKs]
    
    return video_filename



################ SEARCH CORE ################
# CRAETE CLIP MODEL
vit_model, vit_preprocess = clip.load("ViT-B/32")
vit_model.eval()

saved_searchers = {}
def gradio_search(query, game_name, selected_model, aggregator, pool_size, k=6):
  # print(query, game_name, selected_model, aggregator, pool_size)
  if f'{game_name}_{selected_model}' in saved_searchers.keys():
    searcher = saved_searchers[f'{game_name}_{selected_model}']
  else:
    if selected_model == 'ViT-B/32':
      model = vit_model
      searcher = GamePhysicsSearcher(CLIP_MODEL=model, GAME_NAME=game_name)
    else:
      raise
    
    saved_searchers[f'{game_name}_{selected_model}'] = searcher
  
  results = []
  relevant_videos = searcher.search_top_k(query, k=k, pool_size=pool_size, search_mod=aggregator)
  params = ', '.join(map(str, [query, game_name, selected_model, aggregator, pool_size]))
  results.append(params)
  results.extend(relevant_videos)
  print(results)
  return results

list_of_games = ['Grand Theft Auto V']


title = "CLIP + GamePhysics - Searching dataset of Gameplay bugs"
description = "Enter your query and select the game you want to search. The results will be displayed in the console."
article = """
This demo shows how to use the CLIP model to search for gameplay bugs in a video game.
"""

# GRADIO APP
iface = gr.Interface(fn=gradio_search, 
                     inputs =[ gr.inputs.Textbox(lines=1, placeholder='Search Query', default="A person flying in the air", label=None),
                               gr.inputs.Radio(list_of_games, label="Game To Search"),
                               gr.inputs.Radio(['ViT-B/32'], label="MODEL"),
                               gr.inputs.Radio(['Majority', 'Top-K'], label="Aggregator"),
                                gr.inputs.Slider(300, 2000, label="Pool Size", default=1000),
                             ], 
                     outputs=[
                              gr.outputs.Textbox(type="auto", label='Search Params'),
                              gr.outputs.Video(type='mp4', label='Result 1'),
                              gr.outputs.Video(type='mp4', label='Result 2'),
                              gr.outputs.Video(type='mp4', label='Result 3'),
                              gr.outputs.Video(type='mp4', label='Result 4'),
                              gr.outputs.Video(type='mp4', label='Result 5')],
                     examples=[
                               ["A red car", list_of_games[0], 'ViT-B/32', 'Top-K', 1000],
                               ["A person wearing pink", list_of_games[0], 'ViT-B/32', 'Top-K', 1000],
                               ["A car flying in the air", list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A person flying in the air",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A car in vertical position",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A bike inside a car",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A bike on a wall",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A car stuck in a rock",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ["A car stuck in a tree",  list_of_games[0], 'ViT-B/32', 'Majority', 1000],
                               ],
                     title=title,
                     description=description,
                     article=article,
                     enable_queue=True,
                    )
                     
iface.launch()