File size: 2,172 Bytes
8734f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b1ae2
 
 
 
 
 
cf741e2
11b1ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1ec10
11b1ae2
 
dd1ec10
 
11b1ae2
dd1ec10
11b1ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37c0df
 
11b1ae2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Import packages
import pickle
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from sklearn.metrics.pairwise import cosine_similarity

dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cpu")
dinov2_vits14.to(device)

# Define the transformations: convert to tensor, resize, and normalize
transform_image = transforms.Compose([transforms.ToTensor(), transforms.Resize(224), transforms.Normalize([0.5], [0.5])])

model = torch.load("dress_model.pth")
model.eval()

with open('saved_dress_morph.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)

def detect(image):
    
    size = max(image.size)
    new_im = Image.new('RGB', (size, size), color = 0) # Create a squared black image
    new_im.paste(image)

    
    with torch.no_grad():
  
  # Apply transformations to the image and move it to the appropriate device
        image_tensor = transform_image(new_im).to(device)
  
  # Extract features using the DinoV2 model
        dino_embedding = dinov2_vits14(image_tensor.unsqueeze(0)).cpu()
        dino_numpy = dinov2_vits14(image_tensor.unsqueeze(0)).cpu().numpy()

    
    with torch.no_grad():
        outputs = model(dino_embedding)
        pred_dress_cat = round(torch.argmax(outputs, dim = 1).tolist()[0])

    pred_dress = dress_dict[pred_dress_cat]

    pred_dress_s = f"Predicted Dress Category:  {pred_dress}"

    cosine_sim = cosine_similarity(dino_numpy.reshape(1, -1), mean_features.reshape(1, -1)).item()
    cosin = round(float(cosin_sim), 2)

    return pred_dress_s, cosin

    demo = gr.Interface(
    fn=detect,
    inputs=gr.Image(type="numpy", label="Upload an image"),
    outputs=[gr.Textbox(label = "Predictions"),
             gr.Number(label="Typicality Score")],
    title='Dress Classification')