Spaces:
Running
Running
File size: 2,775 Bytes
609bf1e d24bb26 c0c77aa 373a0b2 609bf1e 373a0b2 a8c63a3 af11455 609bf1e 391a1fb 609bf1e 391a1fb 609bf1e 391a1fb 609bf1e 9fd66a3 a8c63a3 1932c4f 609bf1e eaf7419 609bf1e 3989d63 609bf1e 3989d63 bd9a56b 3989d63 6001755 609bf1e 3989d63 3eb4358 609bf1e bb75ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss
# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
model.config.return_dict = False # Set return_dict to False for JIT tracing
model.to(device)
# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device) # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)
# Load faiss index
index = faiss.read_index("xbgp-faiss.index")
# Load faiss map
with open("xbgp-faiss-map.json", "r") as f:
images = json.load(f)
def process_image(image):
"""
Process the image and extract features using the DINOv2 model.
"""
# Add your image processing code here.
# This will include preprocessing the image, passing it through the model,
# and then formatting the output (extracted features).
# Convert to RGB if it isn't already
if image.mode != "RGB":
image = image.convert("RGB")
# Resize to 224px while maintaining aspect ratio
width, height = image.size
if width < height:
w_percent = 224 / float(width)
new_width = 224
new_height = int(float(height) * float(w_percent))
else:
h_percent = 224 / float(height)
new_height = 224
new_width = int(float(width) * float(h_percent))
image = image.resize((new_width, new_height), Image.LANCZOS)
# Extract the features from the uploaded image
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)
# Use the traced model for inference
outputs = traced_model(inputs)
# Normalize the features before search, whatever that means
embeddings = outputs[0].mean(dim=1)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
# Read the index file and perform search of top 50 images
distances, indices = index.search(vector, 50)
matches = []
for idx, matching_gamerpic in enumerate(indices[0]):
gamerpic = {}
gamerpic["id"] = images[matching_gamerpic]
gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"
matches.append(gamerpic)
return matches
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Adjust the shape as needed
outputs="json", # Or any other output format that suits your needs
).queue()
# Launch the Gradio app
iface.launch()
|