hysts HF Staff commited on
Commit
d5d90d2
·
1 Parent(s): 31b0c59
Files changed (4) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +29 -61
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 😻
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import pathlib
@@ -24,23 +23,8 @@ from utils import generate_label
24
 
25
  TITLE = 'CelebAMask-HQ Face Parsing'
26
  DESCRIPTION = 'This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ.'
27
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.celebamask-hq-face-parsing" alt="visitor badge"/></center>'
28
 
29
- TOKEN = os.environ['TOKEN']
30
-
31
-
32
- def parse_args() -> argparse.Namespace:
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument('--device', type=str, default='cpu')
35
- parser.add_argument('--theme', type=str)
36
- parser.add_argument('--live', action='store_true')
37
- parser.add_argument('--share', action='store_true')
38
- parser.add_argument('--port', type=int)
39
- parser.add_argument('--disable-queue',
40
- dest='enable_queue',
41
- action='store_false')
42
- parser.add_argument('--allow-flagging', type=str, default='never')
43
- return parser.parse_args()
44
 
45
 
46
  @torch.inference_mode()
@@ -62,7 +46,7 @@ def predict(image: PIL.Image.Image, model: nn.Module, transform: Callable,
62
  def load_model(device: torch.device) -> nn.Module:
63
  path = hf_hub_download('hysts/CelebAMask-HQ-Face-Parsing',
64
  'models/model.pth',
65
- use_auth_token=TOKEN)
66
  state_dict = torch.load(path, map_location='cpu')
67
  model = unet()
68
  model.load_state_dict(state_dict)
@@ -71,46 +55,30 @@ def load_model(device: torch.device) -> nn.Module:
71
  return model
72
 
73
 
74
- def main():
75
- args = parse_args()
76
- device = torch.device(args.device)
77
-
78
- model = load_model(device)
79
- transform = T.Compose([
80
- T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
81
- T.ToTensor(),
82
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
83
- ])
84
-
85
- func = functools.partial(predict,
86
- model=model,
87
- transform=transform,
88
- device=device)
89
- func = functools.update_wrapper(func, predict)
90
-
91
- image_dir = pathlib.Path('images')
92
- examples = [[path.as_posix()] for path in sorted(image_dir.glob('*.jpg'))]
93
-
94
- gr.Interface(
95
- func,
96
- gr.inputs.Image(type='pil', label='Input'),
97
- [
98
- gr.outputs.Image(type='numpy', label='Predicted Labels'),
99
- gr.outputs.Image(type='numpy', label='Masked'),
100
- ],
101
- examples=examples,
102
- title=TITLE,
103
- description=DESCRIPTION,
104
- article=ARTICLE,
105
- theme=args.theme,
106
- allow_flagging=args.allow_flagging,
107
- live=args.live,
108
- ).launch(
109
- enable_queue=args.enable_queue,
110
- server_port=args.port,
111
- share=args.share,
112
- )
113
-
114
-
115
- if __name__ == '__main__':
116
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pathlib
 
23
 
24
  TITLE = 'CelebAMask-HQ Face Parsing'
25
  DESCRIPTION = 'This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ.'
 
26
 
27
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  @torch.inference_mode()
 
46
  def load_model(device: torch.device) -> nn.Module:
47
  path = hf_hub_download('hysts/CelebAMask-HQ-Face-Parsing',
48
  'models/model.pth',
49
+ use_auth_token=HF_TOKEN)
50
  state_dict = torch.load(path, map_location='cpu')
51
  model = unet()
52
  model.load_state_dict(state_dict)
 
55
  return model
56
 
57
 
58
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
59
+ model = load_model(device)
60
+ transform = T.Compose([
61
+ T.Resize((512, 512), interpolation=PIL.Image.NEAREST),
62
+ T.ToTensor(),
63
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
64
+ ])
65
+
66
+ func = functools.partial(predict,
67
+ model=model,
68
+ transform=transform,
69
+ device=device)
70
+
71
+ image_dir = pathlib.Path('images')
72
+ examples = [[path.as_posix()] for path in sorted(image_dir.glob('*.jpg'))]
73
+
74
+ gr.Interface(
75
+ fn=func,
76
+ inputs=gr.Image(label='Input', type='pil'),
77
+ outputs=[
78
+ gr.Image(label='Predicted Labels', type='numpy'),
79
+ gr.Image(label='Masked', type='numpy'),
80
+ ],
81
+ examples=examples,
82
+ title=TITLE,
83
+ description=DESCRIPTION,
84
+ ).queue().launch(show_api=False)