hysts HF Staff commited on
Commit
49d2e89
·
1 Parent(s): c120ebb
Files changed (1) hide show
  1. app.py +31 -45
app.py CHANGED
@@ -2,34 +2,43 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
- import os
7
  import pathlib
8
  import sys
9
- from typing import Callable
10
 
11
  import gradio as gr
12
  import numpy as np
13
  import PIL.Image
14
  import torch
15
- import torch.nn as nn
16
  import torchvision.transforms as T
17
  from huggingface_hub import hf_hub_download
18
 
19
- sys.path.insert(0, 'CelebAMask-HQ/face_parsing')
20
 
21
  from unet import unet
22
  from utils import generate_label
23
 
24
- TITLE = 'CelebAMask-HQ Face Parsing'
25
- DESCRIPTION = 'This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ.'
26
 
27
- HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  @torch.inference_mode()
31
- def predict(image: PIL.Image.Image, model: nn.Module, transform: Callable,
32
- device: torch.device) -> np.ndarray:
33
  data = transform(image)
34
  data = data.unsqueeze(0).to(device)
35
  out = model(data)
@@ -37,48 +46,25 @@ def predict(image: PIL.Image.Image, model: nn.Module, transform: Callable,
37
  out = out[0].cpu().numpy().transpose(1, 2, 0)
38
  out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
39
 
40
- res = np.asarray(image.resize(
41
- (512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5
42
  res = np.clip(np.round(res), 0, 255).astype(np.uint8)
43
  return out, res
44
 
45
 
46
- def load_model(device: torch.device) -> nn.Module:
47
- path = hf_hub_download('hysts/CelebAMask-HQ-Face-Parsing',
48
- 'models/model.pth',
49
- use_auth_token=HF_TOKEN)
50
- state_dict = torch.load(path, map_location='cpu')
51
- model = unet()
52
- model.load_state_dict(state_dict)
53
- model.eval()
54
- model.to(device)
55
- return model
56
-
57
-
58
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
59
- model = load_model(device)
60
- transform = T.Compose([
61
- T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
62
- T.ToTensor(),
63
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
64
- ])
65
 
66
- func = functools.partial(predict,
67
- model=model,
68
- transform=transform,
69
- device=device)
70
-
71
- image_dir = pathlib.Path('images')
72
- examples = [[path.as_posix()] for path in sorted(image_dir.glob('*.jpg'))]
73
-
74
- gr.Interface(
75
- fn=func,
76
- inputs=gr.Image(label='Input', type='pil'),
77
  outputs=[
78
- gr.Image(label='Predicted Labels', type='numpy'),
79
- gr.Image(label='Masked', type='numpy'),
80
  ],
81
  examples=examples,
82
  title=TITLE,
83
  description=DESCRIPTION,
84
- ).queue().launch(show_api=False)
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import pathlib
6
  import sys
 
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import PIL.Image
11
  import torch
 
12
  import torchvision.transforms as T
13
  from huggingface_hub import hf_hub_download
14
 
15
+ sys.path.insert(0, "CelebAMask-HQ/face_parsing")
16
 
17
  from unet import unet
18
  from utils import generate_label
19
 
20
+ TITLE = "CelebAMask-HQ Face Parsing"
21
+ DESCRIPTION = "This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ."
22
 
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ transform = T.Compose(
25
+ [
26
+ T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
27
+ T.ToTensor(),
28
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
29
+ ]
30
+ )
31
+
32
+ path = hf_hub_download("hysts/CelebAMask-HQ-Face-Parsing", "models/model.pth")
33
+ state_dict = torch.load(path, map_location="cpu")
34
+ model = unet()
35
+ model.load_state_dict(state_dict)
36
+ model.eval()
37
+ model.to(device)
38
 
39
 
40
  @torch.inference_mode()
41
+ def predict(image: PIL.Image.Image) -> np.ndarray:
 
42
  data = transform(image)
43
  data = data.unsqueeze(0).to(device)
44
  out = model(data)
 
46
  out = out[0].cpu().numpy().transpose(1, 2, 0)
47
  out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8)
48
 
49
+ res = np.asarray(image.resize((512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5
 
50
  res = np.clip(np.round(res), 0, 255).astype(np.uint8)
51
  return out, res
52
 
53
 
54
+ image_dir = pathlib.Path("images")
55
+ examples = [[path.as_posix()] for path in sorted(image_dir.glob("*.jpg"))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ demo = gr.Interface(
58
+ fn=predict,
59
+ inputs=gr.Image(label="Input", type="pil"),
 
 
 
 
 
 
 
 
60
  outputs=[
61
+ gr.Image(label="Predicted Labels", type="numpy"),
62
+ gr.Image(label="Masked", type="numpy"),
63
  ],
64
  examples=examples,
65
  title=TITLE,
66
  description=DESCRIPTION,
67
+ )
68
+
69
+ if __name__ == "__main__":
70
+ demo.queue(max_size=20).launch()