File size: 7,675 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import shutil
from PIL import Image
from pathlib import Path

import torch
from torchvision import transforms
import clip
from tqdm.auto import tqdm
import numpy as np

from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.style_clipdraw import (
    Painter, PainterOptimizer, VGG16Extractor, StyleLoss, sample_indices
)
from pytorch_svgrender.plt import plot_img, plot_couple


class StyleCLIPDrawPipeline(ModelState):

    def __init__(self, args):
        logdir_ = f"sd{args.seed}" \
                  f"-P{args.x.num_paths}" \
                  f"-style{args.x.style_strength}" \
                  f"-n{args.x.num_aug}"
        super().__init__(args, log_path_suffix=logdir_)

        # create log dir
        self.png_logs_dir = self.result_path / "png_logs"
        self.svg_logs_dir = self.result_path / "svg_logs"
        if self.accelerator.is_main_process:
            self.png_logs_dir.mkdir(parents=True, exist_ok=True)
            self.svg_logs_dir.mkdir(parents=True, exist_ok=True)

        # make video log
        self.make_video = self.args.mv
        if self.make_video:
            self.frame_idx = 0
            self.frame_log_dir = self.result_path / "frame_logs"
            self.frame_log_dir.mkdir(parents=True, exist_ok=True)

        self.clip, self.tokenize_fn = self.init_clip()

        self.style_extractor = VGG16Extractor(space="normal").to(self.device)
        self.style_loss = StyleLoss()

    def init_clip(self):
        model, _ = clip.load('ViT-B/32', self.device, jit=False)
        return model, clip.tokenize

    def drawing_augment(self, image):
        augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

        # image augmentation transformation
        img_augs = []
        for n in range(self.x_cfg.num_aug):
            img_augs.append(augment_trans(image))
        im_batch = torch.cat(img_augs)
        # clip visual encoding
        image_features = self.clip.encode_image(im_batch)

        return image_features

    def style_file_preprocess(self, style_file):
        process_comp = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t.unsqueeze(0)),
            transforms.Lambda(lambda t: (t + 1) / 2),
        ])
        style_file = process_comp(style_file)
        style_file = style_file.to(self.device)
        return style_file

    def painterly_rendering(self, prompt, style_fpath):
        # load style file
        style_path = Path(style_fpath)
        assert style_path.exists(), f"{style_fpath} is not exist!"
        self.print(f"load style file from: {style_path.as_posix()}")
        style_pil = Image.open(style_path.as_posix()).convert("RGB")
        style_img = self.style_file_preprocess(style_pil)
        shutil.copy(style_fpath, self.result_path)  # copy style file

        # extract style features from style image
        feat_style = None
        for i in range(5):
            with torch.no_grad():
                # r is region of interest (mask)
                feat_e = self.style_extractor.forward_samples_hypercolumn(style_img, samps=1000)
                feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2)

        # text prompt encoding
        self.print(f"prompt: {prompt}")
        text_tokenize = self.tokenize_fn(prompt).to(self.device)
        with torch.no_grad():
            text_features = self.clip.encode_text(text_tokenize)

        renderer = Painter(self.x_cfg,
                           self.args.diffvg,
                           num_strokes=self.x_cfg.num_paths,
                           canvas_size=self.x_cfg.image_size,
                           device=self.device)
        img = renderer.init_image(stage=0)
        self.print("init_image shape: ", img.shape)
        plot_img(img, self.result_path, fname="init_img")

        optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr)
        optimizer.init_optimizers()

        style_weight = 4 * (self.x_cfg.style_strength / 100)
        self.print(f'style_weight: {style_weight}')

        total_step = self.x_cfg.num_iter
        with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
            while self.step < total_step:
                rendering = renderer.get_image(self.step).to(self.device)

                if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
                    plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}")
                    self.frame_idx += 1

                rendering_aug = self.drawing_augment(rendering)

                loss = torch.tensor(0., device=self.device)

                # do clip optimization
                if self.step < 0.9 * total_step:
                    for n in range(self.x_cfg.num_aug):
                        loss -= torch.cosine_similarity(text_features, rendering_aug[n:n + 1], dim=1).mean()

                # do style optimization
                # extract style features based on the approach from STROTSS [Kolkin et al., 2019].
                feat_content = self.style_extractor(rendering)

                xx, xy = sample_indices(feat_content[0], feat_style)

                np.random.shuffle(xx)
                np.random.shuffle(xy)

                L_style = self.style_loss.forward(feat_content, feat_content, feat_style, [xx, xy], 0)

                loss += L_style * style_weight

                pbar.set_description(
                    f"lr: {optimizer.get_lr():.3f}, "
                    f"L_train: {loss.item():.4f}, "
                    f"L_style: {L_style.item():.4f}"
                )

                # optimization
                optimizer.zero_grad_()
                loss.backward()
                optimizer.step_()

                renderer.clip_curve_shape()

                if self.x_cfg.lr_schedule:
                    optimizer.update_lr(self.step)

                if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
                    plot_couple(style_img,
                                rendering,
                                self.step,
                                prompt=prompt,
                                output_dir=self.png_logs_dir.as_posix(),
                                fname=f"iter{self.step}")
                    renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}")

                self.step += 1
                pbar.update(1)

        plot_couple(style_img,
                    rendering,
                    self.step,
                    prompt=prompt,
                    output_dir=self.result_path.as_posix(),
                    fname=f"final_iter")
        renderer.save_svg(self.result_path.as_posix(), "final_svg")

        if self.make_video:
            from subprocess import call
            call([
                "ffmpeg",
                "-framerate", f"{self.args.framerate}",
                "-i", (self.frame_log_dir / "iter%d.png").as_posix(),
                "-vb", "20M",
                (self.result_path / "styleclipdraw_rendering.mp4").as_posix()
            ])

        self.close(msg="painterly rendering complete.")