Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,3 +14,57 @@ from torchcam.utils import overlay_mask
|
|
14 |
from torchvision.transforms.functional import to_pil_image
|
15 |
from sklearn.metrics.pairwise import cosine_similarity
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from torchvision.transforms.functional import to_pil_image
|
15 |
from sklearn.metrics.pairwise import cosine_similarity
|
16 |
|
17 |
+
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
|
18 |
+
# Set the device to GPU if available, otherwise use CPU
|
19 |
+
device = torch.device("cpu")
|
20 |
+
dinov2_vits14.to(device)
|
21 |
+
|
22 |
+
# Define the transformations: convert to tensor, resize, and normalize
|
23 |
+
transform_image = T.Compose([T.ToTensor(), T.Resize(224), T.Normalize([0.5], [0.5])])
|
24 |
+
|
25 |
+
model = torch.load("dress_model.pth")
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
with open('saved_dress_morph.pkl', 'rb') as f:
|
29 |
+
loaded_dict = pickle.load(f)
|
30 |
+
|
31 |
+
def detect(image):
|
32 |
+
|
33 |
+
size = max(image.size)
|
34 |
+
new_im = Image.new('RGB', (size, size), color = 0) # Create a squared black image
|
35 |
+
new_im.paste(image)
|
36 |
+
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
|
40 |
+
# Apply transformations to the image and move it to the appropriate device
|
41 |
+
image_tensor = transform_image(new_im).to(device)
|
42 |
+
|
43 |
+
# Extract features using the DinoV2 model
|
44 |
+
dino_embedding = dinov2_vits14(image_tensor.unsqueeze(0)).cpu()
|
45 |
+
dino_numpy = dinov2_vits14(image_tensor.unsqueeze(0)).cpu().numpy()
|
46 |
+
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
outputs = model(dino_embedding)
|
51 |
+
pred_dress_cat = round(torch.argmax(outputs, dim = 1).tolist()[0])
|
52 |
+
|
53 |
+
pred_dress = dress_dict[pred_dress_cat]
|
54 |
+
|
55 |
+
pred_dress_s = f"Predicted Dress Category: {pred_dress}"
|
56 |
+
|
57 |
+
cosine_sim = cosine_similarity(dino_numpy.reshape(1, -1), mean_features.reshape(1, -1)).item()
|
58 |
+
cosin = round(float(cosin_sim), 2)
|
59 |
+
|
60 |
+
return pred_dress_s, cosin
|
61 |
+
|
62 |
+
demo = gr.Interface(
|
63 |
+
fn=detect,
|
64 |
+
inputs=gr.Image(type="numpy", label="Upload an image"),
|
65 |
+
outputs=[
|
66 |
+
gr.Textbox(label = "Predictions"),
|
67 |
+
gr.Number(label="Typicality Score")]
|
68 |
+
title='Dress Classification')
|
69 |
+
|
70 |
+
|