Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Author: ximing | |
# Description: parent class | |
# Copyright (c) 2023, XiMing Xing. | |
# License: MIT License | |
import pathlib | |
from typing import AnyStr, List, Union | |
import xml.etree.ElementTree as etree | |
import torch | |
import pydiffvg | |
def init_pydiffvg(device: torch.device, | |
use_gpu: bool = torch.cuda.is_available(), | |
print_timing: bool = False): | |
pydiffvg.set_use_gpu(use_gpu) | |
pydiffvg.set_device(device) | |
pydiffvg.set_print_timing(print_timing) | |
class DiffVGState(torch.nn.Module): | |
def __init__(self, | |
device: torch.device, | |
use_gpu: bool = torch.cuda.is_available(), | |
print_timing: bool = False, | |
canvas_width: int = None, | |
canvas_height: int = None): | |
super(DiffVGState, self).__init__() | |
# pydiffvg device setting | |
self.device = device | |
init_pydiffvg(device, use_gpu, print_timing) | |
# canvas size | |
self.canvas_width = canvas_width | |
self.canvas_height = canvas_height | |
# record all paths | |
self.shapes = [] | |
self.shape_groups = [] | |
# record the current optimized path | |
self.cur_shapes = [] | |
self.cur_shape_groups = [] | |
# learnable SVG params | |
self.point_vars = [] | |
self.color_vars = [] | |
self.width_vars = [] | |
def clip_curve_shape(self, *args, **kwargs): | |
raise NotImplementedError | |
def render_warp(self, seed=0): | |
self.clip_curve_shape() | |
scene_args = pydiffvg.RenderFunction.serialize_scene( | |
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups | |
) | |
_render = pydiffvg.RenderFunction.apply | |
img = _render(self.canvas_width, # width | |
self.canvas_height, # height | |
2, # num_samples_x | |
2, # num_samples_y | |
seed, # seed | |
None, | |
*scene_args) | |
return img | |
def load_svg(path_svg): | |
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) | |
return canvas_width, canvas_height, shapes, shape_groups | |
def save_svg(self, | |
filename: Union[AnyStr, pathlib.Path], | |
width: int = None, | |
height: int = None, | |
shapes: List = None, | |
shape_groups: List = None, | |
use_gamma: bool = False, | |
background: str = None): | |
""" | |
Save an SVG file with specified parameters and shapes. | |
Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg. | |
The original version saved words resulting in incomplete glyphs. | |
Args: | |
filename (str): The path to save the SVG file. | |
width (int): The width of the SVG canvas. | |
height (int): The height of the SVG canvas. | |
shapes (list): A list of shapes to be included in the SVG. | |
shape_groups (list): A list of shape groups. | |
use_gamma (bool): Flag indicating whether to apply gamma correction. | |
background (str, optional): The background color of the SVG. | |
Returns: | |
None | |
""" | |
root = etree.Element('svg') | |
root.set('version', '1.1') | |
root.set('xmlns', 'http://www.w3.org/2000/svg') | |
root.set('width', str(width)) | |
root.set('height', str(height)) | |
if background is not None: | |
print(f"setting background to {background}") | |
root.set('style', str(background)) | |
defs = etree.SubElement(root, 'defs') | |
g = etree.SubElement(root, 'g') | |
if use_gamma: | |
f = etree.SubElement(defs, 'filter') | |
f.set('id', 'gamma') | |
f.set('x', '0') | |
f.set('y', '0') | |
f.set('width', '100%') | |
f.set('height', '100%') | |
gamma = etree.SubElement(f, 'feComponentTransfer') | |
gamma.set('color-interpolation-filters', 'sRGB') | |
feFuncR = etree.SubElement(gamma, 'feFuncR') | |
feFuncR.set('type', 'gamma') | |
feFuncR.set('amplitude', str(1)) | |
feFuncR.set('exponent', str(1 / 2.2)) | |
feFuncG = etree.SubElement(gamma, 'feFuncG') | |
feFuncG.set('type', 'gamma') | |
feFuncG.set('amplitude', str(1)) | |
feFuncG.set('exponent', str(1 / 2.2)) | |
feFuncB = etree.SubElement(gamma, 'feFuncB') | |
feFuncB.set('type', 'gamma') | |
feFuncB.set('amplitude', str(1)) | |
feFuncB.set('exponent', str(1 / 2.2)) | |
feFuncA = etree.SubElement(gamma, 'feFuncA') | |
feFuncA.set('type', 'gamma') | |
feFuncA.set('amplitude', str(1)) | |
feFuncA.set('exponent', str(1 / 2.2)) | |
g.set('style', 'filter:url(#gamma)') | |
# Store color | |
for i, shape_group in enumerate(shape_groups): | |
def add_color(shape_color, name): | |
if isinstance(shape_color, pydiffvg.LinearGradient): | |
lg = shape_color | |
color = etree.SubElement(defs, 'linearGradient') | |
color.set('id', name) | |
color.set('x1', str(lg.begin[0].item())) | |
color.set('y1', str(lg.begin[1].item())) | |
color.set('x2', str(lg.end[0].item())) | |
color.set('y2', str(lg.end[1].item())) | |
offsets = lg.offsets.data.cpu().numpy() | |
stop_colors = lg.stop_colors.data.cpu().numpy() | |
for j in range(offsets.shape[0]): | |
stop = etree.SubElement(color, 'stop') | |
stop.set('offset', str(offsets[j])) | |
c = lg.stop_colors[j, :] | |
stop.set('stop-color', 'rgb({}, {}, {})'.format( | |
int(255 * c[0]), int(255 * c[1]), int(255 * c[2]) | |
)) | |
stop.set('stop-opacity', '{}'.format(c[3])) | |
if isinstance(shape_color, pydiffvg.RadialGradient): | |
lg = shape_color | |
color = etree.SubElement(defs, 'radialGradient') | |
color.set('id', name) | |
color.set('cx', str(lg.center[0].item() / width)) | |
color.set('cy', str(lg.center[1].item() / height)) | |
# this only support width=height | |
color.set('r', str(lg.radius[0].item() / width)) | |
offsets = lg.offsets.data.cpu().numpy() | |
stop_colors = lg.stop_colors.data.cpu().numpy() | |
for j in range(offsets.shape[0]): | |
stop = etree.SubElement(color, 'stop') | |
stop.set('offset', str(offsets[j])) | |
c = lg.stop_colors[j, :] | |
stop.set('stop-color', 'rgb({}, {}, {})'.format( | |
int(255 * c[0]), int(255 * c[1]), int(255 * c[2]) | |
)) | |
stop.set('stop-opacity', '{}'.format(c[3])) | |
if shape_group.fill_color is not None: | |
add_color(shape_group.fill_color, 'shape_{}_fill'.format(i)) | |
if shape_group.stroke_color is not None: | |
add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i)) | |
for i, shape_group in enumerate(shape_groups): | |
shape = shapes[shape_group.shape_ids[0]] | |
if isinstance(shape, pydiffvg.Circle): | |
shape_node = etree.SubElement(g, 'circle') | |
shape_node.set('r', str(shape.radius.item())) | |
shape_node.set('cx', str(shape.center[0].item())) | |
shape_node.set('cy', str(shape.center[1].item())) | |
elif isinstance(shape, pydiffvg.Polygon): | |
shape_node = etree.SubElement(g, 'polygon') | |
points = shape.points.data.cpu().numpy() | |
path_str = '' | |
for j in range(0, shape.points.shape[0]): | |
path_str += '{} {}'.format(points[j, 0], points[j, 1]) | |
if j != shape.points.shape[0] - 1: | |
path_str += ' ' | |
shape_node.set('points', path_str) | |
elif isinstance(shape, pydiffvg.Path): | |
for j, id in enumerate(shape_group.shape_ids): | |
shape = shapes[id] | |
if isinstance(shape, pydiffvg.Path): | |
if j == 0: | |
shape_node = etree.SubElement(g, 'path') | |
node_id = shape_node.get('id') | |
path_str = '' | |
num_segments = shape.num_control_points.shape[0] | |
num_control_points = shape.num_control_points.data.cpu().numpy() | |
points = shape.points.data.cpu().numpy() | |
num_points = shape.points.shape[0] | |
path_str += 'M {} {}'.format(points[0, 0], points[0, 1]) | |
point_id = 1 | |
for j in range(0, num_segments): | |
if num_control_points[j] == 0: | |
p = point_id % num_points | |
path_str += ' L {} {}'.format( | |
points[p, 0], points[p, 1]) | |
point_id += 1 | |
elif num_control_points[j] == 1: | |
p1 = (point_id + 1) % num_points | |
path_str += ' Q {} {} {} {}'.format( | |
points[point_id, 0], points[point_id, 1], | |
points[p1, 0], points[p1, 1]) | |
point_id += 2 | |
elif num_control_points[j] == 2: | |
p2 = (point_id + 2) % num_points | |
path_str += ' C {} {} {} {} {} {}'.format( | |
points[point_id, 0], points[point_id, 1], | |
points[point_id + 1, 0], points[point_id + 1, 1], | |
points[p2, 0], points[p2, 1]) | |
point_id += 3 | |
if node_id is not None: | |
shape_node.set('id', node_id) # add id to Path | |
shape_node.set('d', path_str) | |
elif isinstance(shape, pydiffvg.Rect): | |
shape_node = etree.SubElement(g, 'rect') | |
shape_node.set('x', str(shape.p_min[0].item())) | |
shape_node.set('y', str(shape.p_min[1].item())) | |
shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item())) | |
shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item())) | |
elif isinstance(shape, pydiffvg.Ellipse): | |
shape_node = etree.SubElement(g, 'ellipse') | |
shape_node.set('cx', str(shape.center[0].item())) | |
shape_node.set('cy', str(shape.center[1].item())) | |
shape_node.set('rx', str(shape.radius[0].item())) | |
shape_node.set('ry', str(shape.radius[1].item())) | |
else: | |
raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.') | |
shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item())) | |
if shape_group.fill_color is not None: | |
if isinstance(shape_group.fill_color, pydiffvg.LinearGradient): | |
shape_node.set('fill', 'url(#shape_{}_fill)'.format(i)) | |
else: | |
c = shape_group.fill_color.data.cpu().numpy() | |
shape_node.set('fill', 'rgb({}, {}, {})'.format( | |
int(255 * c[0]), int(255 * c[1]), int(255 * c[2]))) | |
shape_node.set('opacity', str(c[3])) | |
else: | |
shape_node.set('fill', 'none') | |
if shape_group.stroke_color is not None: | |
if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient): | |
shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i)) | |
else: | |
c = shape_group.stroke_color.data.cpu().numpy() | |
shape_node.set('stroke', 'rgb({}, {}, {})'.format( | |
int(255 * c[0]), int(255 * c[1]), int(255 * c[2]))) | |
shape_node.set('stroke-opacity', str(c[3])) | |
shape_node.set('stroke-linecap', 'round') | |
shape_node.set('stroke-linejoin', 'round') | |
with open(filename, "w") as f: | |
f.write(pydiffvg.prettify(root)) | |
def save_image(img, filename, gamma=1): | |
if torch.is_tensor(img) and torch.device != 'cpu': | |
img = img.detach().cpu() | |
pydiffvg.imwrite(img, filename, gamma=gamma) | |