Dan Bochman commited on
Commit
d46329a
·
unverified ·
1 Parent(s): 4375c79

adapt to use ZERO spaces.GPU

Browse files
Files changed (2) hide show
  1. app.py +13 -49
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,14 +4,17 @@ import os
4
  import gradio as gr
5
  import matplotlib.colors as mcolors
6
  import numpy as np
 
7
  import torch
8
  from gradio.themes.utils import sizes
9
- from matplotlib import pyplot as plt
10
- from matplotlib.patches import Patch
11
  from PIL import Image
12
  from torchvision import transforms
13
 
14
- # ----------------- HELPER FUNCTIONS ----------------- #
 
 
 
 
15
 
16
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
17
 
@@ -47,6 +50,9 @@ LABELS_TO_IDS = {
47
  }
48
 
49
 
 
 
 
50
  def get_palette(num_cls):
51
  palette = [0] * (256 * 3)
52
  palette[0:3] = [0, 0, 0]
@@ -85,51 +91,6 @@ def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_i
85
  return blended
86
 
87
 
88
- def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
89
- num_cls = len(labels_to_ids)
90
- palette = get_palette(num_cls)
91
- colormap = create_colormap(palette)
92
-
93
- fig, ax = plt.subplots(figsize=(4, 6), facecolor="white")
94
-
95
- ax.axis("off")
96
-
97
- legend_elements = [
98
- Patch(facecolor=colormap(i), edgecolor="black", label=label)
99
- for label, i in sorted(labels_to_ids.items(), key=lambda x: x[1])
100
- ]
101
-
102
- plt.title("Legend", fontsize=16, fontweight="bold", pad=20)
103
-
104
- legend = ax.legend(
105
- handles=legend_elements,
106
- loc="center",
107
- bbox_to_anchor=(0.5, 0.5),
108
- ncol=2,
109
- frameon=True,
110
- fancybox=True,
111
- shadow=True,
112
- fontsize=10,
113
- title_fontsize=12,
114
- borderpad=1,
115
- labelspacing=1.2,
116
- handletextpad=0.5,
117
- handlelength=1.5,
118
- columnspacing=1.5,
119
- )
120
-
121
- legend.get_frame().set_facecolor("#FAFAFA")
122
- legend.get_frame().set_edgecolor("gray")
123
-
124
- # Adjust layout and save
125
- plt.tight_layout()
126
- plt.savefig(filename, dpi=300, bbox_inches="tight")
127
- plt.close()
128
-
129
-
130
- # create_legend_image(LABELS_TO_IDS, filename=os.path.join(ASSETS_DIR, "legend.png"))
131
-
132
-
133
  # ----------------- MODEL ----------------- #
134
 
135
  URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
@@ -146,9 +107,12 @@ if not os.path.exists(model_path):
146
 
147
  model = torch.jit.load(model_path)
148
  model.eval()
 
149
 
150
 
151
- @torch.no_grad()
 
 
152
  def run_model(input_tensor, height, width):
153
  output = model(input_tensor)
154
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
 
4
  import gradio as gr
5
  import matplotlib.colors as mcolors
6
  import numpy as np
7
+ import spaces
8
  import torch
9
  from gradio.themes.utils import sizes
 
 
10
  from PIL import Image
11
  from torchvision import transforms
12
 
13
+ # ----------------- ENV ----------------- #
14
+
15
+ if torch.cuda.get_device_properties(0).major >= 8:
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
 
19
  ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
20
 
 
50
  }
51
 
52
 
53
+ # ----------------- HELPER FUNCTIONS ----------------- #
54
+
55
+
56
  def get_palette(num_cls):
57
  palette = [0] * (256 * 3)
58
  palette[0:3] = [0, 0, 0]
 
91
  return blended
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # ----------------- MODEL ----------------- #
95
 
96
  URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
 
107
 
108
  model = torch.jit.load(model_path)
109
  model.eval()
110
+ model.to("cuda")
111
 
112
 
113
+ @spaces.GPU
114
+ @torch.inference_mode()
115
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
116
  def run_model(input_tensor, height, width):
117
  output = model(input_tensor)
118
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch
4
  torchvision
5
  matplotlib
6
  pillow
7
- requests
 
 
4
  torchvision
5
  matplotlib
6
  pillow
7
+ requests
8
+ spaces