Spaces:
Build error
Build error
joselobenitezg
commited on
Commit
·
94f04b7
1
Parent(s):
abe2204
wip
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- app.py +118 -4
- checkpoints/depth/sapiens_0.3b_torchscript.pt2 +3 -0
- checkpoints/depth/sapiens_0.6b_torchscript.pt2 +3 -0
- checkpoints/depth/sapiens_1b_torchscript.pt2 +3 -0
- checkpoints/depth/sapiens_2b_torchscript.pt2 +3 -0
- checkpoints/normal/sapiens_0.3b_torchscript.pt2 +3 -0
- checkpoints/normal/sapiens_0.6b_torchscript.pt2 +3 -0
- checkpoints/normal/sapiens_1b_torchscript.pt2 +3 -0
- checkpoints/normal/sapiens_2b_torchscript.pt2 +3 -0
- checkpoints/pose/sapiens_1b_torchscript.pt2 +3 -0
- checkpoints/seg/sapiens_0.3b_torchscript.pt2 +3 -0
- checkpoints/seg/sapiens_0.6b_torchscript.pt2 +3 -0
- checkpoints/seg/sapiens_1b_torchscript.pt2 +3 -0
- checkpoints/seg/sapiens_2b_torchscript.pt2 +3 -0
- config.py +55 -0
- download_checkpoints.py +42 -0
- requirements.txt +7 -0
- sapiens +1 -0
- utils/vis_utils.py +42 -0
.gitattributes
CHANGED
@@ -20,6 +20,7 @@
|
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
26 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -1,7 +1,121 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
# Part of the source code is in: fashn-ai/sapiens-body-part-segmentation
|
2 |
+
import os
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
from gradio.themes.utils import sizes
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from utils.vis_utils import get_palette, visualize_mask_with_overlay
|
12 |
+
|
13 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
torch.backends.cudnn.allow_tf32 = True
|
16 |
+
|
17 |
+
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
|
18 |
+
|
19 |
+
|
20 |
+
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
|
21 |
+
|
22 |
+
CHECKPOINTS = {
|
23 |
+
"0.3B": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
|
24 |
+
"0.6B": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
|
25 |
+
"1B": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
|
26 |
+
"2B": "sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2",
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def load_model(checkpoint_name: str):
|
31 |
+
checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name])
|
32 |
+
model = torch.jit.load(checkpoint_path)
|
33 |
+
model.eval()
|
34 |
+
model.to("cuda")
|
35 |
+
return model
|
36 |
+
|
37 |
+
|
38 |
+
MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()}
|
39 |
+
|
40 |
+
|
41 |
+
@torch.inference_mode()
|
42 |
+
def run_model(model, input_tensor, height, width):
|
43 |
+
output = model(input_tensor)
|
44 |
+
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
45 |
+
_, preds = torch.max(output, 1)
|
46 |
+
return preds
|
47 |
+
|
48 |
+
|
49 |
+
transform_fn = transforms.Compose(
|
50 |
+
[
|
51 |
+
transforms.Resize((1024, 768)),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
# ----------------- CORE FUNCTION ----------------- #
|
57 |
+
|
58 |
+
|
59 |
+
@spaces.GPU
|
60 |
+
def segment(image: Image.Image, model_name: str) -> Image.Image:
|
61 |
+
input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
|
62 |
+
model = MODELS[model_name]
|
63 |
+
preds = run_model(model, input_tensor, height=image.height, width=image.width)
|
64 |
+
mask = preds.squeeze(0).cpu().numpy()
|
65 |
+
mask_image = Image.fromarray(mask.astype("uint8"))
|
66 |
+
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
|
67 |
+
return blended_image
|
68 |
+
|
69 |
+
|
70 |
+
# ----------------- GRADIO UI ----------------- #
|
71 |
+
|
72 |
+
|
73 |
+
with open("banner.html", "r") as file:
|
74 |
+
banner = file.read()
|
75 |
+
with open("tips.html", "r") as file:
|
76 |
+
tips = file.read()
|
77 |
+
|
78 |
+
CUSTOM_CSS = """
|
79 |
+
.image-container img {
|
80 |
+
max-width: 512px;
|
81 |
+
max-height: 512px;
|
82 |
+
margin: 0 auto;
|
83 |
+
border-radius: 0px;
|
84 |
+
.gradio-container {background-color: #fafafa}
|
85 |
+
"""
|
86 |
+
|
87 |
+
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo:
|
88 |
+
gr.HTML(banner)
|
89 |
+
gr.HTML(tips)
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column():
|
92 |
+
input_image = gr.Image(label="Input Image", type="pil", format="png")
|
93 |
+
model_name = gr.Dropdown(
|
94 |
+
label="Model Version",
|
95 |
+
choices=list(CHECKPOINTS.keys()),
|
96 |
+
value="0.3B",
|
97 |
+
)
|
98 |
+
|
99 |
+
example_model = gr.Examples(
|
100 |
+
inputs=input_image,
|
101 |
+
examples_per_page=10,
|
102 |
+
examples=[
|
103 |
+
os.path.join(ASSETS_DIR, "examples", img)
|
104 |
+
for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
|
105 |
+
],
|
106 |
+
)
|
107 |
+
with gr.Column():
|
108 |
+
result_image = gr.Image(label="Segmentation Result", format="png")
|
109 |
+
run_button = gr.Button("Run")
|
110 |
+
|
111 |
+
gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
|
112 |
+
|
113 |
+
run_button.click(
|
114 |
+
fn=segment,
|
115 |
+
inputs=[input_image, model_name],
|
116 |
+
outputs=[result_image],
|
117 |
+
)
|
118 |
|
|
|
|
|
119 |
|
120 |
+
if __name__ == "__main__":
|
121 |
+
demo.launch(share=False)
|
checkpoints/depth/sapiens_0.3b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65054e6b6083171b1edf39a9786e34a47f3bfb28c1e0098f73de2ef823b7286e
|
3 |
+
size 1280489853
|
checkpoints/depth/sapiens_0.6b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f18bef54e4902810172bec9877d3f4d287d5e087a1704150ac73ed09a6097892
|
3 |
+
size 2600455553
|
checkpoints/depth/sapiens_1b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ff0c7a8fa48f1d30f97a49aee05abb905f64ee4fe6a35efa805821be5756a8c
|
3 |
+
size 4625326609
|
checkpoints/depth/sapiens_2b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a93550c2849a38ffc0d83e447626caccc4af7f5864ea11a61202808a097c9ea
|
3 |
+
size 799990784
|
checkpoints/normal/sapiens_0.3b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa2db29f0033e7415843842b3c55a7806397116ca3b7dc6c9b2e7914dacba313
|
3 |
+
size 1358768084
|
checkpoints/normal/sapiens_0.6b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5367e673a59e6d8cb04f5cb9ae3c675313bc20f844ef51daf53fa8dc020562b1
|
3 |
+
size 2685035027
|
checkpoints/normal/sapiens_1b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00e29d62c385de04f40bc188dd4571e19cab26a8dbc1424d61a77206b3758fb2
|
3 |
+
size 4716203073
|
checkpoints/normal/sapiens_2b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80f94a277f8cbd73a5ffd00c9dbdc6f2d59e66d5ffa00c56ee9706e4cf9292ea
|
3 |
+
size 8706490978
|
checkpoints/pose/sapiens_1b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6218c6be17697157f9e65ee34054a94ab8ca0f637380fa5748c18e04814976e
|
3 |
+
size 4677162331
|
checkpoints/seg/sapiens_0.3b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:735a9a8d63fe8f3f6a4ca3d787de07e69b1f9708ad550e09bb33c9854b7eafbc
|
3 |
+
size 1358871599
|
checkpoints/seg/sapiens_0.6b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86aa2cb9d7310ba1cb1971026889f1d10d80ddf655d6028aea060aae94d82082
|
3 |
+
size 2685144079
|
checkpoints/seg/sapiens_1b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33bba30f3de8d9cfd44e4eaa4817b1bfdd98c188edfc87fa7cc031ba0f4edc17
|
3 |
+
size 4716314057
|
checkpoints/seg/sapiens_2b_torchscript.pt2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f32f841135794327a434b79fd25c6cca24a72e098e314baa430be65e13dd0332
|
3 |
+
size 8706612665
|
config.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SAPIENS_LITE_MODELS = {
|
2 |
+
"depth": {
|
3 |
+
"sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_0.3b/sapiens_0.3b_render_people_epoch_100_torchscript.pt2?download=true",
|
4 |
+
"sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_0.6b/sapiens_0.6b_render_people_epoch_70_torchscript.pt2?download=true",
|
5 |
+
"sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_1b/sapiens_1b_render_people_epoch_88_torchscript.pt2?download=true",
|
6 |
+
"sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/depth/checkpoints/sapiens_2b/sapiens_2b_render_people_epoch_25_torchscript.pt2?download=true"
|
7 |
+
},
|
8 |
+
"detector": {},
|
9 |
+
"normal": {
|
10 |
+
"sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.3b/sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2?download=true",
|
11 |
+
"sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_0.6b/sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2?download=true",
|
12 |
+
"sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_1b/sapiens_1b_normal_render_people_epoch_115_torchscript.pt2?download=true",
|
13 |
+
"sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/normal/checkpoints/sapiens_2b/sapiens_2b_normal_render_people_epoch_70_torchscript.pt2?download=true"
|
14 |
+
},
|
15 |
+
"pose": {
|
16 |
+
"sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/pose/checkpoints/sapiens_1b/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2?download=true"
|
17 |
+
},
|
18 |
+
"seg": {
|
19 |
+
"sapiens_0.3b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true",
|
20 |
+
"sapiens_0.6b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.6b/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2?download=true",
|
21 |
+
"sapiens_1b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_1b/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2?download=true",
|
22 |
+
"sapiens_2b": "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_2b/sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2?download=true"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
LABELS_TO_IDS = {
|
27 |
+
"Background": 0,
|
28 |
+
"Apparel": 1,
|
29 |
+
"Face Neck": 2,
|
30 |
+
"Hair": 3,
|
31 |
+
"Left Foot": 4,
|
32 |
+
"Left Hand": 5,
|
33 |
+
"Left Lower Arm": 6,
|
34 |
+
"Left Lower Leg": 7,
|
35 |
+
"Left Shoe": 8,
|
36 |
+
"Left Sock": 9,
|
37 |
+
"Left Upper Arm": 10,
|
38 |
+
"Left Upper Leg": 11,
|
39 |
+
"Lower Clothing": 12,
|
40 |
+
"Right Foot": 13,
|
41 |
+
"Right Hand": 14,
|
42 |
+
"Right Lower Arm": 15,
|
43 |
+
"Right Lower Leg": 16,
|
44 |
+
"Right Shoe": 17,
|
45 |
+
"Right Sock": 18,
|
46 |
+
"Right Upper Arm": 19,
|
47 |
+
"Right Upper Leg": 20,
|
48 |
+
"Torso": 21,
|
49 |
+
"Upper Clothing": 22,
|
50 |
+
"Lower Lip": 23,
|
51 |
+
"Upper Lip": 24,
|
52 |
+
"Lower Teeth": 25,
|
53 |
+
"Upper Teeth": 26,
|
54 |
+
"Tongue": 27,
|
55 |
+
}
|
download_checkpoints.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
from config import SAPIENS_LITE_MODELS
|
6 |
+
|
7 |
+
def download_file(url, filename):
|
8 |
+
response = requests.get(url, stream=True)
|
9 |
+
total_size = int(response.headers.get('content-length', 0))
|
10 |
+
|
11 |
+
with open(filename, 'wb') as file, tqdm(
|
12 |
+
desc=filename,
|
13 |
+
total=total_size,
|
14 |
+
unit='iB',
|
15 |
+
unit_scale=True,
|
16 |
+
unit_divisor=1024,
|
17 |
+
) as progress_bar:
|
18 |
+
for data in response.iter_content(chunk_size=1024):
|
19 |
+
size = file.write(data)
|
20 |
+
progress_bar.update(size)
|
21 |
+
|
22 |
+
def main():
|
23 |
+
# Load the JSON file with model URLs
|
24 |
+
model_urls = SAPIENS_LITE_MODELS
|
25 |
+
|
26 |
+
for task, models in model_urls.items():
|
27 |
+
checkpoints_dir = os.path.join('checkpoints', task)
|
28 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
29 |
+
|
30 |
+
for model_name, url in models.items():
|
31 |
+
model_filename = f"{model_name}_torchscript.pt2"
|
32 |
+
model_path = os.path.join(checkpoints_dir, model_filename)
|
33 |
+
|
34 |
+
if not os.path.exists(model_path):
|
35 |
+
print(f"Downloading {task} {model_name} model...")
|
36 |
+
download_file(url, model_path)
|
37 |
+
print(f"{task} {model_name} model downloaded successfully.")
|
38 |
+
else:
|
39 |
+
print(f"{task} {model_name} model already exists. Skipping download.")
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
numpy
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
matplotlib
|
6 |
+
pillow
|
7 |
+
spaces
|
sapiens
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 04bdc575d33ae93735f4c64887383e132951d8a4
|
utils/vis_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# source: huggingface: fashn-ai/sapiens-body-part-segmentation
|
2 |
+
import colorsys
|
3 |
+
import matplotlib.colors as mcolors
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
def get_palette(num_cls):
|
8 |
+
palette = [0] * (256 * 3)
|
9 |
+
palette[0:3] = [0, 0, 0]
|
10 |
+
|
11 |
+
for j in range(1, num_cls):
|
12 |
+
hue = (j - 1) / (num_cls - 1)
|
13 |
+
saturation = 1.0
|
14 |
+
value = 1.0 if j % 2 == 0 else 0.5
|
15 |
+
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
|
16 |
+
r, g, b = [int(x * 255) for x in rgb]
|
17 |
+
palette[j * 3 : j * 3 + 3] = [r, g, b]
|
18 |
+
|
19 |
+
return palette
|
20 |
+
|
21 |
+
|
22 |
+
def create_colormap(palette):
|
23 |
+
colormap = np.array(palette).reshape(-1, 3) / 255.0
|
24 |
+
return mcolors.ListedColormap(colormap)
|
25 |
+
|
26 |
+
|
27 |
+
def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_ids: dict[str, int], alpha=0.5):
|
28 |
+
img_np = np.array(img.convert("RGB"))
|
29 |
+
mask_np = np.array(mask)
|
30 |
+
|
31 |
+
num_cls = len(labels_to_ids)
|
32 |
+
palette = get_palette(num_cls)
|
33 |
+
colormap = create_colormap(palette)
|
34 |
+
|
35 |
+
overlay = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
|
36 |
+
for label, idx in labels_to_ids.items():
|
37 |
+
if idx != 0:
|
38 |
+
overlay[mask_np == idx] = np.array(colormap(idx)[:3]) * 255
|
39 |
+
|
40 |
+
blended = Image.fromarray(np.uint8(img_np * (1 - alpha) + overlay * alpha))
|
41 |
+
|
42 |
+
return blended
|