Spaces:
Running
Running
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from bs4 import BeautifulSoup | |
import re | |
from svgpathtools import svgstr2paths | |
import numpy as np | |
from PIL import Image | |
import cairosvg | |
from io import BytesIO | |
import numpy as np | |
import textwrap | |
import os | |
import base64 | |
import io | |
CIRCLE_SVG = "<svg><circle cx='50%' cy='50%' r='50%' /></svg>" | |
VOID_SVF = "<svg></svg>" | |
def load_transforms(): | |
transforms = { | |
'train': None, | |
'eval': None | |
} | |
return transforms | |
class ImageBaseProcessor(): | |
def __init__(self, mean=None, std=None): | |
if mean is None: | |
mean = (0.48145466, 0.4578275, 0.40821073) | |
if std is None: | |
std = (0.26862954, 0.26130258, 0.27577711) | |
self.normalize = transforms.Normalize(mean=mean, std=std) | |
class ImageTrainProcessor(ImageBaseProcessor): | |
def __init__(self, mean=None, std=None, size=224, **kwargs): | |
super().__init__(mean, std) | |
self.size = size | |
self.transform = transforms.Compose([ | |
transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
self.normalize | |
]) | |
def __call__(self, item): | |
return self.transform(item) | |
def encode_image_base64(pil_image): | |
if pil_image.mode == 'RGBA': | |
pil_image = pil_image.convert('RGB') # Convert RGBA to RGB | |
buffered = io.BytesIO() | |
pil_image.save(buffered, format="JPEG") | |
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return base64_image | |
# -------------- Generation utils -------------- | |
def is_valid_svg(svg_text): | |
try: | |
svgstr2paths(svg_text) | |
return True | |
except Exception as e: | |
print(f"Invalid SVG: {str(e)}") | |
return False | |
def clean_svg(svg_text, output_width=None, output_height=None): | |
soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml | |
svg_bs4 = soup.prettify() # Prettify to get a string | |
# Store the original signal handler | |
import signal | |
original_handler = signal.getsignal(signal.SIGALRM) | |
try: | |
# Set a timeout to prevent hanging | |
def timeout_handler(signum, frame): | |
raise TimeoutError("SVG processing timed out") | |
# Set timeout | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(5) | |
# Try direct conversion without BeautifulSoup | |
svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() | |
except TimeoutError: | |
print("SVG conversion timed out, using fallback method") | |
svg_cairo = """<svg></svg>""" | |
finally: | |
# Always cancel the alarm and restore original handler, regardless of success or failure | |
signal.alarm(0) | |
signal.signal(signal.SIGALRM, original_handler) | |
svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("<?xml")]) # Remove xml header | |
return svg_clean | |
def use_placeholder(): | |
return VOID_SVF | |
def process_and_rasterize_svg(svg_string, resolution=256, dpi = 128, scale=2): | |
try: | |
svgstr2paths(svg_string) # This will raise an exception if the svg is still not valid | |
out_svg = svg_string | |
except: | |
try: | |
svg = clean_svg(svg_string) | |
svgstr2paths(svg) # This will raise an exception if the svg is still not valid | |
out_svg = svg | |
except Exception as e: | |
out_svg = use_placeholder() | |
raster_image = rasterize_svg(out_svg, resolution, dpi, scale) | |
return out_svg, raster_image | |
def rasterize_svg(svg_string, resolution=224, dpi = 128, scale=2): | |
try: | |
svg_raster_bytes = cairosvg.svg2png( | |
bytestring=svg_string, | |
background_color='white', | |
output_width=resolution, | |
output_height=resolution, | |
dpi=dpi, | |
scale=scale) | |
svg_raster = Image.open(BytesIO(svg_raster_bytes)) | |
except: | |
try: | |
svg = clean_svg(svg_string) | |
svg_raster_bytes = cairosvg.svg2png( | |
bytestring=svg, | |
background_color='white', | |
output_width=resolution, | |
output_height=resolution, | |
dpi=dpi, | |
scale=scale) | |
svg_raster = Image.open(BytesIO(svg_raster_bytes)) | |
except: | |
svg_raster = Image.new('RGB', (resolution, resolution), color = 'white') | |
return svg_raster | |
def find_unclosed_tags(svg_content): | |
all_tags_pattern = r"<(\w+)" | |
self_closing_pattern = r"<\w+[^>]*\/>" | |
all_tags = re.findall(all_tags_pattern, svg_content) | |
self_closing_matches = re.findall(self_closing_pattern, svg_content) | |
self_closing_tags = [] | |
for match in self_closing_matches: | |
tag = re.search(all_tags_pattern, match) | |
if tag: | |
self_closing_tags.append(tag.group(1)) | |
unclosed_tags = [] | |
for tag in all_tags: | |
if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count('</' + tag + '>'): | |
unclosed_tags.append(tag) | |
unclosed_tags = list(dict.fromkeys(unclosed_tags)) | |
return unclosed_tags | |
# -------------- Plotting utils -------------- | |
def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path): | |
array1 = np.array(image1).astype(np.float32) | |
array2 = np.array(image2).astype(np.float32) | |
diff = np.abs(array1 - array2).astype(np.uint8) | |
fig, axes = plt.subplots(1, 3, figsize=(10, 5)) | |
axes[0].imshow(image1) | |
axes[0].set_title('generated_svg') | |
axes[0].axis('off') | |
axes[1].imshow(image2) | |
axes[1].set_title('gt') | |
axes[1].axis('off') | |
axes[2].imshow(diff) | |
axes[2].set_title('Difference') | |
axes[2].axis('off') | |
plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) | |
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) | |
image = Image.open(out_path) | |
plt.close(fig) | |
return image | |
def plot_images_side_by_side(image1, image2, out_path): | |
array1 = np.array(image1).astype(np.float32) | |
array2 = np.array(image2).astype(np.float32) | |
diff = np.abs(array1 - array2).astype(np.uint8) | |
fig, axes = plt.subplots(1, 3, figsize=(10, 5)) | |
axes[0].imshow(image1) | |
axes[0].set_title('generated_svg') | |
axes[0].axis('off') | |
axes[1].imshow(image2) | |
axes[1].set_title('gt') | |
axes[1].axis('off') | |
axes[2].imshow(diff) | |
axes[2].set_title('Difference') | |
axes[2].axis('off') | |
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) | |
image = Image.open(out_path) | |
plt.close(fig) | |
return image | |
def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename): | |
# Create a plot with the original image and different temperature results | |
num_temps = len(samples_temp) | |
fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]}) | |
# Plot the original image | |
gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') | |
gt_image = Image.open(gt_image_path) | |
axes[0, 0].imshow(gt_image) | |
axes[0, 0].set_title('Original') | |
axes[0, 0].axis('off') | |
axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16) | |
axes[1, 0].axis('off') | |
# Plot the generated images for different temperatures and metrics | |
for idx, (temp, sample) in enumerate(samples_temp.items()): | |
gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') | |
gen_image = Image.open(gen_image_path) | |
axes[0, idx + 1].imshow(gen_image) | |
axes[0, idx + 1].set_title(f'Temp {temp}') | |
axes[0, idx + 1].axis('off') | |
axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}', | |
horizontalalignment='center', verticalalignment='center', fontsize=12) | |
axes[1, idx + 1].axis('off') | |
# Save the comparison plot | |
comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') | |
plt.tight_layout() | |
plt.savefig(comparison_path) | |
plt.close() | |
def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path): | |
# First col shows caption, second col shows generated svg, third col shows gt svg | |
fig, axes = plt.subplots(1, 3, figsize=(10, 5)) | |
# Split the prompt into multiple lines if it exceeds a certain length | |
prompt_lines = textwrap.wrap(prompt, width=30) | |
prompt_text = '\n'.join(prompt_lines) | |
# Display the prompt in the first cell | |
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) | |
axes[0].axis('off') | |
axes[1].imshow(svg_raster) | |
axes[1].set_title('generated_svg') | |
axes[1].axis('off') | |
axes[2].imshow(gt_svg_raster) | |
axes[2].set_title('gt') | |
axes[2].axis('off') | |
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) | |
image = Image.open(out_path) | |
plt.close(fig) | |
return image | |
def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path): | |
# First col shows caption, second col shows generated svg, third col shows gt svg | |
fig, axes = plt.subplots(1, 3, figsize=(10, 5)) | |
# Split the prompt into multiple lines if it exceeds a certain length | |
prompt_lines = textwrap.wrap(prompt, width=30) | |
prompt_text = '\n'.join(prompt_lines) | |
# Display the prompt in the first cell | |
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) | |
axes[0].axis('off') | |
axes[1].imshow(svg_raster) | |
axes[1].set_title('generated_svg') | |
axes[1].axis('off') | |
axes[2].imshow(gt_svg_raster) | |
axes[2].set_title('gt') | |
axes[2].axis('off') | |
plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) | |
plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) | |
image = Image.open(out_path) | |
plt.close(fig) | |
return image | |
def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename): | |
# Calculate the number of temperature variations | |
num_temps = len(samples_temp) | |
# Create a plot with text, the original image, and different temperature results | |
fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6)) | |
# Split the prompt into multiple lines if it exceeds a certain length | |
prompt_lines = textwrap.wrap(prompt, width=30) | |
prompt_text = '\n'.join(prompt_lines) | |
# Display the prompt in the first cell | |
axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) | |
axes[0].axis('off') | |
# Plot the GT (ground truth) image in the second cell | |
gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') | |
gt_image = Image.open(gt_image_path) | |
axes[1].imshow(gt_image) | |
axes[1].set_title('GT Image') | |
axes[1].axis('off') | |
# Plot the generated images for different temperatures and display metrics | |
for idx, (temp, sample) in enumerate(samples_temp.items()): | |
gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') | |
gen_image = Image.open(gen_image_path) | |
axes[idx + 2].imshow(gen_image) | |
axes[idx + 2].set_title(f'Temp {temp}') | |
axes[idx + 2].axis('off') | |
clip_score = metrics[temp]["clip_score"] | |
axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes) | |
# Save the comparison plot | |
comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') | |
plt.tight_layout() | |
plt.savefig(comparison_path) | |
plt.close() | |
return comparison_path | |
def plot_image_tensor(image): | |
import numpy as np | |
from PIL import Image | |
tensor = image[0].cpu().float() | |
tensor = tensor.permute(1, 2, 0) | |
array = (tensor.numpy() * 255).astype(np.uint8) | |
im = Image.fromarray(array) | |
im.save("tmp/output_image.jpg") | |
def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'): | |
# Calculate the number of rows required for the grid | |
num_images = len(images) | |
num_rows = (num_images + num_cols - 1) // num_cols | |
# Create a new figure | |
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8)) | |
# Loop through the image files and plot them | |
for i, image in enumerate(images): | |
row = i // num_cols | |
col = i % num_cols | |
# Open and display the image using Pillow | |
if type(image) == str: | |
img = Image.open(image) | |
else: | |
img = image | |
axes[row, col].imshow(img) | |
# axes[row, col].set_title(os.path.basename(image_file)) | |
axes[row, col].axis('off') | |
# Remove empty subplots | |
for i in range(num_images, num_rows * num_cols): | |
row = i // num_cols | |
col = i % num_cols | |
fig.delaxes(axes[row, col]) | |
# Adjust spacing between subplots | |
plt.tight_layout() | |
# save image | |
plt.savefig(out_path, dpi=300) | |
image = Image.open(out_path) | |
plt.close(fig) | |
return image |