Update
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import functools
|
6 |
-
import os
|
7 |
import sys
|
8 |
from typing import Callable
|
9 |
|
@@ -12,6 +11,7 @@ import gradio as gr
|
|
12 |
import huggingface_hub
|
13 |
import numpy as np
|
14 |
import PIL.Image
|
|
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.nn.functional as F
|
@@ -66,6 +66,7 @@ def crop_face(image: np.ndarray, box: tuple[int, int, int, int]) -> np.ndarray:
|
|
66 |
return image
|
67 |
|
68 |
|
|
|
69 |
@torch.inference_mode()
|
70 |
def predict(image: np.ndarray, transform: Callable, model: nn.Module, device: torch.device) -> np.ndarray:
|
71 |
indices = torch.arange(66).float().to(device)
|
@@ -129,7 +130,9 @@ def run(
|
|
129 |
|
130 |
|
131 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
132 |
-
face_detector = RetinaFacePredictor(threshold=0.8, device=
|
|
|
|
|
133 |
|
134 |
model_names = [
|
135 |
"hopenet_alpha1",
|
@@ -157,7 +160,13 @@ with gr.Blocks(css="style.css") as demo:
|
|
157 |
inputs=[image, model_name],
|
158 |
outputs=result,
|
159 |
fn=fn,
|
160 |
-
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
|
161 |
)
|
162 |
-
run_button.click(
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import functools
|
|
|
6 |
import sys
|
7 |
from typing import Callable
|
8 |
|
|
|
11 |
import huggingface_hub
|
12 |
import numpy as np
|
13 |
import PIL.Image
|
14 |
+
import spaces
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.nn.functional as F
|
|
|
66 |
return image
|
67 |
|
68 |
|
69 |
+
@spaces.GPU
|
70 |
@torch.inference_mode()
|
71 |
def predict(image: np.ndarray, transform: Callable, model: nn.Module, device: torch.device) -> np.ndarray:
|
72 |
indices = torch.arange(66).float().to(device)
|
|
|
130 |
|
131 |
|
132 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
133 |
+
face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
|
134 |
+
face_detector.device = device
|
135 |
+
face_detector.net.to(device)
|
136 |
|
137 |
model_names = [
|
138 |
"hopenet_alpha1",
|
|
|
160 |
inputs=[image, model_name],
|
161 |
outputs=result,
|
162 |
fn=fn,
|
|
|
163 |
)
|
164 |
+
run_button.click(
|
165 |
+
fn=fn,
|
166 |
+
inputs=[image, model_name],
|
167 |
+
outputs=result,
|
168 |
+
api_name="run",
|
169 |
+
)
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
demo.queue().launch()
|