File size: 2,775 Bytes
609bf1e
 
 
 
 
 
 
 
 
 
d24bb26
c0c77aa
 
373a0b2
 
609bf1e
373a0b2
 
 
 
a8c63a3
af11455
 
 
 
 
 
 
609bf1e
 
 
 
 
 
 
 
 
 
 
 
 
391a1fb
609bf1e
 
391a1fb
 
609bf1e
 
391a1fb
 
609bf1e
 
 
 
 
9fd66a3
a8c63a3
 
1932c4f
609bf1e
 
eaf7419
609bf1e
 
 
 
 
 
3989d63
 
 
609bf1e
3989d63
bd9a56b
3989d63
 
 
6001755
 
609bf1e
 
 
 
 
 
3989d63
3eb4358
609bf1e
 
bb75ee4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss


# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
model.config.return_dict = False  # Set return_dict to False for JIT tracing
model.to(device)

# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device)  # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)

# Load faiss index
index = faiss.read_index("xbgp-faiss.index")

# Load faiss map
with open("xbgp-faiss-map.json", "r") as f:
    images = json.load(f)


def process_image(image):
    """
    Process the image and extract features using the DINOv2 model.
    """
    # Add your image processing code here.
    # This will include preprocessing the image, passing it through the model,
    # and then formatting the output (extracted features).

    # Convert to RGB if it isn't already
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Resize to 224px while maintaining aspect ratio
    width, height = image.size
    if width < height:
        w_percent = 224 / float(width)
        new_width = 224
        new_height = int(float(height) * float(w_percent))
    else:
        h_percent = 224 / float(height)
        new_height = 224
        new_width = int(float(width) * float(h_percent))
    image = image.resize((new_width, new_height), Image.LANCZOS)

    # Extract the features from the uploaded image
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)

        # Use the traced model for inference
        outputs = traced_model(inputs)

    # Normalize the features before search, whatever that means
    embeddings = outputs[0].mean(dim=1)
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)

    # Read the index file and perform search of top 50 images
    distances, indices = index.search(vector, 50)

    matches = []

    for idx, matching_gamerpic in enumerate(indices[0]):
        gamerpic = {}
        gamerpic["id"] = images[matching_gamerpic]
        gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"

        matches.append(gamerpic)

    return matches


# Create a Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),  # Adjust the shape as needed
    outputs="json",  # Or any other output format that suits your needs
).queue()

# Launch the Gradio app
iface.launch()