draw_to_search / app.py
osanseviero's picture
Update app.py
52fb80a
raw
history blame
2.36 kB
import os
from pathlib import Path
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import torch
from torch import nn
import gradio as gr
import requests
LABELS = Path('class_names.txt').read_text().splitlines()
class_model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
class_model.load_state_dict(state_dict, strict=False)
class_model.eval()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = pd.read_csv('data2.csv')
embeddings_npy = np.load('embeddings.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
def download_img(path):
img_data = requests.get(path).content
local_path = path.split("/")[-1] + ".jpg"
with open(local_path, 'wb') as handler:
handler.write(img_data)
return local_path
def predict(query):
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
out = class_model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
query = values[0]
n_results=3
text_embeddings = compute_text_embeddings([query]).detach().numpy()
results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
paths = [download_img(df.iloc[i]['path']) for i in results]
print(paths)
return {LABELS[i]: v.item() for i, v in zip(indices, values)}, paths
title = "Draw to Search"
iface = gr.Interface(
fn=predict,
inputs='sketchpad',
outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
title=title,
)
iface.launch(debug=True)