wwnvp01's picture
Update app.py
cf741e2 verified
raw
history blame
2.17 kB
# Import packages
import pickle
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from sklearn.metrics.pairwise import cosine_similarity
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cpu")
dinov2_vits14.to(device)
# Define the transformations: convert to tensor, resize, and normalize
transform_image = transforms.Compose([transforms.ToTensor(), transforms.Resize(224), transforms.Normalize([0.5], [0.5])])
model = torch.load("dress_model.pth")
model.eval()
with open('saved_dress_morph.pkl', 'rb') as f:
loaded_dict = pickle.load(f)
def detect(image):
size = max(image.size)
new_im = Image.new('RGB', (size, size), color = 0) # Create a squared black image
new_im.paste(image)
with torch.no_grad():
# Apply transformations to the image and move it to the appropriate device
image_tensor = transform_image(new_im).to(device)
# Extract features using the DinoV2 model
dino_embedding = dinov2_vits14(image_tensor.unsqueeze(0)).cpu()
dino_numpy = dinov2_vits14(image_tensor.unsqueeze(0)).cpu().numpy()
with torch.no_grad():
outputs = model(dino_embedding)
pred_dress_cat = round(torch.argmax(outputs, dim = 1).tolist()[0])
pred_dress = dress_dict[pred_dress_cat]
pred_dress_s = f"Predicted Dress Category: {pred_dress}"
cosine_sim = cosine_similarity(dino_numpy.reshape(1, -1), mean_features.reshape(1, -1)).item()
cosin = round(float(cosin_sim), 2)
return pred_dress_s, cosin
demo = gr.Interface(
fn=detect,
inputs=gr.Image(type="numpy", label="Upload an image"),
outputs=[gr.Textbox(label = "Predictions"),
gr.Number(label="Typicality Score")],
title='Dress Classification')