Spaces:
Sleeping
Sleeping
# Import packages | |
import pickle | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
from PIL import Image | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import torch.nn.functional as F | |
from torchcam.methods import SmoothGradCAMpp | |
from torchcam.utils import overlay_mask | |
from torchvision.transforms.functional import to_pil_image | |
from sklearn.metrics.pairwise import cosine_similarity | |
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14") | |
# Set the device to GPU if available, otherwise use CPU | |
device = torch.device("cpu") | |
dinov2_vits14.to(device) | |
# Define the transformations: convert to tensor, resize, and normalize | |
transform_image = transforms.Compose([transforms.ToTensor(), transforms.Resize(224), transforms.Normalize([0.5], [0.5])]) | |
model = torch.load("dress_model.pth") | |
model.eval() | |
with open('saved_dress_morph.pkl', 'rb') as f: | |
loaded_dict = pickle.load(f) | |
def detect(image): | |
size = max(image.size) | |
new_im = Image.new('RGB', (size, size), color = 0) # Create a squared black image | |
new_im.paste(image) | |
with torch.no_grad(): | |
# Apply transformations to the image and move it to the appropriate device | |
image_tensor = transform_image(new_im).to(device) | |
# Extract features using the DinoV2 model | |
dino_embedding = dinov2_vits14(image_tensor.unsqueeze(0)).cpu() | |
dino_numpy = dinov2_vits14(image_tensor.unsqueeze(0)).cpu().numpy() | |
with torch.no_grad(): | |
outputs = model(dino_embedding) | |
pred_dress_cat = round(torch.argmax(outputs, dim = 1).tolist()[0]) | |
pred_dress = dress_dict[pred_dress_cat] | |
pred_dress_s = f"Predicted Dress Category: {pred_dress}" | |
cosine_sim = cosine_similarity(dino_numpy.reshape(1, -1), mean_features.reshape(1, -1)).item() | |
cosin = round(float(cosin_sim), 2) | |
return pred_dress_s, cosin | |
demo = gr.Interface( | |
fn=detect, | |
inputs=gr.Image(type="numpy", label="Upload an image"), | |
outputs=[gr.Textbox(label = "Predictions"), | |
gr.Number(label="Typicality Score")], | |
title='Dress Classification') | |