"hi"
Browse files- .gitignore +6 -0
- README.md +1 -1
- README.yaml +0 -99
- app.py +131 -0
- example.ipynb +150 -0
- requirements.txt +7 -0
- sample_images/image_five.jpg +0 -0
- sample_images/image_four.jpg +0 -0
- sample_images/image_six.jpg +0 -0
- yolo/BodyMask.py +248 -0
- yolo/utils.py +291 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
gradio_cached_examples/
|
3 |
+
checkpoint-*
|
4 |
+
*/example.ipynb
|
5 |
+
|
6 |
+
*.pyc
|
README.md
CHANGED
@@ -56,7 +56,7 @@ To use this model, you'll need to have the appropriate YOLO framework installed.
|
|
56 |
To use the model for inference, you can use the following Python script:
|
57 |
|
58 |
```python
|
59 |
-
from
|
60 |
|
61 |
# Load the model
|
62 |
model = YOLO('path/to/your/model.pt')
|
|
|
56 |
To use the model for inference, you can use the following Python script:
|
57 |
|
58 |
```python
|
59 |
+
from ultralytics import YOLO
|
60 |
|
61 |
# Load the model
|
62 |
model = YOLO('path/to/your/model.pt')
|
README.yaml
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
---
|
2 |
-
language:
|
3 |
-
- "en"
|
4 |
-
thumbnail: "https://example.com/path/to/your/thumbnail.jpg"
|
5 |
-
tags:
|
6 |
-
- yolo
|
7 |
-
- object-detection
|
8 |
-
- image-segmentation
|
9 |
-
- computer-vision
|
10 |
-
- human-body-parts
|
11 |
-
license: "mit"
|
12 |
-
datasets:
|
13 |
-
- custom_human_body_parts_dataset
|
14 |
-
metrics:
|
15 |
-
- mean_average_precision
|
16 |
-
- intersection_over_union
|
17 |
-
base_model: "ultralytics/yolov5yolov8x-seg"
|
18 |
-
---
|
19 |
-
|
20 |
-
# YOLO Segmentation Model for Human Body Parts and Objects
|
21 |
-
|
22 |
-
This model is a fine-tuned version of YOLOv5 for segmenting human body parts and objects. It can detect and segment 11 different classes including various body parts, outfits, and phones.
|
23 |
-
|
24 |
-
## Model Details
|
25 |
-
|
26 |
-
- **Model Type:** YOLOv8 for Instance Segmentation
|
27 |
-
- **Task:** Segmentation
|
28 |
-
- **Fine-tuning Dataset:** Custom dataset of human body parts and objects
|
29 |
-
- **Number of Classes:** 11
|
30 |
-
|
31 |
-
## Classes
|
32 |
-
|
33 |
-
The model can detect and segment the following classes:
|
34 |
-
|
35 |
-
0. Hair
|
36 |
-
1. Face
|
37 |
-
2. Neck
|
38 |
-
3. Arm
|
39 |
-
4. Hand
|
40 |
-
5. Back
|
41 |
-
6. Leg
|
42 |
-
7. Foot
|
43 |
-
8. Outfit
|
44 |
-
9. Person
|
45 |
-
10. Phone
|
46 |
-
|
47 |
-
## Usage
|
48 |
-
|
49 |
-
This model can be used for various applications, including:
|
50 |
-
|
51 |
-
- Human pose estimation
|
52 |
-
- Gesture recognition
|
53 |
-
- Fashion analysis
|
54 |
-
- Person tracking
|
55 |
-
- Human-computer interaction
|
56 |
-
|
57 |
-
For detailed usage instructions, please refer to the model's README file.
|
58 |
-
|
59 |
-
## Training Procedure
|
60 |
-
|
61 |
-
The model was fine-tuned on a custom dataset of annotated images containing human body parts and objects. The training process involved transfer learning from the base YOLOv8 model, with adjustments made to the final layers to accommodate the new class structure.
|
62 |
-
|
63 |
-
## Evaluation Results
|
64 |
-
|
65 |
-
(Note: Replace these placeholder metrics with your actual evaluation results)
|
66 |
-
|
67 |
-
lr/pg0:0.000572628
|
68 |
-
lr/pg1:0.000572628
|
69 |
-
lr/pg2:0.000572628
|
70 |
-
metrics/mAP50-95(B):0.53001
|
71 |
-
metrics/mAP50-95(M):0.42367
|
72 |
-
metrics/mAP50(B):0.69407
|
73 |
-
metrics/mAP50(M):0.61714
|
74 |
-
metrics/precision(B):0.7047
|
75 |
-
metrics/precision(M):0.68041
|
76 |
-
metrics/recall(B):0.68802
|
77 |
-
metrics/recall(M):0.62248
|
78 |
-
model/GFLOPs:344.557
|
79 |
-
model/parameters:71,761,441
|
80 |
-
model/speed_PyTorch(ms):5.813
|
81 |
-
train/box_loss:0.54718
|
82 |
-
train/cls_loss:0.52977
|
83 |
-
train/dfl_loss:0.95171
|
84 |
-
train/seg_loss:1.34628
|
85 |
-
val/box_loss:0.80538
|
86 |
-
val/cls_loss:0.83434
|
87 |
-
val/dfl_loss:1.18352
|
88 |
-
val/seg_loss:2.19488
|
89 |
-
|
90 |
-
|
91 |
-
## Limitations and Biases
|
92 |
-
|
93 |
-
- The model's performance may vary depending on lighting conditions and image quality.
|
94 |
-
- It may have difficulty with occluded or partially visible body parts.
|
95 |
-
- The model's performance on diverse body types and skin tones should be carefully evaluated to ensure fairness and inclusivity.
|
96 |
-
|
97 |
-
## Ethical Considerations
|
98 |
-
|
99 |
-
Users of this model should be aware of privacy concerns related to human body detection and ensure they have appropriate consent for its application. The model should not be used for surveillance or any application that could infringe on personal privacy without explicit consent.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from ultralytics import YOLO
|
4 |
+
from yolo.BodyMask import BodyMask
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from matplotlib import patches
|
8 |
+
from skimage.transform import resize
|
9 |
+
from PIL import Image
|
10 |
+
import io
|
11 |
+
|
12 |
+
model_id = os.path.abspath("yolo-human-parse-epoch-125.pt")
|
13 |
+
|
14 |
+
|
15 |
+
def display_image_with_masks(image, results, cols=4):
|
16 |
+
# Convert PIL Image to numpy array
|
17 |
+
image_np = np.array(image)
|
18 |
+
|
19 |
+
# Check image dimensions
|
20 |
+
if image_np.ndim != 3 or image_np.shape[2] != 3:
|
21 |
+
raise ValueError("Image must be a 3-dimensional array with 3 color channels")
|
22 |
+
|
23 |
+
# Number of masks
|
24 |
+
n = len(results)
|
25 |
+
rows = (n + cols - 1) // cols # Calculate required number of rows
|
26 |
+
|
27 |
+
# Setting up the plot
|
28 |
+
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
|
29 |
+
axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
|
30 |
+
|
31 |
+
for i, result in enumerate(results):
|
32 |
+
mask = result["mask"]
|
33 |
+
label = result["label"]
|
34 |
+
score = float(result["score"])
|
35 |
+
|
36 |
+
# Convert PIL mask to numpy array and resize if necessary
|
37 |
+
mask_np = np.array(mask)
|
38 |
+
if mask_np.shape != image_np.shape[:2]:
|
39 |
+
mask_np = resize(
|
40 |
+
mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
|
41 |
+
)
|
42 |
+
mask_np = (mask_np > 0.5).astype(
|
43 |
+
np.uint8
|
44 |
+
) # Threshold back to binary after resize
|
45 |
+
|
46 |
+
# Create an overlay where mask is True
|
47 |
+
overlay = np.zeros_like(image_np)
|
48 |
+
overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
|
49 |
+
|
50 |
+
# Combine the image and the overlay
|
51 |
+
combined = image_np.copy()
|
52 |
+
indices = np.where(mask_np > 0)
|
53 |
+
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
|
54 |
+
|
55 |
+
# Show the combined image
|
56 |
+
ax = axs[i]
|
57 |
+
ax.imshow(combined)
|
58 |
+
ax.axis("off")
|
59 |
+
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
|
60 |
+
rect = patches.Rectangle(
|
61 |
+
(0, 0),
|
62 |
+
image_np.shape[1],
|
63 |
+
image_np.shape[0],
|
64 |
+
linewidth=1,
|
65 |
+
edgecolor="r",
|
66 |
+
facecolor="none",
|
67 |
+
)
|
68 |
+
ax.add_patch(rect)
|
69 |
+
|
70 |
+
# Hide unused subplots if the total number of masks is not a multiple of cols
|
71 |
+
for idx in range(i + 1, rows * cols):
|
72 |
+
axs[idx].axis("off")
|
73 |
+
|
74 |
+
plt.tight_layout()
|
75 |
+
|
76 |
+
# Save the plot to a bytes buffer
|
77 |
+
buf = io.BytesIO()
|
78 |
+
plt.savefig(buf, format="png")
|
79 |
+
buf.seek(0)
|
80 |
+
|
81 |
+
# Clear the current figure
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
return buf
|
85 |
+
|
86 |
+
|
87 |
+
def perform_segmentation(input_image):
|
88 |
+
bm = BodyMask(input_image, model_id=model_id, resize_to=640)
|
89 |
+
results = bm.results
|
90 |
+
buf = display_image_with_masks(input_image, results)
|
91 |
+
|
92 |
+
# Convert BytesIO to PIL Image
|
93 |
+
img = Image.open(buf)
|
94 |
+
return img
|
95 |
+
|
96 |
+
|
97 |
+
# Get example images
|
98 |
+
example_images = [
|
99 |
+
os.path.join("sample_images", f)
|
100 |
+
for f in os.listdir("sample_images")
|
101 |
+
if f.endswith((".png", ".jpg", ".jpeg"))
|
102 |
+
]
|
103 |
+
|
104 |
+
with gr.Blocks() as demo:
|
105 |
+
gr.Markdown("# YOLO Segmentation Demo with BodyMask")
|
106 |
+
gr.Markdown(
|
107 |
+
"Upload an image or select an example to see the YOLO segmentation results."
|
108 |
+
)
|
109 |
+
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
input_image = gr.Image(type="pil", label="Input Image", height=512)
|
113 |
+
segment_button = gr.Button("Perform Segmentation")
|
114 |
+
|
115 |
+
output_image = gr.Image(label="Segmentation Result")
|
116 |
+
|
117 |
+
gr.Examples(
|
118 |
+
examples=example_images,
|
119 |
+
inputs=input_image,
|
120 |
+
outputs=output_image,
|
121 |
+
fn=perform_segmentation,
|
122 |
+
cache_examples=True,
|
123 |
+
)
|
124 |
+
|
125 |
+
segment_button.click(
|
126 |
+
fn=perform_segmentation,
|
127 |
+
inputs=input_image,
|
128 |
+
outputs=output_image,
|
129 |
+
)
|
130 |
+
|
131 |
+
demo.launch()
|
example.ipynb
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os \n",
|
10 |
+
"from ultralytics import YOLO\n",
|
11 |
+
"from yolo.BodyMask import BodyMask\n",
|
12 |
+
"\n",
|
13 |
+
"\n",
|
14 |
+
"model_id = os.path.abspath(\"yolo-human-parse-epoch-125.pt\")\n",
|
15 |
+
"\n",
|
16 |
+
"example_images = [\n",
|
17 |
+
" os.path.join(\"sample_images\", f)\n",
|
18 |
+
" for f in os.listdir(\"sample_images\")\n",
|
19 |
+
" if f.endswith((\".png\", \".jpg\", \".jpeg\"))\n",
|
20 |
+
"]\n",
|
21 |
+
"\n",
|
22 |
+
"image = example_images[0]\n",
|
23 |
+
"\n",
|
24 |
+
"bm = BodyMask(image, model_id=model_id)"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": null,
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"bm.display_results()"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 8,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [
|
41 |
+
{
|
42 |
+
"name": "stdout",
|
43 |
+
"output_type": "stream",
|
44 |
+
"text": [
|
45 |
+
"\u001b[0;31mInit signature:\u001b[0m\n",
|
46 |
+
"\u001b[0mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
47 |
+
"\u001b[0;34m\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | PIL.Image.Image | np.ndarray | Callable | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
48 |
+
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
49 |
+
"\u001b[0;34m\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'webp'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
50 |
+
"\u001b[0;34m\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
51 |
+
"\u001b[0;34m\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
52 |
+
"\u001b[0;34m\u001b[0m \u001b[0mimage_mode\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"Literal['1', 'L', 'P', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F'] | None\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'RGB'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
53 |
+
"\u001b[0;34m\u001b[0m \u001b[0msources\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"list[Literal['upload', 'webcam', 'clipboard']] | Literal['upload', 'webcam', 'clipboard'] | None\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
54 |
+
"\u001b[0;34m\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"Literal['numpy', 'pil', 'filepath']\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'numpy'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
55 |
+
"\u001b[0;34m\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
56 |
+
"\u001b[0;34m\u001b[0m \u001b[0mevery\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Timer | float | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
57 |
+
"\u001b[0;34m\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Component | Sequence[Component] | set[Component] | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
58 |
+
"\u001b[0;34m\u001b[0m \u001b[0mshow_label\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
59 |
+
"\u001b[0;34m\u001b[0m \u001b[0mshow_download_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
60 |
+
"\u001b[0;34m\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
61 |
+
"\u001b[0;34m\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
62 |
+
"\u001b[0;34m\u001b[0m \u001b[0mmin_width\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m160\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
63 |
+
"\u001b[0;34m\u001b[0m \u001b[0minteractive\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
64 |
+
"\u001b[0;34m\u001b[0m \u001b[0mvisible\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
65 |
+
"\u001b[0;34m\u001b[0m \u001b[0mstreaming\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
66 |
+
"\u001b[0;34m\u001b[0m \u001b[0melem_id\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
67 |
+
"\u001b[0;34m\u001b[0m \u001b[0melem_classes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'list[str] | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
68 |
+
"\u001b[0;34m\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
69 |
+
"\u001b[0;34m\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
70 |
+
"\u001b[0;34m\u001b[0m \u001b[0mmirror_webcam\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
71 |
+
"\u001b[0;34m\u001b[0m \u001b[0mshow_share_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
72 |
+
"\u001b[0;34m\u001b[0m \u001b[0mplaceholder\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
73 |
+
"\u001b[0;34m\u001b[0m \u001b[0mshow_fullscreen_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
74 |
+
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
75 |
+
"\u001b[0;31mDocstring:\u001b[0m \n",
|
76 |
+
"Creates an image component that can be used to upload images (as an input) or display images (as an output).\n",
|
77 |
+
"\n",
|
78 |
+
"Demos: sepia_filter, fake_diffusion\n",
|
79 |
+
"Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, create-your-own-friends-with-a-gan\n",
|
80 |
+
"\u001b[0;31mInit docstring:\u001b[0m\n",
|
81 |
+
"Parameters:\n",
|
82 |
+
" value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.\n",
|
83 |
+
" format: File format (e.g. \"png\" or \"gif\") to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files.\n",
|
84 |
+
" height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.\n",
|
85 |
+
" width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.\n",
|
86 |
+
" image_mode: \"RGB\" if color, or \"L\" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file.\n",
|
87 |
+
" sources: List of sources for the image. \"upload\" creates a box where user can drop an image file, \"webcam\" allows user to take snapshot from their webcam, \"clipboard\" allows users to paste an image from the clipboard. If None, defaults to [\"upload\", \"webcam\", \"clipboard\"] if streaming is False, otherwise defaults to [\"webcam\"].\n",
|
88 |
+
" type: The format the image is converted before being passed into the prediction function. \"numpy\" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, \"pil\" converts the image to a PIL image object, \"filepath\" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to \"filepath\" or \"pil\".\n",
|
89 |
+
" label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.\n",
|
90 |
+
" every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.\n",
|
91 |
+
" inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.\n",
|
92 |
+
" show_label: if True, will display label.\n",
|
93 |
+
" show_download_button: If True, will display button to download image.\n",
|
94 |
+
" container: If True, will place the component in a container - providing some extra padding around the border.\n",
|
95 |
+
" scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.\n",
|
96 |
+
" min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.\n",
|
97 |
+
" interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.\n",
|
98 |
+
" visible: If False, component will be hidden.\n",
|
99 |
+
" streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.\n",
|
100 |
+
" elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.\n",
|
101 |
+
" elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.\n",
|
102 |
+
" render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.\n",
|
103 |
+
" key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.\n",
|
104 |
+
" mirror_webcam: If True webcam will be mirrored. Default is True.\n",
|
105 |
+
" show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.\n",
|
106 |
+
" placeholder: Custom text for the upload area. Overrides default upload messages when provided. Accepts new lines and `#` to designate a heading.\n",
|
107 |
+
" show_fullscreen_button: If True, will show a fullscreen icon in the corner of the component that allows user to view the image in fullscreen mode. If False, icon does not appear.\n",
|
108 |
+
"\u001b[0;31mFile:\u001b[0m /opt/homebrew/Caskroom/miniforge/base/envs/lemons/lib/python3.10/site-packages/gradio/components/image.py\n",
|
109 |
+
"\u001b[0;31mType:\u001b[0m ComponentMeta\n",
|
110 |
+
"\u001b[0;31mSubclasses:\u001b[0m "
|
111 |
+
]
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"import gradio as gr \n",
|
116 |
+
"\n",
|
117 |
+
"gr.Image?"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": []
|
126 |
+
}
|
127 |
+
],
|
128 |
+
"metadata": {
|
129 |
+
"kernelspec": {
|
130 |
+
"display_name": "lemons",
|
131 |
+
"language": "python",
|
132 |
+
"name": "lemons"
|
133 |
+
},
|
134 |
+
"language_info": {
|
135 |
+
"codemirror_mode": {
|
136 |
+
"name": "ipython",
|
137 |
+
"version": 3
|
138 |
+
},
|
139 |
+
"file_extension": ".py",
|
140 |
+
"mimetype": "text/x-python",
|
141 |
+
"name": "python",
|
142 |
+
"nbconvert_exporter": "python",
|
143 |
+
"pygments_lexer": "ipython3",
|
144 |
+
"version": "3.10.14"
|
145 |
+
},
|
146 |
+
"orig_nbformat": 4
|
147 |
+
},
|
148 |
+
"nbformat": 4,
|
149 |
+
"nbformat_minor": 2
|
150 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.30.3
|
2 |
+
gradio==4.44.0
|
3 |
+
matplotlib==3.8.4
|
4 |
+
numpy==2.1.1
|
5 |
+
Pillow==10.4.0
|
6 |
+
skimage==0.0
|
7 |
+
ultralytics==8.2.97
|
sample_images/image_five.jpg
ADDED
![]() |
sample_images/image_four.jpg
ADDED
![]() |
sample_images/image_six.jpg
ADDED
![]() |
yolo/BodyMask.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from diffusers.utils import load_image
|
8 |
+
from PIL import Image, ImageChops, ImageFilter
|
9 |
+
from ultralytics import YOLO
|
10 |
+
from .utils import *
|
11 |
+
|
12 |
+
|
13 |
+
def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
|
14 |
+
if not mask:
|
15 |
+
return None
|
16 |
+
# Convert PIL image to NumPy array if necessary
|
17 |
+
if isinstance(mask, Image.Image):
|
18 |
+
mask = np.array(mask)
|
19 |
+
|
20 |
+
# Ensure mask is in uint8 format
|
21 |
+
mask = mask.astype(np.uint8)
|
22 |
+
|
23 |
+
# Apply dilation
|
24 |
+
kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
|
25 |
+
dilated_mask = cv2.dilate(mask, kernel, iterations=1)
|
26 |
+
|
27 |
+
# Apply erosion for refinement
|
28 |
+
kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
|
29 |
+
eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)
|
30 |
+
|
31 |
+
# Apply Gaussian blur to smooth the edges
|
32 |
+
blurred_mask = cv2.GaussianBlur(
|
33 |
+
eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
|
34 |
+
)
|
35 |
+
|
36 |
+
# Convert back to PIL image
|
37 |
+
smoothed_mask = Image.fromarray(blurred_mask).convert("L")
|
38 |
+
|
39 |
+
# Optionally, apply an additional blur for extra smoothness using PIL
|
40 |
+
smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))
|
41 |
+
|
42 |
+
return smoothed_mask
|
43 |
+
|
44 |
+
|
45 |
+
@lru_cache(maxsize=1)
|
46 |
+
def get_model(model_id):
|
47 |
+
model = YOLO(model=model_id)
|
48 |
+
return model
|
49 |
+
|
50 |
+
|
51 |
+
def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
|
52 |
+
"""
|
53 |
+
Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
- masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
|
57 |
+
- labels (List[str]): A list of labels to include in the combination.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
- Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
|
61 |
+
"""
|
62 |
+
labels_set = set(labels) # Convert labels list to a set for O(1) lookups
|
63 |
+
|
64 |
+
# Filter and convert mask images based on the specified labels
|
65 |
+
mask_images = [
|
66 |
+
mask["mask"].convert("L")
|
67 |
+
for mask in masks
|
68 |
+
if (mask["label"] in labels_set) == is_label
|
69 |
+
]
|
70 |
+
|
71 |
+
# Ensure there is at least one mask to combine
|
72 |
+
if not mask_images:
|
73 |
+
return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")
|
74 |
+
|
75 |
+
# Initialize the combined mask with the first mask
|
76 |
+
combined_mask = mask_images[0]
|
77 |
+
|
78 |
+
# Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
|
79 |
+
for mask in mask_images[1:]:
|
80 |
+
combined_mask = ImageChops.lighter(combined_mask, mask)
|
81 |
+
|
82 |
+
return combined_mask
|
83 |
+
|
84 |
+
|
85 |
+
body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
|
86 |
+
|
87 |
+
|
88 |
+
class BodyMask:
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
image_path,
|
93 |
+
model_id,
|
94 |
+
labels=body_labels,
|
95 |
+
overlay="mask",
|
96 |
+
widen_box=0,
|
97 |
+
elongate_box=0,
|
98 |
+
resize_to=640,
|
99 |
+
dilate_factor=0,
|
100 |
+
is_label=False,
|
101 |
+
resize_to_nearest_eight=False,
|
102 |
+
verbose=True,
|
103 |
+
remove_overlap=True,
|
104 |
+
):
|
105 |
+
self.image_path = image_path
|
106 |
+
self.image = self.get_image(
|
107 |
+
resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
|
108 |
+
)
|
109 |
+
self.labels = labels
|
110 |
+
self.is_label = is_label
|
111 |
+
self.model_id = model_id
|
112 |
+
self.model = get_model(self.model_id)
|
113 |
+
self.model_labels = self.model.names
|
114 |
+
self.verbose = verbose
|
115 |
+
self.results = self.get_results()
|
116 |
+
self.dilate_factor = dilate_factor
|
117 |
+
self.body_mask = self.get_body_mask()
|
118 |
+
self.box = get_bounding_box(self.body_mask)
|
119 |
+
self.body_box = self.get_body_box(
|
120 |
+
remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
|
121 |
+
)
|
122 |
+
if overlay == "box":
|
123 |
+
self.overlay = overlay_mask(
|
124 |
+
self.image, self.body_box, opacity=0.9, color="red"
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
self.overlay = overlay_mask(
|
128 |
+
self.image, self.body_mask, opacity=0.9, color="red"
|
129 |
+
)
|
130 |
+
|
131 |
+
def get_image(self, resize_to, resize_to_nearest_eight):
|
132 |
+
image = load_image(self.image_path)
|
133 |
+
if resize_to:
|
134 |
+
image = resize_preserve_aspect_ratio(image, resize_to)
|
135 |
+
if resize_to_nearest_eight:
|
136 |
+
image = resize_image_to_nearest_eight(image)
|
137 |
+
else:
|
138 |
+
image = image
|
139 |
+
return image
|
140 |
+
|
141 |
+
def get_body_mask(self):
|
142 |
+
body_mask = combine_masks(self.results, self.labels, self.is_label)
|
143 |
+
return dilate_mask(body_mask, self.dilate_factor)
|
144 |
+
|
145 |
+
def get_results(self):
|
146 |
+
imgsz = max(self.image.size)
|
147 |
+
results = self.model(
|
148 |
+
self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
|
149 |
+
)[0]
|
150 |
+
self.masks, self.boxes, self.scores, self.phrases = unload(
|
151 |
+
results, self.model_labels
|
152 |
+
)
|
153 |
+
results = format_results(
|
154 |
+
self.masks,
|
155 |
+
self.boxes,
|
156 |
+
self.scores,
|
157 |
+
self.phrases,
|
158 |
+
self.model_labels,
|
159 |
+
person_masks_only=False,
|
160 |
+
)
|
161 |
+
|
162 |
+
# filter out lower score results
|
163 |
+
masks_to_filter = ["hair"]
|
164 |
+
results = filter_highest_score(results, ["hair", "face", "phone"])
|
165 |
+
return results
|
166 |
+
|
167 |
+
def display_results(self):
|
168 |
+
if len(self.masks) < 4:
|
169 |
+
cols = len(self.masks)
|
170 |
+
else:
|
171 |
+
cols = 4
|
172 |
+
display_image_with_masks(self.image, self.results, cols=cols)
|
173 |
+
|
174 |
+
def get_mask(self, mask_label):
|
175 |
+
assert mask_label in self.phrases, "Mask label not found in results"
|
176 |
+
return [f for f in self.results if f.get("label") == mask_label]
|
177 |
+
|
178 |
+
def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
|
179 |
+
"""
|
180 |
+
Combine the masks included in the labels list or all of the masks not in the list
|
181 |
+
"""
|
182 |
+
if not is_label:
|
183 |
+
mask_labels = [
|
184 |
+
phrase for phrase in self.phrases if phrase not in mask_labels
|
185 |
+
]
|
186 |
+
masks = [
|
187 |
+
row.get("mask") for row in self.results if row.get("label") in mask_labels
|
188 |
+
]
|
189 |
+
if len(masks) == 0:
|
190 |
+
return None
|
191 |
+
combined_mask = masks[0]
|
192 |
+
for mask in masks[1:]:
|
193 |
+
combined_mask = ImageChops.lighter(combined_mask, mask)
|
194 |
+
return combined_mask
|
195 |
+
|
196 |
+
def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
|
197 |
+
body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
|
198 |
+
if remove_overlap:
|
199 |
+
body_box = self.remove_overlap(body_box)
|
200 |
+
return body_box
|
201 |
+
|
202 |
+
def remove_overlap(self, body_box):
|
203 |
+
"""
|
204 |
+
Remove mask regions that overlap with unwanted labels
|
205 |
+
"""
|
206 |
+
# convert mask to numpy array
|
207 |
+
box_array = np.array(body_box)
|
208 |
+
|
209 |
+
# combine the masks for those labels
|
210 |
+
mask = self.combine_masks(mask_labels=self.labels, is_label=True)
|
211 |
+
|
212 |
+
# convert mask to numpy array
|
213 |
+
mask_array = np.array(mask)
|
214 |
+
|
215 |
+
# where the mask array is white set the box array to black
|
216 |
+
box_array[mask_array == 255] = 0
|
217 |
+
|
218 |
+
# convert the box array to an image
|
219 |
+
mask_image = Image.fromarray(box_array)
|
220 |
+
return mask_image
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
|
225 |
+
image_name = url.split("/")[-1]
|
226 |
+
labels = ["face", "hair", "phone", "hand"]
|
227 |
+
image = load_image(url)
|
228 |
+
image_size = image.size
|
229 |
+
# Get the original size of the image
|
230 |
+
original_size = image.size
|
231 |
+
|
232 |
+
# Create body mask
|
233 |
+
body_mask = BodyMask(
|
234 |
+
image,
|
235 |
+
overlay="box",
|
236 |
+
labels=labels,
|
237 |
+
widen_box=50,
|
238 |
+
elongate_box=10,
|
239 |
+
dilate_factor=0,
|
240 |
+
resize_to=640,
|
241 |
+
is_label=False,
|
242 |
+
remove_overlap=True,
|
243 |
+
verbose=False,
|
244 |
+
)
|
245 |
+
|
246 |
+
# Resize the image back to the original size
|
247 |
+
image = body_mask.image.resize(original_size)
|
248 |
+
body_mask.body_box.save(image_name)
|
yolo/utils.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.patches as patches
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
|
6 |
+
|
7 |
+
def unload_mask(mask):
|
8 |
+
mask = mask.cpu().numpy().squeeze()
|
9 |
+
mask = mask.astype(np.uint8) * 255
|
10 |
+
return Image.fromarray(mask)
|
11 |
+
|
12 |
+
|
13 |
+
def unload_box(box):
|
14 |
+
return box.cpu().numpy().tolist()
|
15 |
+
|
16 |
+
|
17 |
+
def masks_overlap(mask1, mask2):
|
18 |
+
return np.any(np.logical_and(mask1, mask2))
|
19 |
+
|
20 |
+
|
21 |
+
def remove_non_person_masks(person_mask, formatted_results):
|
22 |
+
return [
|
23 |
+
f
|
24 |
+
for f in formatted_results
|
25 |
+
if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def format_masks(masks):
|
30 |
+
return [unload_mask(mask) for mask in masks]
|
31 |
+
|
32 |
+
|
33 |
+
def format_boxes(boxes):
|
34 |
+
return [unload_box(box) for box in boxes]
|
35 |
+
|
36 |
+
|
37 |
+
def format_scores(scores):
|
38 |
+
return scores.cpu().numpy().tolist()
|
39 |
+
|
40 |
+
|
41 |
+
def unload(result, labels_dict):
|
42 |
+
masks = format_masks(result.masks.data)
|
43 |
+
boxes = format_boxes(result.boxes.xyxy)
|
44 |
+
scores = format_scores(result.boxes.conf)
|
45 |
+
labels = result.boxes.cls
|
46 |
+
labels = [int(label.item()) for label in labels]
|
47 |
+
phrases = [labels_dict[label] for label in labels]
|
48 |
+
return masks, boxes, scores, phrases
|
49 |
+
|
50 |
+
|
51 |
+
def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
|
52 |
+
if isinstance(list(labels_dict.keys())[0], int):
|
53 |
+
labels_dict = {v: k for k, v in labels_dict.items()}
|
54 |
+
|
55 |
+
# check that the person mask is present
|
56 |
+
if person_masks_only:
|
57 |
+
assert "person" in labels, "Person mask not present in results"
|
58 |
+
results_dict = []
|
59 |
+
for row in zip(labels, scores, boxes, masks):
|
60 |
+
label, score, box, mask = row
|
61 |
+
label_id = labels_dict[label]
|
62 |
+
results_row = dict(
|
63 |
+
label=label, score=score, mask=mask, box=box, label_id=label_id
|
64 |
+
)
|
65 |
+
results_dict.append(results_row)
|
66 |
+
results_dict = sorted(results_dict, key=lambda x: x["label"])
|
67 |
+
if person_masks_only:
|
68 |
+
# Get the person mask
|
69 |
+
person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
|
70 |
+
assert person_mask is not None, "Person mask not found in results"
|
71 |
+
|
72 |
+
# Remove any results that do no overlap with the person
|
73 |
+
results_dict = remove_non_person_masks(person_mask, results_dict)
|
74 |
+
return results_dict
|
75 |
+
|
76 |
+
|
77 |
+
def filter_highest_score(results, labels):
|
78 |
+
"""
|
79 |
+
Filter results to remove entries with lower scores for specified labels.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
results (list): List of dictionaries containing 'label', 'score', and other keys.
|
83 |
+
labels (list): List of labels to filter.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
list: Filtered results with only the highest score for each specified label.
|
87 |
+
"""
|
88 |
+
# Dictionary to keep track of the highest score entry for each label
|
89 |
+
label_highest = {}
|
90 |
+
|
91 |
+
# First pass: identify the highest score for each label
|
92 |
+
for result in results:
|
93 |
+
label = result["label"]
|
94 |
+
if label in labels:
|
95 |
+
if (
|
96 |
+
label not in label_highest
|
97 |
+
or result["score"] > label_highest[label]["score"]
|
98 |
+
):
|
99 |
+
label_highest[label] = result
|
100 |
+
|
101 |
+
# Second pass: construct the filtered list while preserving the order
|
102 |
+
filtered_results = []
|
103 |
+
seen_labels = set()
|
104 |
+
|
105 |
+
for result in results:
|
106 |
+
label = result["label"]
|
107 |
+
if label in labels:
|
108 |
+
if label in seen_labels:
|
109 |
+
continue
|
110 |
+
if result == label_highest[label]:
|
111 |
+
filtered_results.append(result)
|
112 |
+
seen_labels.add(label)
|
113 |
+
else:
|
114 |
+
filtered_results.append(result)
|
115 |
+
|
116 |
+
return filtered_results
|
117 |
+
|
118 |
+
|
119 |
+
def display_image_with_masks(image, results, cols=4, return_images=False):
|
120 |
+
# Convert PIL Image to numpy array
|
121 |
+
image_np = np.array(image)
|
122 |
+
|
123 |
+
# Check image dimensions
|
124 |
+
if image_np.ndim != 3 or image_np.shape[2] != 3:
|
125 |
+
raise ValueError("Image must be a 3-dimensional array with 3 color channels")
|
126 |
+
|
127 |
+
# Number of masks
|
128 |
+
n = len(results)
|
129 |
+
rows = (n + cols - 1) // cols # Calculate required number of rows
|
130 |
+
|
131 |
+
# Setting up the plot
|
132 |
+
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
|
133 |
+
axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
|
134 |
+
for i, result in enumerate(results):
|
135 |
+
mask = result["mask"]
|
136 |
+
label = result["label"]
|
137 |
+
score = float(result["score"])
|
138 |
+
|
139 |
+
# Convert PIL mask to numpy array and resize if necessary
|
140 |
+
mask_np = np.array(mask)
|
141 |
+
if mask_np.shape != image_np.shape[:2]:
|
142 |
+
mask_np = resize(
|
143 |
+
mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
|
144 |
+
)
|
145 |
+
mask_np = (mask_np > 0.5).astype(
|
146 |
+
np.uint8
|
147 |
+
) # Threshold back to binary after resize
|
148 |
+
|
149 |
+
# Create an overlay where mask is True
|
150 |
+
overlay = np.zeros_like(image_np)
|
151 |
+
overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
|
152 |
+
|
153 |
+
# Combine the image and the overlay
|
154 |
+
combined = image_np.copy()
|
155 |
+
indices = np.where(mask_np > 0)
|
156 |
+
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
|
157 |
+
|
158 |
+
# Show the combined image
|
159 |
+
ax = axs[i]
|
160 |
+
ax.imshow(combined)
|
161 |
+
ax.axis("off")
|
162 |
+
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
|
163 |
+
rect = patches.Rectangle(
|
164 |
+
(0, 0),
|
165 |
+
image_np.shape[1],
|
166 |
+
image_np.shape[0],
|
167 |
+
linewidth=1,
|
168 |
+
edgecolor="r",
|
169 |
+
facecolor="none",
|
170 |
+
)
|
171 |
+
ax.add_patch(rect)
|
172 |
+
|
173 |
+
# Hide unused subplots if the total number of masks is not a multiple of cols
|
174 |
+
for idx in range(i + 1, rows * cols):
|
175 |
+
axs[idx].axis("off")
|
176 |
+
plt.tight_layout()
|
177 |
+
plt.show()
|
178 |
+
|
179 |
+
|
180 |
+
def get_bounding_box(mask):
|
181 |
+
"""
|
182 |
+
Given a segmentation mask, return the bounding box for the mask object.
|
183 |
+
"""
|
184 |
+
# Find indices where the mask is non-zero
|
185 |
+
coords = np.argwhere(mask)
|
186 |
+
# Get the minimum and maximum x and y coordinates
|
187 |
+
x_min, y_min = np.min(coords, axis=0)
|
188 |
+
x_max, y_max = np.max(coords, axis=0)
|
189 |
+
# Return the bounding box coordinates
|
190 |
+
return (y_min, x_min, y_max, x_max)
|
191 |
+
|
192 |
+
|
193 |
+
def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
|
194 |
+
# Convert the PIL segmentation mask to a NumPy array
|
195 |
+
mask_array = np.array(segmentation_mask)
|
196 |
+
|
197 |
+
# Find the coordinates of the non-zero pixels
|
198 |
+
non_zero_y, non_zero_x = np.nonzero(mask_array)
|
199 |
+
|
200 |
+
# Calculate the bounding box coordinates
|
201 |
+
min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
|
202 |
+
min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)
|
203 |
+
|
204 |
+
if widen > 0:
|
205 |
+
min_x = max(0, min_x - widen)
|
206 |
+
max_x = min(mask_array.shape[1], max_x + widen)
|
207 |
+
|
208 |
+
if elongate > 0:
|
209 |
+
min_y = max(0, min_y - elongate)
|
210 |
+
max_y = min(mask_array.shape[0], max_y + elongate)
|
211 |
+
|
212 |
+
# Create a new blank image for the bounding box mask
|
213 |
+
bounding_box_mask = Image.new("1", segmentation_mask.size)
|
214 |
+
|
215 |
+
# Draw the filled bounding box on the blank image
|
216 |
+
draw = ImageDraw.Draw(bounding_box_mask)
|
217 |
+
draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)
|
218 |
+
|
219 |
+
return bounding_box_mask
|
220 |
+
|
221 |
+
|
222 |
+
colors = {
|
223 |
+
"blue": (136, 207, 249),
|
224 |
+
"red": (255, 0, 0),
|
225 |
+
"green": (0, 255, 0),
|
226 |
+
"yellow": (255, 255, 0),
|
227 |
+
"purple": (128, 0, 128),
|
228 |
+
"cyan": (0, 255, 255),
|
229 |
+
"magenta": (255, 0, 255),
|
230 |
+
"orange": (255, 165, 0),
|
231 |
+
"lime": (50, 205, 50),
|
232 |
+
"pink": (255, 192, 203),
|
233 |
+
"brown": (139, 69, 19),
|
234 |
+
"gray": (128, 128, 128),
|
235 |
+
"black": (0, 0, 0),
|
236 |
+
"white": (255, 255, 255),
|
237 |
+
"gold": (255, 215, 0),
|
238 |
+
"silver": (192, 192, 192),
|
239 |
+
"beige": (245, 245, 220),
|
240 |
+
"navy": (0, 0, 128),
|
241 |
+
"maroon": (128, 0, 0),
|
242 |
+
"olive": (128, 128, 0),
|
243 |
+
}
|
244 |
+
|
245 |
+
|
246 |
+
def overlay_mask(image, mask, opacity=0.5, color="blue"):
|
247 |
+
"""
|
248 |
+
Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
|
249 |
+
and color the mask with a low opacity blue with hex #88CFF9.
|
250 |
+
"""
|
251 |
+
# Convert the boolean mask to an image with alpha channel
|
252 |
+
alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")
|
253 |
+
|
254 |
+
# Choose the color
|
255 |
+
r, g, b = colors[color]
|
256 |
+
|
257 |
+
color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255)))
|
258 |
+
mask_rgba = Image.composite(
|
259 |
+
color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha
|
260 |
+
)
|
261 |
+
|
262 |
+
# Create a new RGBA image to overlay the mask on
|
263 |
+
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
264 |
+
|
265 |
+
# Paste the mask onto the overlay
|
266 |
+
overlay.paste(mask_rgba, (0, 0))
|
267 |
+
|
268 |
+
# Create a new image to return by blending the original image and the overlay
|
269 |
+
result = Image.alpha_composite(image.convert("RGBA"), overlay)
|
270 |
+
|
271 |
+
# Convert the result back to the original mode and return it
|
272 |
+
return result.convert(image.mode)
|
273 |
+
|
274 |
+
|
275 |
+
def resize_preserve_aspect_ratio(image, max_side=512):
|
276 |
+
width, height = image.size
|
277 |
+
scale = min(max_side / width, max_side / height)
|
278 |
+
new_width = int(width * scale)
|
279 |
+
new_height = int(height * scale)
|
280 |
+
return image.resize((new_width, new_height))
|
281 |
+
|
282 |
+
|
283 |
+
def round_to_nearest_eigth(value):
|
284 |
+
return int((value // 8 * 8))
|
285 |
+
|
286 |
+
|
287 |
+
def resize_image_to_nearest_eight(image):
|
288 |
+
width, height = image.size
|
289 |
+
width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
|
290 |
+
image = image.resize((width, height))
|
291 |
+
return image
|