Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import pandas as pd, numpy as np | |
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel | |
import torch | |
from torch import nn | |
import gradio as gr | |
import requests | |
LABELS = Path('class_names.txt').read_text().splitlines() | |
class_model = nn.Sequential( | |
nn.Conv2d(1, 32, 3, padding='same'), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Conv2d(32, 64, 3, padding='same'), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Conv2d(64, 128, 3, padding='same'), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Flatten(), | |
nn.Linear(1152, 256), | |
nn.ReLU(), | |
nn.Linear(256, len(LABELS)), | |
) | |
state_dict = torch.load('pytorch_model.bin', map_location='cpu') | |
class_model.load_state_dict(state_dict, strict=False) | |
class_model.eval() | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
df = pd.read_csv('data2.csv') | |
embeddings_npy = np.load('embeddings.npy') | |
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) | |
def compute_text_embeddings(list_of_strings): | |
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) | |
return model.get_text_features(**inputs) | |
def download_img(path): | |
img_data = requests.get(path).content | |
local_path = path.split("/")[-1] + ".jpg" | |
with open(local_path, 'wb') as handler: | |
handler.write(img_data) | |
return local_path | |
def predict(query): | |
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. | |
with torch.no_grad(): | |
out = class_model(x) | |
probabilities = torch.nn.functional.softmax(out[0], dim=0) | |
values, indices = torch.topk(probabilities, 5) | |
query = values[0] | |
n_results=3 | |
text_embeddings = compute_text_embeddings([query]).detach().numpy() | |
results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1] | |
paths = [download_img(df.iloc[i]['path']) for i in results] | |
print(paths) | |
return {LABELS[i]: v.item() for i, v in zip(indices, values)}, paths | |
title = "Draw to Search" | |
iface = gr.Interface( | |
fn=predict, | |
inputs='sketchpad', | |
outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")], | |
title=title, | |
) | |
iface.launch(debug=True) | |