kindjeff hysts HF Staff commited on
Commit
286c216
·
0 Parent(s):

Duplicate from Gradio-Blocks/HairCLIP

Browse files

Co-authored-by: hysts <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.wasm filter=lfs diff=lfs merge=lfs -text
25
+ *.xz filter=lfs diff=lfs merge=lfs -text
26
+ *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "HairCLIP"]
2
+ path = HairCLIP
3
+ url = https://github.com/wty-ustc/HairCLIP
4
+ [submodule "encoder4editing"]
5
+ path = encoder4editing
6
+ url = https://github.com/omertov/encoder4editing
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch.*
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
HairCLIP ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 29290cf5bdca0f21ff27e0ec2e93bdd1ebbe3605
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HairCLIP
3
+ emoji: ⚡
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ suggested_hardware: t4-small
11
+ duplicated_from: Gradio-Blocks/HairCLIP
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
15
+
16
+ https://arxiv.org/abs/2112.05142
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+
9
+ from model import Model
10
+
11
+ DESCRIPTION = '''# [HairCLIP](https://github.com/wty-ustc/HairCLIP)
12
+
13
+ <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
14
+ '''
15
+
16
+
17
+ def load_hairstyle_list() -> list[str]:
18
+ with open('HairCLIP/mapper/hairstyle_list.txt') as f:
19
+ lines = [line.strip() for line in f.readlines()]
20
+ lines = [line[:-10] for line in lines]
21
+ return lines
22
+
23
+
24
+ def set_example_image(example: list) -> dict:
25
+ return gr.Image.update(value=example[0])
26
+
27
+
28
+ def update_step2_components(choice: str) -> tuple[dict, dict]:
29
+ return (
30
+ gr.Dropdown.update(visible=choice in ['hairstyle', 'both']),
31
+ gr.Textbox.update(visible=choice in ['color', 'both']),
32
+ )
33
+
34
+
35
+ model = Model()
36
+
37
+ with gr.Blocks(css='style.css') as demo:
38
+ gr.Markdown(DESCRIPTION)
39
+ with gr.Box():
40
+ gr.Markdown('## Step 1')
41
+ with gr.Row():
42
+ with gr.Column():
43
+ with gr.Row():
44
+ input_image = gr.Image(label='Input Image',
45
+ type='filepath')
46
+ with gr.Row():
47
+ preprocess_button = gr.Button('Preprocess')
48
+ with gr.Column():
49
+ aligned_face = gr.Image(label='Aligned Face',
50
+ type='pil',
51
+ interactive=False)
52
+ with gr.Column():
53
+ reconstructed_face = gr.Image(label='Reconstructed Face',
54
+ type='numpy')
55
+ latent = gr.Variable()
56
+
57
+ with gr.Row():
58
+ paths = sorted(pathlib.Path('images').glob('*.jpg'))
59
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
60
+ inputs=input_image)
61
+
62
+ with gr.Box():
63
+ gr.Markdown('## Step 2')
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Row():
67
+ editing_type = gr.Radio(
68
+ label='Editing Type',
69
+ choices=['hairstyle', 'color', 'both'],
70
+ value='both',
71
+ type='value')
72
+ with gr.Row():
73
+ hairstyles = load_hairstyle_list()
74
+ hairstyle_index = gr.Dropdown(label='Hairstyle',
75
+ choices=hairstyles,
76
+ value='afro',
77
+ type='index')
78
+ with gr.Row():
79
+ color_description = gr.Textbox(label='Color', value='red')
80
+ with gr.Row():
81
+ run_button = gr.Button('Run')
82
+
83
+ with gr.Column():
84
+ result = gr.Image(label='Result')
85
+
86
+ preprocess_button.click(fn=model.detect_and_align_face,
87
+ inputs=input_image,
88
+ outputs=aligned_face)
89
+ aligned_face.change(fn=model.reconstruct_face,
90
+ inputs=aligned_face,
91
+ outputs=[reconstructed_face, latent])
92
+ editing_type.change(fn=update_step2_components,
93
+ inputs=editing_type,
94
+ outputs=[hairstyle_index, color_description])
95
+ run_button.click(fn=model.generate,
96
+ inputs=[
97
+ editing_type,
98
+ hairstyle_index,
99
+ color_description,
100
+ latent,
101
+ ],
102
+ outputs=result)
103
+
104
+ demo.queue(max_size=10).launch()
encoder4editing ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 99ea50578695d2e8a1cf7259d8ee89b23eea942b
images/95UF6LXe-Lo.jpg ADDED

Git LFS Details

  • SHA256: 9ba751a6519822fa683e062ee3a383e748f15b41d4ca87d14c4fa73f9beed845
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
images/ILip77SbmOE.jpg ADDED

