LucyintheSky commited on
Commit
3d52dd9
·
verified ·
1 Parent(s): 5d8dfa4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from safetensors.torch import load_model
6
+ from huggingface_hub import hf_hub_download
7
+ from timm import list_models, create_model
8
+ import os
9
+ import numpy as np
10
+
11
+ # Intialize the model
12
+ model_name='swin_s3_base_224'
13
+ model = create_model(
14
+ model_name,
15
+ num_classes=36
16
+ )
17
+ load_model(model,f'./{model_name}/model.safetensors')
18
+
19
+ # Define class names
20
+ class_names = ["3/4 Sleeve", "Accessory", "Babydoll", "Closed Back", "Corset", "Crochet", "Cutouts", "Draped", "Floral", "Gloves", "Halter", "Lace", "Long", "Long Sleeve", "Midi", "No Slit", "Off The Shoulder", "One Shoulder", "Open Back", "Pockets", "Print", "Puff Sleeve", "Ruched", "Satin", "Sequins", "Shimmer", "Short", "Short Sleeve", "Side Slit", "Square Neck", "Strapless", "Sweetheart Neck", "Tight", "V-Neck", "Velvet", "Wrap"]
21
+ label2id = {c:idx for idx,c in enumerate(class_names)}
22
+ id2label = {idx:c for idx,c in enumerate(class_names)}
23
+
24
+ def predict_features(image_path):
25
+ # Load PIL image
26
+ pil_image = Image.open(image_path).convert('RGB')
27
+
28
+ # Define transformations to resize and convert image to tensor
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor()
32
+ ])
33
+ tensor_image = transform(pil_image)
34
+
35
+ inputs = tensor_image.unsqueeze(0)
36
+
37
+ with torch.no_grad():
38
+ logits = model(inputs)
39
+
40
+ # apply sigmoid activation to convert logits to probabilities
41
+ # getting labels with confidence threshold of 0.5
42
+ predictions = logits.sigmoid() > 0.5
43
+
44
+ # converting one-hot encoded predictions back to list of labels
45
+ predictions = predictions.float().numpy().flatten() # convert boolean predictions to float
46
+ pred_labels = np.where(predictions==1)[0] # find indices where prediction is 1
47
+ pred_labels = ([id2label[label] for label in pred_labels]) # converting integer labels to string
48
+ print(pred_labels)
49
+ return pred_labels
50
+
51
+
52
+
53
+ def greet(image):
54
+ return str(predict_features(image))
55
+
56
+ demo = gr.Interface(fn=greet, inputs=gr.Image(), outputs="text")
57
+ demo.launch()