hysts HF staff commited on
Commit
ef06509
·
1 Parent(s): 2994b00
Files changed (5) hide show
  1. .pre-commit-config.yaml +46 -0
  2. .style.yapf +5 -0
  3. app.py +75 -117
  4. model.py +95 -0
  5. style.css +11 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^StyleSwin
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py CHANGED
@@ -3,149 +3,107 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import functools
7
- import os
8
- import sys
9
 
10
  import gradio as gr
11
- import huggingface_hub
12
  import numpy as np
13
- import PIL.Image
14
- import torch
15
- import torch.nn as nn
16
 
17
- if os.environ.get('SYSTEM') == 'spaces':
18
- os.system("sed -i '14,21d' StyleSwin/op/fused_act.py")
19
- os.system("sed -i '12,19d' StyleSwin/op/upfirdn2d.py")
20
 
21
- sys.path.insert(0, 'StyleSwin')
22
-
23
- from models.generator import Generator
24
-
25
- TITLE = 'microsoft/StyleSwin'
26
- DESCRIPTION = '''This is an unofficial demo for https://github.com/microsoft/StyleSwin.
27
 
28
  Expected execution time on Hugging Face Spaces: 3s (for 256x256 images), 7s (for 1024x1024 images)
29
  '''
30
- SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/StyleSwin/resolve/main/samples'
31
- ARTICLE = f'''## Generated images
32
- ### CelebA-HQ
33
- - size: 1024x1024
34
- - seed: 0-99
35
- ![CelebA-HQ samples]({SAMPLE_IMAGE_DIR}/celeba-hq.jpg)
36
- ### FFHQ
37
- - size: 1024x1024
38
- - seed: 0-99
39
- ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
40
- ### LSUN Church
41
- - size: 256x256
42
- - seed: 0-99
43
- ![LSUN Church samples]({SAMPLE_IMAGE_DIR}/lsun-church.jpg)
44
-
45
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.styleswin" alt="visitor badge"/></center>
46
- '''
47
-
48
- TOKEN = os.environ['TOKEN']
49
-
50
- MODEL_REPO = 'hysts/StyleSwin'
51
- MODEL_NAMES = [
52
- 'CelebAHQ_256',
53
- 'FFHQ_256',
54
- 'LSUNChurch_256',
55
- 'CelebAHQ_1024',
56
- 'FFHQ_1024',
57
- ]
58
 
59
 
60
  def parse_args() -> argparse.Namespace:
61
  parser = argparse.ArgumentParser()
62
  parser.add_argument('--device', type=str, default='cpu')
63
  parser.add_argument('--theme', type=str)
64
- parser.add_argument('--live', action='store_true')
65
  parser.add_argument('--share', action='store_true')
66
  parser.add_argument('--port', type=int)
67
  parser.add_argument('--disable-queue',
68
  dest='enable_queue',
69
  action='store_false')
70
- parser.add_argument('--allow-flagging', type=str, default='never')
71
  return parser.parse_args()
72
 
73
 
74
- def load_model(model_name: str, device: torch.device) -> nn.Module:
75
- size = int(model_name.split('_')[1])
76
- channel_multiplier = 1 if size == 1024 else 2
77
- model = Generator(size,
78
- style_dim=512,
79
- n_mlp=8,
80
- channel_multiplier=channel_multiplier)
81
- ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
82
- f'models/{model_name}.pt',
83
- use_auth_token=TOKEN)
84
- ckpt = torch.load(ckpt_path)
85
- model.load_state_dict(ckpt['g_ema'])
86
- model.to(device)
87
- model.eval()
88
- return model
89
-
90
-
91
- def generate_z(seed: int, device: torch.device) -> torch.Tensor:
92
- return torch.from_numpy(np.random.RandomState(seed).randn(
93
- 1, 512)).to(device).float()
94
-
95
-
96
- def postprocess(tensors: torch.Tensor) -> torch.Tensor:
97
- assert tensors.dim() == 4
98
- tensors = tensors.cpu()
99
- std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
100
- mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
101
- tensors = tensors * std + mean
102
- tensors = (tensors * 255).clamp(0, 255).to(torch.uint8)
103
- return tensors
104
-
105
-
106
- @torch.inference_mode()
107
- def generate_image(model_name: str, seed: int, model_dict: dict,
108
- device: torch.device) -> PIL.Image.Image:
109
- model = model_dict[model_name]
110
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
111
- z = generate_z(seed, device)
112
- out, _ = model(z)
113
- out = postprocess(out)
114
- out = out.numpy()[0].transpose(1, 2, 0)
115
- return PIL.Image.fromarray(out, 'RGB')
116
 
