File size: 3,715 Bytes
0f14459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64b2445
 
0f14459
973d811
1e0ec18
973d811
0f14459
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

from pathlib import Path
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import torch
from torch import nn
import gradio as gr
import requests
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


LABELS = Path('class_names.txt').read_text().splitlines()
class_model = nn.Sequential(
    nn.Conv2d(1, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(64, 128, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(1152, 256),
    nn.ReLU(),
    nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
class_model.load_state_dict(state_dict, strict=False)
class_model.eval()


model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df =  pd.read_csv('clip.csv')
embeddings_npy = np.load('clip.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))


def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.get_text_features(**inputs)


def compute_image_embeddings(list_of_images):
    inputs = processor(images=list_of_images, return_tensors="pt", padding=True)
    return model.get_image_features(**inputs)


def load_image(image, same_height=False):
    # im = Image.open(path)
    im = Image.fromarray(np.uint8(image))
    if im.mode != 'RGB':
        im = im.convert('RGB')
    if same_height:
        ratio = 224/im.size[1]
        return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))
    else:
        ratio = 224/min(im.size)
        return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))


def download_img(identifier, url):
    local_path = f"{identifier}.jpg"
    if not os.path.isfile(local_path):
        img_data = requests.get(url).content
        with open(local_path, 'wb') as handler:
            handler.write(img_data)
    return local_path


def predict(image=None, text=None, sketch=None):
    if image is not None:
        input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy()
        topk = {"local": 100}
    else:
        if text:
            query = text
            topk = {text: 100}
        else:
            x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
            with torch.no_grad():
                out = class_model(x)
            probabilities = torch.nn.functional.softmax(out[0], dim=0)
            values, indices = torch.topk(probabilities, 5)
            query = LABELS[indices[0]]
            topk = {LABELS[i]: v.item() for i, v in zip(indices, values)}
        input_embeddings = compute_text_embeddings([query]).detach().numpy()

    n_results = 3
    results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1]
    outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results]
    outputs.insert(0, topk)
    print(outputs)
    return outputs


def predict_sketch(sketch):
    return predict(None, None, sketch)


title = "Draw to search in the Nasjonalbiblioteket"
description = "Find images in the Nasjonalbiblioteket image collections based on what you draw"
interface = gr.Interface(
  fn=predict_sketch,
  inputs=["sketchpad"],
  outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
  title=title,
  description=description,
  live=True
)
interface.launch(debug=True)