Git LFS Details

  • SHA256: 3eed82923bc76a90f067415f148d56239fdfa4a1aca9eef1d459bc6050c9dde8
  • Pointer size: 131 Bytes
  • Size of remote file: 939 kB
images/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ These images are freely-usable ones from [Unsplash](https://unsplash.com/).
2
+
3
+ - https://unsplash.com/photos/rDEOVtE7vOs
4
+ - https://unsplash.com/photos/et_78QkMMQs
5
+ - https://unsplash.com/photos/ILip77SbmOE
6
+ - https://unsplash.com/photos/95UF6LXe-Lo
images/et_78QkMMQs.jpg ADDED

Git LFS Details

  • SHA256: c63a2e9de5eda3cb28012cfc8e4ba9384daeda8cca7a8989ad90b21a1293cc6f
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
images/rDEOVtE7vOs.jpg ADDED

Git LFS Details

  • SHA256: b136bf195fef5599f277a563f0eef79af5301d9352d4ebf82bd7a0a061b7bdc0
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+ import sys
8
+ from typing import Callable, Union
9
+
10
+ import dlib
11
+ import huggingface_hub
12
+ import numpy as np
13
+ import PIL.Image
14
+ import torch
15
+ import torch.nn as nn
16
+ import torchvision.transforms as T
17
+
18
+ if os.getenv('SYSTEM') == 'spaces' and not torch.cuda.is_available():
19
+ with open('patch.e4e') as f:
20
+ subprocess.run('patch -p1'.split(), cwd='encoder4editing', stdin=f)
21
+ with open('patch.hairclip') as f:
22
+ subprocess.run('patch -p1'.split(), cwd='HairCLIP', stdin=f)
23
+
24
+ app_dir = pathlib.Path(__file__).parent
25
+
26
+ e4e_dir = app_dir / 'encoder4editing'
27
+ sys.path.insert(0, e4e_dir.as_posix())
28
+
29
+ from models.psp import pSp
30
+ from utils.alignment import align_face
31
+
32
+ hairclip_dir = app_dir / 'HairCLIP'
33
+ mapper_dir = hairclip_dir / 'mapper'
34
+ sys.path.insert(0, hairclip_dir.as_posix())
35
+ sys.path.insert(0, mapper_dir.as_posix())
36
+
37
+ from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
38
+ from mapper.hairclip_mapper import HairCLIPMapper
39
+
40
+
41
+ class Model:
42
+ def __init__(self):
43
+ self.device = torch.device(
44
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
45
+ self.landmark_model = self._create_dlib_landmark_model()
46
+ self.e4e = self._load_e4e()
47
+ self.hairclip = self._load_hairclip()
48
+ self.transform = self._create_transform()
49
+
50
+ @staticmethod
51
+ def _create_dlib_landmark_model():
52
+ path = huggingface_hub.hf_hub_download(
53
+ 'public-data/dlib_face_landmark_model',
54
+ 'shape_predictor_68_face_landmarks.dat')
55
+ return dlib.shape_predictor(path)
56
+
57
+ def _load_e4e(self) -> nn.Module:
58
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/e4e',
59
+ 'e4e_ffhq_encode.pt')
60
+ ckpt = torch.load(ckpt_path, map_location='cpu')
61
+ opts = ckpt['opts']
62
+ opts['device'] = self.device.type
63
+ opts['checkpoint_path'] = ckpt_path
64
+ opts = argparse.Namespace(**opts)
65
+ model = pSp(opts)
66
+ model.to(self.device)
67
+ model.eval()
68
+ return model
69
+
70
+ def _load_hairclip(self) -> nn.Module:
71
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/HairCLIP',
72
+ 'hairclip.pt')
73
+ ckpt = torch.load(ckpt_path, map_location='cpu')
74
+ opts = ckpt['opts']
75
+ opts['device'] = self.device.type
76
+ opts['checkpoint_path'] = ckpt_path
77
+ opts['editing_type'] = 'both'
78
+ opts['input_type'] = 'text'
79
+ opts['hairstyle_description'] = 'HairCLIP/mapper/hairstyle_list.txt'
80
+ opts['color_description'] = 'red'
81
+ opts = argparse.Namespace(**opts)
82
+ model = HairCLIPMapper(opts)
83
+ model.to(self.device)
84
+ model.eval()
85
+ return model
86
+
87
+ @staticmethod
88
+ def _create_transform() -> Callable:
89
+ transform = T.Compose([
90
+ T.Resize(256),
91
+ T.CenterCrop(256),
92
+ T.ToTensor(),
93
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
94
+ ])
95
+ return transform
96
+
97
+ def detect_and_align_face(self, image: str) -> PIL.Image.Image:
98
+ image = align_face(filepath=image, predictor=self.landmark_model)
99
+ return image
100
+
101
+ @staticmethod
102
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
103
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
104
+
105
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
106
+ tensor = self.denormalize(tensor)
107
+ return tensor.cpu().numpy().transpose(1, 2, 0)
108
+
109
+ @torch.inference_mode()
110
+ def reconstruct_face(
111
+ self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
112
+ input_data = self.transform(image).unsqueeze(0).to(self.device)
113
+ reconstructed_images, latents = self.e4e(input_data,
114
+ randomize_noise=False,
115
+ return_latents=True)
116
+ reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
117
+ reconstructed = self.postprocess(reconstructed)
118
+ return reconstructed, latents[0]
119
+
120
+ @torch.inference_mode()
121
+ def generate(self, editing_type: str, hairstyle_index: int,
122
+ color_description: str, latent: torch.Tensor) -> np.ndarray:
123
+ opts = self.hairclip.opts
124
+ opts.editing_type = editing_type
125
+ opts.color_description = color_description
126
+
127
+ if editing_type == 'color':
128
+ hairstyle_index = 0
129
+
130
+ device = torch.device(opts.device)
131
+
132
+ dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(),
133
+ opts=opts)
134
+ w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
135
+
136
+ w = w.unsqueeze(0).to(device)
137
+ hairstyle_text_inputs = hairstyle_text_inputs_list[
138
+ hairstyle_index].unsqueeze(0).to(device)
139
+ color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
140
+
141
+ hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
142
+ color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
143
+
144
+ w_hat = w + 0.1 * self.hairclip.mapper(
145
+ w,
146
+ hairstyle_text_inputs,
147
+ color_text_inputs,
148
+ hairstyle_tensor_hairmasked,
149
+ color_tensor_hairmasked,
150
+ )
151
+ x_hat, _ = self.hairclip.decoder(
152
+ [w_hat],
153
+ input_is_latent=True,
154
+ return_latents=True,
155
+ randomize_noise=False,
156
+ truncation=1,
157
+ )
158
+ res = torch.clamp(x_hat[0].detach(), -1, 1)
159
+ res = self.postprocess(res)
160
+ return res
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cmake
2
+ ninja-build
patch.e4e ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
2
+ index 973a84f..6854b97 100644
3
+ --- a/models/stylegan2/op/fused_act.py
4
+ +++ b/models/stylegan2/op/fused_act.py
5
+ @@ -2,17 +2,18 @@ import os
6
+
7
+ import torch
8
+ from torch import nn
9
+ +from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+ from torch.utils.cpp_extension import load
12
+
13
+ -module_path = os.path.dirname(__file__)
14
+ -fused = load(
15
+ - 'fused',
16
+ - sources=[
17
+ - os.path.join(module_path, 'fused_bias_act.cpp'),
18
+ - os.path.join(module_path, 'fused_bias_act_kernel.cu'),
19
+ - ],
20
+ -)
21
+ +#module_path = os.path.dirname(__file__)
22
+ +#fused = load(
23
+ +# 'fused',
24
+ +# sources=[
25
+ +# os.path.join(module_path, 'fused_bias_act.cpp'),
26
+ +# os.path.join(module_path, 'fused_bias_act_kernel.cu'),
27
+ +# ],
28
+ +#)
29
+
30
+
31
+ class FusedLeakyReLUFunctionBackward(Function):
32
+ @@ -82,4 +83,18 @@ class FusedLeakyReLU(nn.Module):
33
+
34
+
35
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
36
+ - return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
37
+ + if input.device.type == "cpu":
38
+ + if bias is not None:
39
+ + rest_dim = [1] * (input.ndim - bias.ndim - 1)
40
+ + return (
41
+ + F.leaky_relu(
42
+ + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
43
+ + )
44
+ + * scale
45
+ + )
46
+ +
47
+ + else:
48
+ + return F.leaky_relu(input, negative_slope=0.2) * scale
49
+ +
50
+ + else:
51
+ + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
52
+ diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
53
+ index 7bc5a1e..5465d1a 100644
54
+ --- a/models/stylegan2/op/upfirdn2d.py
55
+ +++ b/models/stylegan2/op/upfirdn2d.py
56
+ @@ -1,17 +1,18 @@
57
+ import os
58
+
59
+ import torch
60
+ +from torch.nn import functional as F
61
+ from torch.autograd import Function
62
+ from torch.utils.cpp_extension import load
63
+
64
+ -module_path = os.path.dirname(__file__)
65
+ -upfirdn2d_op = load(
66
+ - 'upfirdn2d',
67
+ - sources=[
68
+ - os.path.join(module_path, 'upfirdn2d.cpp'),
69
+ - os.path.join(module_path, 'upfirdn2d_kernel.cu'),
70
+ - ],
71
+ -)
72
+ +#module_path = os.path.dirname(__file__)
73
+ +#upfirdn2d_op = load(
74
+ +# 'upfirdn2d',
75
+ +# sources=[
76
+ +# os.path.join(module_path, 'upfirdn2d.cpp'),
77
+ +# os.path.join(module_path, 'upfirdn2d_kernel.cu'),
78
+ +# ],
79
+ +#)
80
+
81
+
82
+ class UpFirDn2dBackward(Function):
83
+ @@ -97,8 +98,8 @@ class UpFirDn2d(Function):
84
+
85
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
86
+
87
+ - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
88
+ - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
89
+ + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
90
+ + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
91
+ ctx.out_size = (out_h, out_w)
92
+
93
+ ctx.up = (up_x, up_y)
94
+ @@ -140,9 +141,13 @@ class UpFirDn2d(Function):
95
+
96
+
97
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
98
+ - out = UpFirDn2d.apply(
99
+ - input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
100
+ - )
101
+ + if input.device.type == "cpu":
102
+ + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
103
+ +
104
+ + else:
105
+ + out = UpFirDn2d.apply(
106
+ + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
107
+ + )
108
+
109
+ return out
110
+
111
+ @@ -150,6 +155,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
112
+ def upfirdn2d_native(
113
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
114
+ ):
115
+ + _, channel, in_h, in_w = input.shape
116
+ + input = input.reshape(-1, in_h, in_w, 1)
117
+ +
118
+ _, in_h, in_w, minor = input.shape
119
+ kernel_h, kernel_w = kernel.shape
120
+
121
+ @@ -180,5 +188,9 @@ def upfirdn2d_native(
122
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
123
+ )
124
+ out = out.permute(0, 2, 3, 1)
125
+ + out = out[:, ::down_y, ::down_x, :]
126
+ +
127
+ + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
128
+ + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
129
+
130
+ - return out[:, ::down_y, ::down_x, :]
131
+ + return out.view(-1, channel, out_h, out_w)
patch.hairclip ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/mapper/latent_mappers.py b/mapper/latent_mappers.py
2
+ index 56b9c55..f0dd005 100644
3
+ --- a/mapper/latent_mappers.py
4
+ +++ b/mapper/latent_mappers.py
5
+ @@ -19,7 +19,7 @@ class ModulationModule(Module):
6
+
7
+ def forward(self, x, embedding, cut_flag):
8
+ x = self.fc(x)
9
+ - x = self.norm(x)
10
+ + x = self.norm(x)
11
+ if cut_flag == 1:
12
+ return x
13
+ gamma = self.gamma_function(embedding.float())
14
+ @@ -39,20 +39,20 @@ class SubHairMapper(Module):
15
+ def forward(self, x, embedding, cut_flag=0):
16
+ x = self.pixelnorm(x)
17
+ for modulation_module in self.modulation_module_list:
18
+ - x = modulation_module(x, embedding, cut_flag)
19
+ + x = modulation_module(x, embedding, cut_flag)
20
+ return x
21
+
22
+ -class HairMapper(Module):
23
+ +class HairMapper(Module):
24
+ def __init__(self, opts):
25
+ super(HairMapper, self).__init__()
26
+ self.opts = opts
27
+ - self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda")
28
+ + self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device)
29
+ self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
30
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
31
+ self.hairstyle_cut_flag = 0
32
+ self.color_cut_flag = 0
33
+
34
+ - if not opts.no_coarse_mapper:
35
+ + if not opts.no_coarse_mapper:
36
+ self.course_mapping = SubHairMapper(opts, 4)
37
+ if not opts.no_medium_mapper:
38
+ self.medium_mapping = SubHairMapper(opts, 4)
39
+ @@ -70,13 +70,13 @@ class HairMapper(Module):
40
+ elif hairstyle_tensor.shape[1] != 1:
41
+ hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
42
+ else:
43
+ - hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda()
44
+ + hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
45
+ if color_text_inputs.shape[1] != 1:
46
+ color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach()
47
+ elif color_tensor.shape[1] != 1:
48
+ color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
49
+ else:
50
+ - color_embedding = torch.ones(x.shape[0], 18, 512).cuda()
51
+ + color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
52
+
53
+
54
+ if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1):
55
+ @@ -106,4 +106,4 @@ class HairMapper(Module):
56
+ x_fine = torch.zeros_like(x_fine)
57
+
58
+ out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
59
+ - return out
60
+
61
+ + return out
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ dlib==19.23.0
2
+ git+https://github.com/openai/CLIP.git
3
+ numpy==1.22.3
4
+ opencv-python-headless==4.5.5.64
5
+ Pillow==9.1.0
6
+ scipy==1.8.0
7
+ torch==1.11.0
8
+ torchvision==0.12.0
style.css ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ img#teaser {
6
+ max-width: 1000px;
7
+ max-height: 600px;
8
+ }