Dan Bochman commited on
Commit
4375c79
·
unverified ·
1 Parent(s): bf03d7d

download checkpoint instead of uploading it

Browse files
Files changed (4) hide show
  1. .gitattributes +0 -1
  2. .gitignore +2 -1
  3. app.py +10 -2
  4. requirements.txt +2 -1
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
- *.pt2 filter=lfs diff=lfs merge=lfs -text
2
  *.jpg filter=lfs diff=lfs merge=lfs -text
3
  *.jpeg filter=lfs diff=lfs merge=lfs -text
4
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.jpeg filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .DS_Store
 
 
1
+ .DS_Store
2
+ *.pt2
app.py CHANGED
@@ -132,9 +132,18 @@ def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
132
 
133
  # ----------------- MODEL ----------------- #
134
 
 
135
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
136
  model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
137
 
 
 
 
 
 
 
 
 
138
  model = torch.jit.load(model_path)
139
  model.eval()
140
 
@@ -201,8 +210,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radi
201
  with gr.Column():
202
  result_image = gr.Image(label="Segmentation Result", format="png")
203
  run_button = gr.Button("Run")
204
-
205
-
206
  gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
207
 
208
  run_button.click(
 
132
 
133
  # ----------------- MODEL ----------------- #
134
 
135
+ URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
136
  CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
137
  model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
138
 
139
+ if not os.path.exists(model_path):
140
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
141
+ import requests
142
+
143
+ response = requests.get(URL)
144
+ with open(model_path, "wb") as file:
145
+ file.write(response.content)
146
+
147
  model = torch.jit.load(model_path)
148
  model.eval()
149
 
 
210
  with gr.Column():
211
  result_image = gr.Image(label="Segmentation Result", format="png")
212
  run_button = gr.Button("Run")
213
+
 
214
  gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
215
 
216
  run_button.click(
requirements.txt CHANGED
@@ -3,4 +3,5 @@ numpy
3
  torch
4
  torchvision
5
  matplotlib
6
- pillow
 
 
3
  torch
4
  torchvision
5
  matplotlib
6
+ pillow
7
+ requests