wwnvp01 commited on
Commit
11b1ae2
·
verified ·
1 Parent(s): 4994143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py CHANGED
@@ -14,3 +14,57 @@ from torchcam.utils import overlay_mask
14
  from torchvision.transforms.functional import to_pil_image
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from torchvision.transforms.functional import to_pil_image
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
17
+ dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
18
+ # Set the device to GPU if available, otherwise use CPU
19
+ device = torch.device("cpu")
20
+ dinov2_vits14.to(device)
21
+
22
+ # Define the transformations: convert to tensor, resize, and normalize
23
+ transform_image = T.Compose([T.ToTensor(), T.Resize(224), T.Normalize([0.5], [0.5])])
24
+
25
+ model = torch.load("dress_model.pth")
26
+ model.eval()
27
+
28
+ with open('saved_dress_morph.pkl', 'rb') as f:
29
+ loaded_dict = pickle.load(f)
30
+
31
+ def detect(image):
32
+
33
+ size = max(image.size)
34
+ new_im = Image.new('RGB', (size, size), color = 0) # Create a squared black image
35
+ new_im.paste(image)
36
+
37
+
38
+ with torch.no_grad():
39
+
40
+ # Apply transformations to the image and move it to the appropriate device
41
+ image_tensor = transform_image(new_im).to(device)
42
+
43
+ # Extract features using the DinoV2 model
44
+ dino_embedding = dinov2_vits14(image_tensor.unsqueeze(0)).cpu()
45
+ dino_numpy = dinov2_vits14(image_tensor.unsqueeze(0)).cpu().numpy()
46
+
47
+ model.eval()
48
+
49
+ with torch.no_grad():
50
+ outputs = model(dino_embedding)
51
+ pred_dress_cat = round(torch.argmax(outputs, dim = 1).tolist()[0])
52
+
53
+ pred_dress = dress_dict[pred_dress_cat]
54
+
55
+ pred_dress_s = f"Predicted Dress Category: {pred_dress}"
56
+
57
+ cosine_sim = cosine_similarity(dino_numpy.reshape(1, -1), mean_features.reshape(1, -1)).item()
58
+ cosin = round(float(cosin_sim), 2)
59
+
60
+ return pred_dress_s, cosin
61
+
62
+ demo = gr.Interface(
63
+ fn=detect,
64
+ inputs=gr.Image(type="numpy", label="Upload an image"),
65
+ outputs=[
66
+ gr.Textbox(label = "Predictions"),
67
+ gr.Number(label="Typicality Score")]
68
+ title='Dress Classification')
69
+
70
+