hysts HF Staff commited on
Commit
a5c1a92
·
1 Parent(s): a5f7743
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -3,7 +3,6 @@
3
  from __future__ import annotations
4
 
5
  import functools
6
- import os
7
  import sys
8
  from typing import Callable
9
 
@@ -12,6 +11,7 @@ import gradio as gr
12
  import huggingface_hub
13
  import numpy as np
14
  import PIL.Image
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
@@ -66,6 +66,7 @@ def crop_face(image: np.ndarray, box: tuple[int, int, int, int]) -> np.ndarray:
66
  return image
67
 
68
 
 
69
  @torch.inference_mode()
70
  def predict(image: np.ndarray, transform: Callable, model: nn.Module, device: torch.device) -> np.ndarray:
71
  indices = torch.arange(66).float().to(device)
@@ -129,7 +130,9 @@ def run(
129
 
130
 
131
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
132
- face_detector = RetinaFacePredictor(threshold=0.8, device=device, model=RetinaFacePredictor.get_model("mobilenet0.25"))
 
 
133
 
134
  model_names = [
135
  "hopenet_alpha1",
@@ -157,7 +160,13 @@ with gr.Blocks(css="style.css") as demo:
157
  inputs=[image, model_name],
158
  outputs=result,
159
  fn=fn,
160
- cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
161
  )
162
- run_button.click(fn=fn, inputs=[image, model_name], outputs=result, api_name="run")
163
- demo.queue().launch()
 
 
 
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
  import functools
 
6
  import sys
7
  from typing import Callable
8
 
 
11
  import huggingface_hub
12
  import numpy as np
13
  import PIL.Image
14
+ import spaces
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
 
66
  return image
67
 
68
 
69
+ @spaces.GPU
70
  @torch.inference_mode()
71
  def predict(image: np.ndarray, transform: Callable, model: nn.Module, device: torch.device) -> np.ndarray:
72
  indices = torch.arange(66).float().to(device)
 
130
 
131
 
132
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
+ face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
134
+ face_detector.device = device
135
+ face_detector.net.to(device)
136
 
137
  model_names = [
138
  "hopenet_alpha1",
 
160
  inputs=[image, model_name],
161
  outputs=result,
162
  fn=fn,
 
163
  )
164
+ run_button.click(
165
+ fn=fn,
166
+ inputs=[image, model_name],
167
+ outputs=result,
168
+ api_name="run",
169
+ )
170
+
171
+ if __name__ == "__main__":
172
+ demo.queue().launch()