117
 
118
- def main():
119
- gr.close_all()
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
121
  args = parse_args()
122
- device = torch.device(args.device)
123
-
124
- model_dict = {name: load_model(name, device) for name in MODEL_NAMES}
125
-
126
- func = functools.partial(generate_image,
127
- model_dict=model_dict,
128
- device=device)
129
- func = functools.update_wrapper(func, generate_image)
130
-
131
- gr.Interface(
132
- func,
133
- [
134
- gr.inputs.Radio(MODEL_NAMES,
135
- type='value',
136
- default='FFHQ_256',
137
- label='Model',
138
- optional=False),
139
- gr.inputs.Slider(0, 2147483647, step=1, default=0, label='Seed'),
140
- ],
141
- gr.outputs.Image(type='pil', label='Output'),
142
- title=TITLE,
143
- description=DESCRIPTION,
144
- article=ARTICLE,
145
- theme=args.theme,
146
- allow_flagging=args.allow_flagging,
147
- live=args.live,
148
- ).launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  enable_queue=args.enable_queue,
150
  server_port=args.port,
151
  share=args.share,
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
 
6
 
7
  import gradio as gr
 
8
  import numpy as np
 
 
 
9
 
10
+ from model import Model
 
 
11
 
12
+ TITLE = '# microsoft/StyleSwin'
13
+ DESCRIPTION = '''This is an unofficial demo for [https://github.com/microsoft/StyleSwin](https://github.com/microsoft/StyleSwin).
 
 
 
 
14
 
15
  Expected execution time on Hugging Face Spaces: 3s (for 256x256 images), 7s (for 1024x1024 images)
16
  '''
17
+ FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.styleswin" alt="visitor badge" />'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def parse_args() -> argparse.Namespace:
21
  parser = argparse.ArgumentParser()
22
  parser.add_argument('--device', type=str, default='cpu')
23
  parser.add_argument('--theme', type=str)
 
24
  parser.add_argument('--share', action='store_true')
25
  parser.add_argument('--port', type=int)
26
  parser.add_argument('--disable-queue',
27
  dest='enable_queue',
28
  action='store_false')
 
29
  return parser.parse_args()
30
 
31
 
32
+ def get_sample_image_url(name: str) -> str:
33
+ sample_image_dir = 'https://huggingface.co/spaces/hysts/StyleSwin/resolve/main/samples'
34
+ return f'{sample_image_dir}/{name}.jpg'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ def get_sample_image_markdown(name: str) -> str:
38
+ url = get_sample_image_url(name)
39
+ if name == 'celeba-hq':
40
+ size = 1024
41
+ elif name == 'ffhq':
42
+ size = 1024
43
+ elif name == 'lsun-church':
44
+ size = 256
45
+ else:
46
+ raise ValueError
47
+ seed = '0-99'
48
+ return f'''
49
+ - size: {size}x{size}
50
+ - seed: {seed}
51
+ ![sample images]({url})'''
52
 
53
+
54
+ def main():
55
  args = parse_args()
