# -*- coding: utf-8 -*- # Author: ximing # Description: DiffVG painter and optimizer # Copyright (c) 2023, XiMing Xing. # License: MPL-2.0 License import copy import random from typing import List import omegaconf import numpy as np import pydiffvg import torch from torch.optim.lr_scheduler import LambdaLR from pytorch_svgrender.diffvg_warp import DiffVGState class Painter(DiffVGState): def __init__( self, target_img: torch.Tensor, diffvg_cfg: omegaconf.DictConfig, canvas_size: List, path_type: str = 'unclosed', max_width: float = 3.0, device: torch.device = None, ): super(Painter, self).__init__(device, print_timing=diffvg_cfg.print_timing, canvas_width=canvas_size[0], canvas_height=canvas_size[1]) self.target_img = target_img self.path_type: str = path_type self.max_width = max_width self.train_stroke: bool = path_type == 'unclosed' self.strokes_counter: int = 0 # counts the number of calls to "get_path" def init_image(self, num_paths=0): for i in range(num_paths): path = self.get_path() self.shapes.append(path) self.shapes.append(path) fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4])) path_group = pydiffvg.ShapeGroup( shape_ids=torch.tensor([len(self.shapes) - 1]), fill_color=None if self.train_stroke else fill_color_init, stroke_color=stroke_color_init if self.train_stroke else None ) self.shape_groups.append(path_group) self.shape_groups.append(path_group) img = self.render_warp() img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) \ * (1 - img[:, :, 3:4]) img = img.unsqueeze(0) # convert img from HWC to NCHW img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW return img def get_image(self, step: int = 0): img = self.render_warp(seed=step) img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) \ * (1 - img[:, :, 3:4]) img = img.unsqueeze(0) # convert img from HWC to NCHW img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW return img def get_path(self): if self.path_type == 'unclosed': num_segments = random.randint(1, 3) num_control_points = torch.zeros(num_segments, dtype=torch.int32) + 2 points = [] p0 = (random.random(), random.random()) points.append(p0) for j in range(num_segments): radius = 0.05 p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) points.append(p1) points.append(p2) points.append(p3) p0 = p3 points = torch.tensor(points) points[:, 0] *= self.canvas_width points[:, 1] *= self.canvas_height # points = torch.rand(3 * num_segments + 1, 2) * min(canvas_width, canvas_height) path = pydiffvg.Path(num_control_points=num_control_points, points=points, stroke_width=torch.tensor(1.0), is_closed=False) elif self.path_type == 'closed': num_segments = random.randint(3, 5) num_control_points = torch.zeros(num_segments, dtype=torch.int32) + 2 points = [] p0 = (random.random(), random.random()) points.append(p0) for j in range(num_segments): radius = 0.05 p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) points.append(p1) points.append(p2) if j < num_segments - 1: points.append(p3) p0 = p3 points = torch.tensor(points) points[:, 0] *= self.canvas_width points[:, 1] *= self.canvas_height path = pydiffvg.Path(num_control_points=num_control_points, points=points, stroke_width=torch.tensor(1.0), is_closed=True) self.strokes_counter += 1 return path def clip_curve_shape(self): if self.train_stroke: # open-form path for path in self.shapes: path.stroke_width.data.clamp_(1.0, self.max_width) for group in self.shape_groups: group.stroke_color.data.clamp_(0.0, 1.0) else: # closed-form path for group in self.shape_groups: group.fill_color.data.clamp_(0.0, 1.0) def set_parameters(self): # stroke`s location optimization self.point_vars = [] for i, path in enumerate(self.shapes): path.points.requires_grad = True self.point_vars.append(path.points) if self.train_stroke: path.stroke_width.requires_grad = True self.width_vars.append(path.stroke_width) # for stroke' color optimization self.color_vars = [] for i, group in enumerate(self.shape_groups): if self.train_stroke: group.stroke_color.requires_grad = True self.color_vars.append(group.stroke_color) else: group.fill_color.requires_grad = True self.color_vars.append(group.fill_color) def get_point_parameters(self): return self.point_vars def get_color_parameters(self): return self.color_vars def get_stroke_parameters(self): return self.width_vars, self.get_color_parameters() def save_svg(self, fpath): pydiffvg.save_svg(f'{fpath}', self.canvas_width, self.canvas_height, self.shapes, self.shape_groups) class LinearDecayLR: def __init__(self, decay_every, decay_ratio): self.decay_every = decay_every self.decay_ratio = decay_ratio def __call__(self, n): decay_time = n // self.decay_every decay_step = n % self.decay_every lr_s = self.decay_ratio ** decay_time lr_e = self.decay_ratio ** (decay_time + 1) r = decay_step / self.decay_every lr = lr_s * (1 - r) + lr_e * r return lr class PainterOptimizer: def __init__(self, renderer: Painter, num_iter: int, lr_config: omegaconf.DictConfig, trainable_stroke: bool = False): self.renderer = renderer self.num_iter = num_iter self.trainable_stroke = trainable_stroke self.lr_base = { 'point': lr_config.point, 'color': lr_config.color, 'stroke_width': lr_config.stroke_width, 'stroke_color': lr_config.stroke_color, } self.learnable_params = [] # list[Dict] self.optimizer = None self.scheduler = None def init_optimizer(self): # optimizers params = {} self.renderer.set_parameters() params['point'] = self.renderer.get_point_parameters() if self.trainable_stroke: params['stroke_width'], params['stroke_color'] = self.renderer.get_stroke_parameters() else: params['color'] = self.renderer.get_color_parameters() self.learnable_params = [ {'params': params[ki], 'lr': self.lr_base[ki]} for ki in sorted(params.keys()) ] self.optimizer = torch.optim.Adam(self.learnable_params) # lr schedule lr_lambda_fn = LinearDecayLR(self.num_iter, 0.4) self.scheduler = LambdaLR(self.optimizer, lr_lambda=lr_lambda_fn, last_epoch=-1) def update_params(self, name: str, value: torch.tensor): for param_group in self.learnable_params: if param_group.get('_id') == name: param_group['params'] = value def update_lr(self): self.scheduler.step() def zero_grad_(self): self.optimizer.zero_grad() def step_(self): self.optimizer.step() def get_lr(self): return self.optimizer.param_groups[0]['lr']