Spaces:
Running
Running
File size: 7,691 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import random
import pathlib
import omegaconf
import pydiffvg
import torch
from pytorch_svgrender.diffvg_warp import DiffVGState
class Painter(DiffVGState):
def __init__(
self,
method_cfg: omegaconf.DictConfig,
diffvg_cfg: omegaconf.DictConfig,
num_strokes: int = 4,
canvas_size: int = 224,
device: torch.device = None,
):
super(Painter, self).__init__(device, print_timing=diffvg_cfg.print_timing,
canvas_width=canvas_size, canvas_height=canvas_size)
self.method_cfg = method_cfg
self.num_paths = num_strokes
self.max_width = method_cfg.max_width
self.num_stages = method_cfg.num_stages
self.black_stroke_color = method_cfg.black_stroke_color
self.path_svg = method_cfg.path_svg
self.strokes_per_stage = self.num_paths
self.optimize_flag = []
self.strokes_counter = 0 # counts the number of calls to "get_path"
def init_image(self, stage=0):
if stage > 0:
# Noting: if multi stages training than add new strokes on existing ones
# don't optimize on previous strokes
self.optimize_flag = [False for i in range(len(self.shapes))]
for i in range(self.strokes_per_stage):
stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
path = self.get_path()
self.shapes.append(path)
path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
fill_color=None,
stroke_color=stroke_color)
self.shape_groups.append(path_group)
self.optimize_flag.append(True)
else:
num_paths_exists = 0
if self.path_svg is not None and pathlib.Path(self.path_svg).exists():
print(f"-> init svg from '{self.path_svg}' ...")
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg)
# if you want to add more strokes to existing ones and optimize on all of them
num_paths_exists = len(self.shapes)
for i in range(num_paths_exists, self.num_paths):
if self.black_stroke_color:
stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
else:
stroke_color = torch.tensor([random.random(), random.random(), random.random(), random.random()])
path = self.get_path()
self.shapes.append(path)
path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
fill_color=None,
stroke_color=stroke_color)
self.shape_groups.append(path_group)
self.optimize_flag = [True for i in range(len(self.shapes))]
img = self.render_warp()
img = img[:, :, 3:4] * img[:, :, :3] + \
torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4])
img = img[:, :, :3]
img = img.unsqueeze(0) # convert img from HWC to NCHW
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
return img
def get_image(self, step=0):
img = self.render_warp(step)
opacity = img[:, :, 3:4]
img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity)
img = img[:, :, :3]
img = img.unsqueeze(0) # convert img from HWC to NCHW
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
return img
def get_path(self):
num_segments = random.randint(1, 3)
num_control_points = torch.zeros(num_segments, dtype=torch.int32) + 2
points = []
p0 = (random.random(), random.random())
points.append(p0)
for j in range(num_segments):
radius = 0.1
p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
points.append(p1)
points.append(p2)
points.append(p3)
p0 = p3
points = torch.tensor(points).to(self.device)
points[:, 0] *= self.canvas_width
points[:, 1] *= self.canvas_height
path = pydiffvg.Path(num_control_points=num_control_points,
points=points,
stroke_width=torch.tensor(1.0),
is_closed=False)
self.strokes_counter += 1
return path
def clip_curve_shape(self):
for path in self.shapes:
path.stroke_width.data.clamp_(1.0, self.max_width)
for group in self.shape_groups:
group.stroke_color.data.clamp_(0.0, 1.0)
def set_parameters(self):
# stroke`s location and width optimization
self.point_vars = []
self.width_vars = []
for i, path in enumerate(self.shapes):
if self.optimize_flag[i]:
path.points.requires_grad = True
self.point_vars.append(path.points)
path.stroke_width.requires_grad = True
self.width_vars.append(path.stroke_width)
# for stroke' color optimization
self.color_vars = []
for i, group in enumerate(self.shape_groups):
if self.optimize_flag[i]:
group.stroke_color.requires_grad = True
self.color_vars.append(group.stroke_color)
return self.point_vars, self.width_vars, self.color_vars
def learnable_parameters(self):
return self.point_vars + self.width_vars + self.color_vars
def save_svg(self, output_dir, name):
pydiffvg.save_svg('{}/{}.svg'.format(output_dir, name),
self.canvas_width, self.canvas_height,
self.shapes, self.shape_groups)
class PainterOptimizer:
def __init__(self, renderer: Painter, points_lr: float, width_lr: float, color_lr: float):
self.renderer = renderer
self.points_lr = points_lr
self.width_lr = width_lr
self.color_lr = color_lr
self.points_optimizer, self.width_optimizer, self.color_optimizer = None, None, None
def init_optimizers(self):
point_vars, width_vars, color_vars = self.renderer.set_parameters()
self.points_optimizer = torch.optim.Adam(point_vars, lr=self.points_lr)
self.width_optimizer = torch.optim.Adam(width_vars, lr=self.width_lr)
self.color_optimizer = torch.optim.Adam(color_vars, lr=self.color_lr)
def update_lr(self, step, decay_steps=(500, 750)):
if step % decay_steps[0] == 0:
for param_group in self.points_optimizer.param_groups:
param_group['lr'] = 0.4
if step % decay_steps[1] == 0:
for param_group in self.points_optimizer.param_groups:
param_group['lr'] = 0.1
def zero_grad_(self):
self.points_optimizer.zero_grad()
self.width_optimizer.zero_grad()
self.color_optimizer.zero_grad()
def step_(self):
self.points_optimizer.step()
self.width_optimizer.step()
self.color_optimizer.step()
def get_lr(self):
return self.points_optimizer.param_groups[0]['lr']
|