56
+ model = Model(args.device)
57
+
58
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
59
+ gr.Markdown(TITLE)
60
+ gr.Markdown(DESCRIPTION)
61
+
62
+ with gr.Tabs():
63
+ with gr.TabItem('App'):
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Group():
67
+ model_name = gr.Dropdown(
68
+ model.MODEL_NAMES,
69
+ value=model.MODEL_NAMES[3],
70
+ label='Model')
71
+ seed = gr.Slider(0,
72
+ np.iinfo(np.uint32).max,
73
+ step=1,
74
+ value=0,
75
+ label='Seed')
76
+ run_button = gr.Button('Run')
77
+ with gr.Column():
78
+ result = gr.Image(label='Result', elem_id='result')
79
+
80
+ with gr.TabItem('Sample Images'):
81
+ with gr.Row():
82
+ model_name2 = gr.Dropdown([
83
+ 'celeba-hq',
84
+ 'ffhq',
85
+ 'lsun-church',
86
+ ],
87
+ value='celeba-hq',
88
+ label='Model')
89
+ with gr.Row():
90
+ text = get_sample_image_markdown(model_name2.value)
91
+ sample_images = gr.Markdown(text)
92
+
93
+ gr.Markdown(FOOTER)
94
+
95
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
96
+ run_button.click(fn=model.set_model_and_generate_image,
97
+ inputs=[
98
+ model_name,
99
+ seed,
100
+ ],
101
+ outputs=result)
102
+ model_name2.change(fn=get_sample_image_markdown,
103
+ inputs=model_name2,
104
+ outputs=sample_images)
105
+
106
+ demo.launch(
107
  enable_queue=args.enable_queue,
108
  server_port=args.port,
109
  share=args.share,
model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import sys
6
+
7
+ import huggingface_hub
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ if os.environ.get('SYSTEM') == 'spaces':
14
+ os.system("sed -i '14,21d' StyleSwin/op/fused_act.py")
15
+ os.system("sed -i '12,19d' StyleSwin/op/upfirdn2d.py")
16
+
17
+ current_dir = pathlib.Path(__file__).parent
18
+ submodule_dir = current_dir / 'StyleSwin'
19
+ sys.path.insert(0, submodule_dir.as_posix())
20
+
21
+ from models.generator import Generator
22
+
23
+ HF_TOKEN = os.environ['HF_TOKEN']
24
+
25
+
26
+ class Model:
27
+ MODEL_NAMES = [
28
+ 'CelebAHQ_256',
29
+ 'FFHQ_256',
30
+ 'LSUNChurch_256',
31
+ 'CelebAHQ_1024',
32
+ 'FFHQ_1024',
33
+ ]
34
+
35
+ def __init__(self, device: str | torch.device):
36
+ self.device = torch.device(device)
37
+ self._download_all_models()
38
+ self.model_name = self.MODEL_NAMES[3]
39
+ self.model = self._load_model(self.model_name)
40
+
41
+ self.std = torch.FloatTensor([0.229, 0.224,
42
+ 0.225])[None, :, None,
43
+ None].to(self.device)
44
+ self.mean = torch.FloatTensor([0.485, 0.456,
45
+ 0.406])[None, :, None,
46
+ None].to(self.device)
47
+
48
+ def _load_model(self, model_name: str) -> nn.Module:
49
+ size = int(model_name.split('_')[1])
50
+ channel_multiplier = 1 if size == 1024 else 2
51
+ model = Generator(size,
52
+ style_dim=512,
53
+ n_mlp=8,
54
+ channel_multiplier=channel_multiplier)
55
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/StyleSwin',
56
+ f'models/{model_name}.pt',
57
+ use_auth_token=HF_TOKEN)
58
+ ckpt = torch.load(ckpt_path)
59
+ model.load_state_dict(ckpt['g_ema'])
60
+ model.to(self.device)
61
+ model.eval()
62
+ return model
63
+
64
+ def set_model(self, model_name: str) -> None:
65
+ if model_name == self.model_name:
66
+ return
67
+ self.model_name = model_name
68
+ self.model = self._load_model(model_name)
69
+
70
+ def _download_all_models(self):
71
+ for name in self.MODEL_NAMES:
72
+ self._load_model(name)
73
+
74
+ def generate_z(self, seed: int) -> torch.Tensor:
75
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
76
+ z = np.random.RandomState(seed).randn(1, 512)
77
+ return torch.from_numpy(z).float().to(self.device)
78
+
79
+ def postprocess(self, tensors: torch.Tensor) -> np.ndarray:
80
+ assert tensors.dim() == 4
81
+ tensors = tensors * self.std + self.mean
82
+ tensors = (tensors * 255).clamp(0, 255).to(torch.uint8)
83
+ return tensors.permute(0, 2, 3, 1).cpu().numpy()
84
+
85
+ @torch.inference_mode()
86
+ def generate_image(self, seed: int) -> np.ndarray:
87
+ z = self.generate_z(seed)
88
+ out, _ = self.model(z)
89
+ out = self.postprocess(out)
90
+ return out[0]
91
+
92
+ def set_model_and_generate_image(self, model_name: str,
93
+ seed: int) -> np.ndarray:
94
+ self.set_model(model_name)
95
+ return self.generate_image(seed)
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }