Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Author: ximing | |
# Description: DiffVG painter and optimizer | |
# Copyright (c) 2023, XiMing Xing. | |
# License: MPL-2.0 License | |
import copy | |
import random | |
from typing import List | |
import omegaconf | |
import numpy as np | |
import pydiffvg | |
import torch | |
from torch.optim.lr_scheduler import LambdaLR | |
from pytorch_svgrender.diffvg_warp import DiffVGState | |
class Painter(DiffVGState): | |
def __init__( | |
self, | |
target_img: torch.Tensor, | |
diffvg_cfg: omegaconf.DictConfig, | |
canvas_size: List, | |
path_type: str = 'unclosed', | |
max_width: float = 3.0, | |
device: torch.device = None, | |
): | |
super(Painter, self).__init__(device, print_timing=diffvg_cfg.print_timing, | |
canvas_width=canvas_size[0], canvas_height=canvas_size[1]) | |
self.target_img = target_img | |
self.path_type: str = path_type | |
self.max_width = max_width | |
self.train_stroke: bool = path_type == 'unclosed' | |
self.strokes_counter: int = 0 # counts the number of calls to "get_path" | |
def init_image(self, num_paths=0): | |
for i in range(num_paths): | |
path = self.get_path() | |
self.shapes.append(path) | |
self.shapes.append(path) | |
fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4])) | |
path_group = pydiffvg.ShapeGroup( | |
shape_ids=torch.tensor([len(self.shapes) - 1]), | |
fill_color=None if self.train_stroke else fill_color_init, | |
stroke_color=stroke_color_init if self.train_stroke else None | |
) | |
self.shape_groups.append(path_group) | |
self.shape_groups.append(path_group) | |
img = self.render_warp() | |
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) \ | |
* (1 - img[:, :, 3:4]) | |
img = img.unsqueeze(0) # convert img from HWC to NCHW | |
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW | |
return img | |
def get_image(self, step: int = 0): | |
img = self.render_warp(seed=step) | |
img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) \ | |
* (1 - img[:, :, 3:4]) | |
img = img.unsqueeze(0) # convert img from HWC to NCHW | |
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW | |
return img | |
def get_path(self): | |
if self.path_type == 'unclosed': | |
num_segments = random.randint(1, 3) | |
num_control_points = torch.zeros(num_segments, dtype=torch.int32) + 2 | |
points = [] | |
p0 = (random.random(), random.random()) | |
points.append(p0) | |
for j in range(num_segments): | |
radius = 0.05 | |
p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) | |
p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) | |
points.append(p1) | |
points.append(p2) | |
points.append(p3) | |
p0 = p3 | |
points = torch.tensor(points) | |
points[:, 0] *= self.canvas_width | |
points[:, 1] *= self.canvas_height | |
# points = torch.rand(3 * num_segments + 1, 2) * min(canvas_width, canvas_height) | |
path = pydiffvg.Path(num_control_points=num_control_points, | |
points=points, | |
stroke_width=torch.tensor(1.0), | |
is_closed=False) | |
elif self.path_type == 'closed': | |
num_segments = random.randint(3, 5) | |
num_control_points = torch.zeros(num_segments, dtype=torch.int32) + 2 | |
points = [] | |
p0 = (random.random(), random.random()) | |
points.append(p0) | |
for j in range(num_segments): | |
radius = 0.05 | |
p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) | |
p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) | |
p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) | |
points.append(p1) | |
points.append(p2) | |
if j < num_segments - 1: | |
points.append(p3) | |
p0 = p3 | |
points = torch.tensor(points) | |
points[:, 0] *= self.canvas_width | |
points[:, 1] *= self.canvas_height | |
path = pydiffvg.Path(num_control_points=num_control_points, | |
points=points, | |
stroke_width=torch.tensor(1.0), | |
is_closed=True) | |
self.strokes_counter += 1 | |
return path | |
def clip_curve_shape(self): | |
if self.train_stroke: # open-form path | |
for path in self.shapes: | |
path.stroke_width.data.clamp_(1.0, self.max_width) | |
for group in self.shape_groups: | |
group.stroke_color.data.clamp_(0.0, 1.0) | |
else: # closed-form path | |
for group in self.shape_groups: | |
group.fill_color.data.clamp_(0.0, 1.0) | |
def set_parameters(self): | |
# stroke`s location optimization | |
self.point_vars = [] | |
for i, path in enumerate(self.shapes): | |
path.points.requires_grad = True | |
self.point_vars.append(path.points) | |
if self.train_stroke: | |
path.stroke_width.requires_grad = True | |
self.width_vars.append(path.stroke_width) | |
# for stroke' color optimization | |
self.color_vars = [] | |
for i, group in enumerate(self.shape_groups): | |
if self.train_stroke: | |
group.stroke_color.requires_grad = True | |
self.color_vars.append(group.stroke_color) | |
else: | |
group.fill_color.requires_grad = True | |
self.color_vars.append(group.fill_color) | |
def get_point_parameters(self): | |
return self.point_vars | |
def get_color_parameters(self): | |
return self.color_vars | |
def get_stroke_parameters(self): | |
return self.width_vars, self.get_color_parameters() | |
def save_svg(self, fpath): | |
pydiffvg.save_svg(f'{fpath}', self.canvas_width, self.canvas_height, self.shapes, self.shape_groups) | |
class LinearDecayLR: | |
def __init__(self, decay_every, decay_ratio): | |
self.decay_every = decay_every | |
self.decay_ratio = decay_ratio | |
def __call__(self, n): | |
decay_time = n // self.decay_every | |
decay_step = n % self.decay_every | |
lr_s = self.decay_ratio ** decay_time | |
lr_e = self.decay_ratio ** (decay_time + 1) | |
r = decay_step / self.decay_every | |
lr = lr_s * (1 - r) + lr_e * r | |
return lr | |
class PainterOptimizer: | |
def __init__(self, | |
renderer: Painter, | |
num_iter: int, | |
lr_config: omegaconf.DictConfig, | |
trainable_stroke: bool = False): | |
self.renderer = renderer | |
self.num_iter = num_iter | |
self.trainable_stroke = trainable_stroke | |
self.lr_base = { | |
'point': lr_config.point, | |
'color': lr_config.color, | |
'stroke_width': lr_config.stroke_width, | |
'stroke_color': lr_config.stroke_color, | |
} | |
self.learnable_params = [] # list[Dict] | |
self.optimizer = None | |
self.scheduler = None | |
def init_optimizer(self): | |
# optimizers | |
params = {} | |
self.renderer.set_parameters() | |
params['point'] = self.renderer.get_point_parameters() | |
if self.trainable_stroke: | |
params['stroke_width'], params['stroke_color'] = self.renderer.get_stroke_parameters() | |
else: | |
params['color'] = self.renderer.get_color_parameters() | |
self.learnable_params = [ | |
{'params': params[ki], 'lr': self.lr_base[ki]} for ki in sorted(params.keys()) | |
] | |
self.optimizer = torch.optim.Adam(self.learnable_params) | |
# lr schedule | |
lr_lambda_fn = LinearDecayLR(self.num_iter, 0.4) | |
self.scheduler = LambdaLR(self.optimizer, lr_lambda=lr_lambda_fn, last_epoch=-1) | |
def update_params(self, name: str, value: torch.tensor): | |
for param_group in self.learnable_params: | |
if param_group.get('_id') == name: | |
param_group['params'] = value | |
def update_lr(self): | |
self.scheduler.step() | |
def zero_grad_(self): | |
self.optimizer.zero_grad() | |
def step_(self): | |
self.optimizer.step() | |
def get_lr(self): | |
return self.optimizer.param_groups[0]['lr'] | |