import json |
from tqdm import tqdm |
import numpy as np |
import matplotlib.pyplot as plt |
import xml.etree.ElementTree as ET |
from xml.dom import minidom |
import os |
from PIL import Image |
import matplotlib.animation as animation |
import copy |
from PIL import ImageEnhance |
import colorsys |
import matplotlib.colors as mcolors |
from matplotlib.collections import LineCollection |
from matplotlib.patheffects import withStroke |
import random |
import warnings |
from matplotlib.figure import Figure |
from io import BytesIO |
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter |
import requests |
import zipfile |
warnings.filterwarnings("ignore") |
def get_svg_content(svg_path): |
with open(svg_path, "r") as file: |
return file.read() |
def download_file(url, filename): |
response = requests.get(url) |
with open(filename, "wb") as f: |
f.write(response.content) |
def unzip_file(filename, extract_to="."): |
with zipfile.ZipFile(filename, "r") as zip_ref: |
zip_ref.extractall(extract_to) |
def load_and_pad_img_dir(file_dir): |
image_path = os.path.join(file_dir) |
image = Image.open(image_path) |
width, height = image.size |
ratio = min(224 / width, 224 / height) |
image = image.resize((int(width * ratio), int(height * ratio))) |
width, height = image.size |
if height < 224: |
top_padding = (224 - height) // 2 |
bottom_padding = 224 - height - top_padding |
padded_image = Image.new("RGB", (width, 224), (255, 255, 255)) |
padded_image.paste(image, (0, top_padding)) |
else: |
left_padding = (224 - width) // 2 |
right_padding = 224 - width - left_padding |
padded_image = Image.new("RGB", (224, height), (255, 255, 255)) |
padded_image.paste(image, (left_padding, 0)) |
return padded_image |
def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"): |
if input_image is not None: |
img = copy.deepcopy(input_image) |
enhancer = ImageEnhance.Brightness(img) |
img = enhancer.enhance(0.45) |
ax.imshow(img) |
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) |
for i, stroke in enumerate(ink.strokes): |
x, y = np.array(stroke.x), np.array(stroke.y) |
base_color = base_colors(len(ink.strokes) - 1 - i) |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) |
darker_color = colorsys.hsv_to_rgb( |
hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) |
) |
colors = [ |
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) |
for j in range(len(x)) |
] |
points = np.array([x, y]).T.reshape(-1, 1, 2) |
segments = np.concatenate([points[:-1], points[1:]], axis=1) |
lc = LineCollection(segments, colors=colors, linewidth=lw) |
if with_path: |
lc.set_path_effects( |
[withStroke(linewidth=lw * 1.25, foreground=path_color)] |
) |
ax.add_collection(lc) |
ax.set_xlim(0, 224) |
ax.set_ylim(0, 224) |
ax.invert_yaxis() |
def plot_ink_to_video( |
ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30 |
): |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150) |
if input_image is not None: |
img = copy.deepcopy(input_image) |
enhancer = ImageEnhance.Brightness(img) |
img = enhancer.enhance(0.45) |
ax.imshow(img) |
ax.set_xlim(0, 224) |
ax.set_ylim(0, 224) |
ax.invert_yaxis() |
ax.axis("off") |
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) |
all_points = sum([len(stroke.x) for stroke in ink.strokes], 0) |
def update(frame): |
ax.clear() |
if input_image is not None: |
ax.imshow(img) |
ax.set_xlim(0, 224) |
ax.set_ylim(0, 224) |
ax.invert_yaxis() |
ax.axis("off") |
points_drawn = 0 |
for stroke_index, stroke in enumerate(ink.strokes): |
x, y = np.array(stroke.x), np.array(stroke.y) |
points = np.array([x, y]).T.reshape(-1, 1, 2) |
segments = np.concatenate([points[:-1], points[1:]], axis=1) |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index) |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) |
darker_color = colorsys.hsv_to_rgb( |
hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) |
) |
visible_segments = ( |
segments[: frame - points_drawn] |
if frame - points_drawn < len(segments) |
else segments |
) |
colors = [ |
mcolors.to_rgba( |
darker_color, alpha=1 - (0.5 * j / len(visible_segments)) |
) |
for j in range(len(visible_segments)) |
] |
if len(visible_segments) > 0: |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw) |
lc.set_path_effects( |
[withStroke(linewidth=lw * 1.25, foreground=path_color)] |
) |
ax.add_collection(lc) |
points_drawn += len(segments) |
if points_drawn >= frame: |
break |
ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False) |
Writer = FFMpegWriter(fps=fps) |
ani.save(output_name, writer=Writer) |
plt.close(fig) |
class Stroke: |
def __init__(self, list_of_coordinates=None) -> None: |
self.x = [] |
self.y = [] |
if list_of_coordinates: |
for point in list_of_coordinates: |
self.x.append(point[0]) |
self.y.append(point[1]) |
def __len__(self): |
return len(self.x) |
def __getitem__(self, index): |
return (self.x[index], self.y[index]) |
class Ink: |
def __init__(self, list_of_strokes=None) -> None: |
self.strokes = [] |
if list_of_strokes: |
self.strokes = list_of_strokes |
def __len__(self): |
return len(self.strokes) |
def __getitem__(self, index): |
return self.strokes[index] |
def inkml_to_ink(inkml_file): |
"""Convert inkml file to Ink""" |
tree = ET.parse(inkml_file) |
root = tree.getroot() |
inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"} |
strokes = [] |
for trace in root.findall("inkml:trace", inkml_namespace): |
points = trace.text.strip().split() |
stroke_points = [] |
for point in points: |
x, y = point.split(",") |
stroke_points.append((float(x), float(y))) |
strokes.append(Stroke(stroke_points)) |
return Ink(strokes) |
def parse_inkml_annotations(inkml_file): |
tree = ET.parse(inkml_file) |
root = tree.getroot() |
annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation") |
annotation_dict = {} |
for annotation in annotations: |
annotation_type = annotation.get("type") |
annotation_text = annotation.text |
annotation_dict[annotation_type] = annotation_text |
return annotation_dict |