hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Author: ximing
# Description: parent class
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
import pathlib
from typing import AnyStr, List, Union
import xml.etree.ElementTree as etree
import torch
import pydiffvg
def init_pydiffvg(device: torch.device,
use_gpu: bool = torch.cuda.is_available(),
print_timing: bool = False):
pydiffvg.set_use_gpu(use_gpu)
pydiffvg.set_device(device)
pydiffvg.set_print_timing(print_timing)
class DiffVGState(torch.nn.Module):
def __init__(self,
device: torch.device,
use_gpu: bool = torch.cuda.is_available(),
print_timing: bool = False,
canvas_width: int = None,
canvas_height: int = None):
super(DiffVGState, self).__init__()
# pydiffvg device setting
self.device = device
init_pydiffvg(device, use_gpu, print_timing)
# canvas size
self.canvas_width = canvas_width
self.canvas_height = canvas_height
# record all paths
self.shapes = []
self.shape_groups = []
# record the current optimized path
self.cur_shapes = []
self.cur_shape_groups = []
# learnable SVG params
self.point_vars = []
self.color_vars = []
self.width_vars = []
def clip_curve_shape(self, *args, **kwargs):
raise NotImplementedError
def render_warp(self, seed=0):
self.clip_curve_shape()
scene_args = pydiffvg.RenderFunction.serialize_scene(
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
)
_render = pydiffvg.RenderFunction.apply
img = _render(self.canvas_width, # width
self.canvas_height, # height
2, # num_samples_x
2, # num_samples_y
seed, # seed
None,
*scene_args)
return img
@staticmethod
def load_svg(path_svg):
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
return canvas_width, canvas_height, shapes, shape_groups
def save_svg(self,
filename: Union[AnyStr, pathlib.Path],
width: int = None,
height: int = None,
shapes: List = None,
shape_groups: List = None,
use_gamma: bool = False,
background: str = None):
"""
Save an SVG file with specified parameters and shapes.
Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg.
The original version saved words resulting in incomplete glyphs.
Args:
filename (str): The path to save the SVG file.
width (int): The width of the SVG canvas.
height (int): The height of the SVG canvas.
shapes (list): A list of shapes to be included in the SVG.
shape_groups (list): A list of shape groups.
use_gamma (bool): Flag indicating whether to apply gamma correction.
background (str, optional): The background color of the SVG.
Returns:
None
"""
root = etree.Element('svg')
root.set('version', '1.1')
root.set('xmlns', 'http://www.w3.org/2000/svg')
root.set('width', str(width))
root.set('height', str(height))
if background is not None:
print(f"setting background to {background}")
root.set('style', str(background))
defs = etree.SubElement(root, 'defs')
g = etree.SubElement(root, 'g')
if use_gamma:
f = etree.SubElement(defs, 'filter')
f.set('id', 'gamma')
f.set('x', '0')
f.set('y', '0')
f.set('width', '100%')
f.set('height', '100%')
gamma = etree.SubElement(f, 'feComponentTransfer')
gamma.set('color-interpolation-filters', 'sRGB')
feFuncR = etree.SubElement(gamma, 'feFuncR')
feFuncR.set('type', 'gamma')
feFuncR.set('amplitude', str(1))
feFuncR.set('exponent', str(1 / 2.2))
feFuncG = etree.SubElement(gamma, 'feFuncG')
feFuncG.set('type', 'gamma')
feFuncG.set('amplitude', str(1))
feFuncG.set('exponent', str(1 / 2.2))
feFuncB = etree.SubElement(gamma, 'feFuncB')
feFuncB.set('type', 'gamma')
feFuncB.set('amplitude', str(1))
feFuncB.set('exponent', str(1 / 2.2))
feFuncA = etree.SubElement(gamma, 'feFuncA')
feFuncA.set('type', 'gamma')
feFuncA.set('amplitude', str(1))
feFuncA.set('exponent', str(1 / 2.2))
g.set('style', 'filter:url(#gamma)')
# Store color
for i, shape_group in enumerate(shape_groups):
def add_color(shape_color, name):
if isinstance(shape_color, pydiffvg.LinearGradient):
lg = shape_color
color = etree.SubElement(defs, 'linearGradient')
color.set('id', name)
color.set('x1', str(lg.begin[0].item()))
color.set('y1', str(lg.begin[1].item()))
color.set('x2', str(lg.end[0].item()))
color.set('y2', str(lg.end[1].item()))
offsets = lg.offsets.data.cpu().numpy()
stop_colors = lg.stop_colors.data.cpu().numpy()
for j in range(offsets.shape[0]):
stop = etree.SubElement(color, 'stop')
stop.set('offset', str(offsets[j]))
c = lg.stop_colors[j, :]
stop.set('stop-color', 'rgb({}, {}, {})'.format(
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
))
stop.set('stop-opacity', '{}'.format(c[3]))
if isinstance(shape_color, pydiffvg.RadialGradient):
lg = shape_color
color = etree.SubElement(defs, 'radialGradient')
color.set('id', name)
color.set('cx', str(lg.center[0].item() / width))
color.set('cy', str(lg.center[1].item() / height))
# this only support width=height
color.set('r', str(lg.radius[0].item() / width))
offsets = lg.offsets.data.cpu().numpy()
stop_colors = lg.stop_colors.data.cpu().numpy()
for j in range(offsets.shape[0]):
stop = etree.SubElement(color, 'stop')
stop.set('offset', str(offsets[j]))
c = lg.stop_colors[j, :]
stop.set('stop-color', 'rgb({}, {}, {})'.format(
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
))
stop.set('stop-opacity', '{}'.format(c[3]))
if shape_group.fill_color is not None:
add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
if shape_group.stroke_color is not None:
add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
for i, shape_group in enumerate(shape_groups):
shape = shapes[shape_group.shape_ids[0]]
if isinstance(shape, pydiffvg.Circle):
shape_node = etree.SubElement(g, 'circle')
shape_node.set('r', str(shape.radius.item()))
shape_node.set('cx', str(shape.center[0].item()))
shape_node.set('cy', str(shape.center[1].item()))
elif isinstance(shape, pydiffvg.Polygon):
shape_node = etree.SubElement(g, 'polygon')
points = shape.points.data.cpu().numpy()
path_str = ''
for j in range(0, shape.points.shape[0]):
path_str += '{} {}'.format(points[j, 0], points[j, 1])
if j != shape.points.shape[0] - 1:
path_str += ' '
shape_node.set('points', path_str)
elif isinstance(shape, pydiffvg.Path):
for j, id in enumerate(shape_group.shape_ids):
shape = shapes[id]
if isinstance(shape, pydiffvg.Path):
if j == 0:
shape_node = etree.SubElement(g, 'path')
node_id = shape_node.get('id')
path_str = ''
num_segments = shape.num_control_points.shape[0]
num_control_points = shape.num_control_points.data.cpu().numpy()
points = shape.points.data.cpu().numpy()
num_points = shape.points.shape[0]
path_str += 'M {} {}'.format(points[0, 0], points[0, 1])
point_id = 1
for j in range(0, num_segments):
if num_control_points[j] == 0:
p = point_id % num_points
path_str += ' L {} {}'.format(
points[p, 0], points[p, 1])
point_id += 1
elif num_control_points[j] == 1:
p1 = (point_id + 1) % num_points
path_str += ' Q {} {} {} {}'.format(
points[point_id, 0], points[point_id, 1],
points[p1, 0], points[p1, 1])
point_id += 2
elif num_control_points[j] == 2:
p2 = (point_id + 2) % num_points
path_str += ' C {} {} {} {} {} {}'.format(
points[point_id, 0], points[point_id, 1],
points[point_id + 1, 0], points[point_id + 1, 1],
points[p2, 0], points[p2, 1])
point_id += 3
if node_id is not None:
shape_node.set('id', node_id) # add id to Path
shape_node.set('d', path_str)
elif isinstance(shape, pydiffvg.Rect):
shape_node = etree.SubElement(g, 'rect')
shape_node.set('x', str(shape.p_min[0].item()))
shape_node.set('y', str(shape.p_min[1].item()))
shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item()))
shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item()))
elif isinstance(shape, pydiffvg.Ellipse):
shape_node = etree.SubElement(g, 'ellipse')
shape_node.set('cx', str(shape.center[0].item()))
shape_node.set('cy', str(shape.center[1].item()))
shape_node.set('rx', str(shape.radius[0].item()))
shape_node.set('ry', str(shape.radius[1].item()))
else:
raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.')
shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
if shape_group.fill_color is not None:
if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
else:
c = shape_group.fill_color.data.cpu().numpy()
shape_node.set('fill', 'rgb({}, {}, {})'.format(
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
shape_node.set('opacity', str(c[3]))
else:
shape_node.set('fill', 'none')
if shape_group.stroke_color is not None:
if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
else:
c = shape_group.stroke_color.data.cpu().numpy()
shape_node.set('stroke', 'rgb({}, {}, {})'.format(
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
shape_node.set('stroke-opacity', str(c[3]))
shape_node.set('stroke-linecap', 'round')
shape_node.set('stroke-linejoin', 'round')
with open(filename, "w") as f:
f.write(pydiffvg.prettify(root))
@staticmethod
def save_image(img, filename, gamma=1):
if torch.is_tensor(img) and torch.device != 'cpu':
img = img.detach().cpu()
pydiffvg.imwrite(img, filename, gamma=gamma)