import os from pathlib import Path import pandas as pd, numpy as np from transformers import CLIPProcessor, CLIPTextModel, CLIPModel import torch from torch import nn import gradio as gr import requests from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True LABELS = Path('class_names.txt').read_text().splitlines() class_model = nn.Sequential( nn.Conv2d(1, 32, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding='same'), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(1152, 256), nn.ReLU(), nn.Linear(256, len(LABELS)), ) state_dict = torch.load('pytorch_model.bin', map_location='cpu') class_model.load_state_dict(state_dict, strict=False) class_model.eval() model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") df = pd.read_csv('clip.csv') embeddings_npy = np.load('clip.npy') embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) def compute_text_embeddings(list_of_strings): inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) return model.get_text_features(**inputs) def compute_image_embeddings(list_of_images): inputs = processor(images=list_of_images, return_tensors="pt", padding=True) return model.get_image_features(**inputs) def load_image(image, same_height=False): # im = Image.open(path) im = Image.fromarray(np.uint8(image)) if im.mode != 'RGB': im = im.convert('RGB') if same_height: ratio = 224/im.size[1] return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) else: ratio = 224/min(im.size) return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) def download_img(identifier, url): local_path = f"{identifier}.jpg" if not os.path.isfile(local_path): img_data = requests.get(url).content with open(local_path, 'wb') as handler: handler.write(img_data) return local_path def predict(image=None, text=None, sketch=None): if image is not None: input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy() topk = {"local": 100} else: if text: query = text topk = {text: 100} else: x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. with torch.no_grad(): out = class_model(x) probabilities = torch.nn.functional.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) query = LABELS[indices[0]] topk = {LABELS[i]: v.item() for i, v in zip(indices, values)} input_embeddings = compute_text_embeddings([query]).detach().numpy() n_results = 3 results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1] outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results] outputs.insert(0, topk) print(outputs) return outputs def predict_sketch(sketch): return predict(None, None, sketch) title = "Draw to search in the Nasjonalbiblioteket" description = "Find images in the Nasjonalbiblioteket image collections based on what you draw" interface = gr.Interface( fn=predict_sketch, inputs=["sketchpad"], outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")], title=title, description=description, live=True ) interface.launch(debug=True)