Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import json
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
-
from
|
13 |
|
14 |
# Load CLIP model and tokenizer
|
15 |
@st.cache_resource
|
@@ -24,15 +24,13 @@ clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
|
24 |
|
25 |
# Load YOLOS model
|
26 |
@st.cache_resource
|
27 |
-
def
|
28 |
-
|
29 |
-
model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
|
30 |
-
return processor, model
|
31 |
|
32 |
-
|
33 |
|
34 |
# Define the categories
|
35 |
-
CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'glove', 'shoe', 'bag', 'wallet', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
|
36 |
|
37 |
# Helper functions
|
38 |
def load_image_from_url(url, max_retries=3):
|
@@ -92,21 +90,22 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
92 |
return results
|
93 |
|
94 |
def detect_clothing(image):
|
95 |
-
inputs = yolos_processor(images=image, return_tensors="pt")
|
96 |
-
outputs = yolos_model(**inputs)
|
97 |
|
98 |
-
target_sizes = torch.tensor([image.size[::-1]])
|
99 |
-
results =
|
|
|
100 |
|
101 |
categories = []
|
102 |
-
for
|
103 |
-
|
104 |
-
category =
|
105 |
-
if category in
|
106 |
categories.append({
|
107 |
'category': category,
|
108 |
-
'bbox':
|
109 |
-
'confidence':
|
110 |
})
|
111 |
return categories
|
112 |
|
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
+
from ultralytics import YOLO
|
13 |
|
14 |
# Load CLIP model and tokenizer
|
15 |
@st.cache_resource
|
|
|
24 |
|
25 |
# Load YOLOS model
|
26 |
@st.cache_resource
|
27 |
+
def load_yolo_model():
|
28 |
+
return YOLO("./best.pt")
|
|
|
|
|
29 |
|
30 |
+
yolo_model = load_yolo_model()
|
31 |
|
32 |
# Define the categories
|
33 |
+
#CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'glove', 'shoe', 'bag', 'wallet', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
|
34 |
|
35 |
# Helper functions
|
36 |
def load_image_from_url(url, max_retries=3):
|
|
|
90 |
return results
|
91 |
|
92 |
def detect_clothing(image):
|
93 |
+
#inputs = yolos_processor(images=image, return_tensors="pt")
|
94 |
+
#outputs = yolos_model(**inputs)
|
95 |
|
96 |
+
#target_sizes = torch.tensor([image.size[::-1]])
|
97 |
+
results = yolo_model(image)
|
98 |
+
detections = results[0].boxes.data.cpu().numpy()
|
99 |
|
100 |
categories = []
|
101 |
+
for detection in detections:
|
102 |
+
x1, y1, x2, y2, conf, cls = detection
|
103 |
+
category = yolo_model.names[int(cls)]
|
104 |
+
if category in ['sunglass','hat','jacket','shirt','pants','shorts','skirt','dress','bag','shoe']:
|
105 |
categories.append({
|
106 |
'category': category,
|
107 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
108 |
+
'confidence': conf
|
109 |
})
|
110 |
return categories
|
111 |
|