ubamba98's picture
Update app.py
80065f6
import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from transformers import CLIPProcessor, CLIPModel
def find_similar(image):
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
device = "cuda" if torch.cuda.is_available() else "cpu"
## Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
## Load data
photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
## Inference
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
search_photo_feature = search_photo_feature.cpu().numpy()
## Find similarity
similarities = list((search_photo_feature @ photo_features.T).squeeze(0))
## Return best image :)
idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1]
photo_id = photo_ids[idx]
try:
photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
except:
photo_data = photos.iloc[0]
response = requests.get(photo_data["photo_image_url"] + "?w=640")
img = PILIMAGE.open(BytesIO(response.content))
return img
iface = gr.Interface(fn=find_similar, inputs=gr.inputs.Image(), outputs=gr.outputs.Image(type="pil")).launch()