File size: 5,756 Bytes
068d302 757440e 068d302 6cbc414 068d302 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from datasets import load_dataset
from torchvision import transforms
import torch
from timm import create_model
from omegaconf import OmegaConf
import faiss
import pickle
import gradio as gr
import os
import joblib
import torch.nn as nn
from typing import Dict, Iterable, Callable
from torch import Tensor
import torchvision
from PIL import Image
def get_model(args,arch,load_from,arch_path):
if load_from == 'timm':
model = create_model(arch,pretrained = True).to(args.PARAMETERS.device)
print("Load model timm")
elif load_from == 'torchvision':
if arch == 'resnet50':
model = torchvision.models.resnet50(pretrained=False)
if len(arch_path)>0:
print("Loading pretrained Model")
model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True)
model.eval()
return model
def get_transform(args):
return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]),
transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]),
transforms.ToTensor()])
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
layer.register_forward_hook(self.save_outputs_hook(layer_id))
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
def _load_dataset(args):
if args.PARAMETERS.metric == 'L2':
faiss_metric = faiss.METRIC_L2
dataset = load_dataset(args.PARAMETERS.dataset,split = 'train')
dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric)
dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric)
return dataset
args = OmegaConf.load("configs/resnet.yaml")
wiki_dataset = _load_dataset(args)
TRANSFORMS = get_transform(args)
robust_model = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path)
non_robust_model = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path)
fe_robust_model = FeatureExtractor(robust_model,layers = [args.ROBUST.layer])
fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer])
# +
def retrieval_fn(image,radio):
#try:
image = Image.fromarray(image)
#except:
#pass
image = TRANSFORMS(image).unsqueeze(0)
image = image.to(args.PARAMETERS.device)
if radio == 'robust':
emb = fe_robust_model(image)[args.ROBUST.layer]
emb = emb.view(1,-1).detach().cpu().numpy()
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col,
query = emb,
k = 3)
elif radio == 'standard':
emb = fe_nonrobust_model(image)[args.NONROBUST.layer]
emb = emb.view(1,-1).detach().cpu().numpy()
scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col,
query = emb,
k=3)
return scores,retrieved_examples
def gradio_fn(image,radio):
scores,retrieved_examples = retrieval_fn(image,radio)
m = []
for description,image,score in zip(retrieved_examples['description'],
retrieved_examples['image'],
scores):
m.append(description)
m.append(image)
return m
# -
if __name__ == '__main__':
demo = gr.Blocks()
with demo:
gr.Markdown("# Robust vs Standard Image Retrieval")
with gr.Tabs():
with gr.TabItem("Upload your Image"):
with gr.Row():
with gr.Column():
with gr.Row():
image_input = gr.Image(label="Input Image")
with gr.Row():
radio_button = gr.Radio(["robust","standard"],
value = "robust",
label = "OD Model")
with gr.Row():
calculate_button = gr.Button("Compute")
with gr.Column():
textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image1 = gr.Image(label="1st Best match")
textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image2 = gr.Image(label="2nd Best match")
textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
output_image3 = gr.Image(label="3rd Best match")
calculate_button.click(fn = gradio_fn,
inputs = [image_input,radio_button],
outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3])
demo.launch(share = False,debug = True)
|