Spaces:
Running
on
Zero
Running
on
Zero
Dan Bochman
commited on
adapt to use ZERO spaces.GPU
Browse files- app.py +13 -49
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,14 +4,17 @@ import os
|
|
4 |
import gradio as gr
|
5 |
import matplotlib.colors as mcolors
|
6 |
import numpy as np
|
|
|
7 |
import torch
|
8 |
from gradio.themes.utils import sizes
|
9 |
-
from matplotlib import pyplot as plt
|
10 |
-
from matplotlib.patches import Patch
|
11 |
from PIL import Image
|
12 |
from torchvision import transforms
|
13 |
|
14 |
-
# -----------------
|
|
|
|
|
|
|
|
|
15 |
|
16 |
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
|
17 |
|
@@ -47,6 +50,9 @@ LABELS_TO_IDS = {
|
|
47 |
}
|
48 |
|
49 |
|
|
|
|
|
|
|
50 |
def get_palette(num_cls):
|
51 |
palette = [0] * (256 * 3)
|
52 |
palette[0:3] = [0, 0, 0]
|
@@ -85,51 +91,6 @@ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_i
|
|
85 |
return blended
|
86 |
|
87 |
|
88 |
-
def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
|
89 |
-
num_cls = len(labels_to_ids)
|
90 |
-
palette = get_palette(num_cls)
|
91 |
-
colormap = create_colormap(palette)
|
92 |
-
|
93 |
-
fig, ax = plt.subplots(figsize=(4, 6), facecolor="white")
|
94 |
-
|
95 |
-
ax.axis("off")
|
96 |
-
|
97 |
-
legend_elements = [
|
98 |
-
Patch(facecolor=colormap(i), edgecolor="black", label=label)
|
99 |
-
for label, i in sorted(labels_to_ids.items(), key=lambda x: x[1])
|
100 |
-
]
|
101 |
-
|
102 |
-
plt.title("Legend", fontsize=16, fontweight="bold", pad=20)
|
103 |
-
|
104 |
-
legend = ax.legend(
|
105 |
-
handles=legend_elements,
|
106 |
-
loc="center",
|
107 |
-
bbox_to_anchor=(0.5, 0.5),
|
108 |
-
ncol=2,
|
109 |
-
frameon=True,
|
110 |
-
fancybox=True,
|
111 |
-
shadow=True,
|
112 |
-
fontsize=10,
|
113 |
-
title_fontsize=12,
|
114 |
-
borderpad=1,
|
115 |
-
labelspacing=1.2,
|
116 |
-
handletextpad=0.5,
|
117 |
-
handlelength=1.5,
|
118 |
-
columnspacing=1.5,
|
119 |
-
)
|
120 |
-
|
121 |
-
legend.get_frame().set_facecolor("#FAFAFA")
|
122 |
-
legend.get_frame().set_edgecolor("gray")
|
123 |
-
|
124 |
-
# Adjust layout and save
|
125 |
-
plt.tight_layout()
|
126 |
-
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
127 |
-
plt.close()
|
128 |
-
|
129 |
-
|
130 |
-
# create_legend_image(LABELS_TO_IDS, filename=os.path.join(ASSETS_DIR, "legend.png"))
|
131 |
-
|
132 |
-
|
133 |
# ----------------- MODEL ----------------- #
|
134 |
|
135 |
URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
|
@@ -146,9 +107,12 @@ if not os.path.exists(model_path):
|
|
146 |
|
147 |
model = torch.jit.load(model_path)
|
148 |
model.eval()
|
|
|
149 |
|
150 |
|
151 |
-
@
|
|
|
|
|
152 |
def run_model(input_tensor, height, width):
|
153 |
output = model(input_tensor)
|
154 |
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
|
|
4 |
import gradio as gr
|
5 |
import matplotlib.colors as mcolors
|
6 |
import numpy as np
|
7 |
+
import spaces
|
8 |
import torch
|
9 |
from gradio.themes.utils import sizes
|
|
|
|
|
10 |
from PIL import Image
|
11 |
from torchvision import transforms
|
12 |
|
13 |
+
# ----------------- ENV ----------------- #
|
14 |
+
|
15 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
16 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
17 |
+
torch.backends.cudnn.allow_tf32 = True
|
18 |
|
19 |
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
|
20 |
|
|
|
50 |
}
|
51 |
|
52 |
|
53 |
+
# ----------------- HELPER FUNCTIONS ----------------- #
|
54 |
+
|
55 |
+
|
56 |
def get_palette(num_cls):
|
57 |
palette = [0] * (256 * 3)
|
58 |
palette[0:3] = [0, 0, 0]
|
|
|
91 |
return blended
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
# ----------------- MODEL ----------------- #
|
95 |
|
96 |
URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
|
|
|
107 |
|
108 |
model = torch.jit.load(model_path)
|
109 |
model.eval()
|
110 |
+
model.to("cuda")
|
111 |
|
112 |
|
113 |
+
@spaces.GPU
|
114 |
+
@torch.inference_mode()
|
115 |
+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
116 |
def run_model(input_tensor, height, width):
|
117 |
output = model(input_tensor)
|
118 |
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ torch
|
|
4 |
torchvision
|
5 |
matplotlib
|
6 |
pillow
|
7 |
-
requests
|
|
|
|
4 |
torchvision
|
5 |
matplotlib
|
6 |
pillow
|
7 |
+
requests
|
8 |
+
spaces
|