Spaces:
Running
Running
from abc import ABC, abstractmethod | |
import os | |
import json | |
import pandas as pd | |
from starvector.metrics.metrics import SVGMetrics | |
from copy import deepcopy | |
import numpy as np | |
from starvector.data.util import rasterize_svg | |
import importlib | |
from typing import Type | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
from datetime import datetime | |
import re | |
from starvector.data.util import clean_svg, use_placeholder | |
from svgpathtools import svgstr2paths | |
# Registry for SVGValidator subclasses | |
validator_registry = {} | |
def register_validator(cls: Type['SVGValidator']): | |
""" | |
Decorator to register SVGValidator subclasses. | |
""" | |
validator_registry[cls.__name__] = cls | |
return cls | |
class SVGValidator(ABC): | |
def __init__(self, config): | |
self.task = config.model.task | |
# Flag to determine if we should report to wandb | |
self.report_to_wandb = config.run.report_to == 'wandb' | |
date_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
if config.model.from_checkpoint: | |
chkp_dir = self.get_checkpoint_dir(config.model.from_checkpoint) | |
config.model.from_checkpoint = chkp_dir | |
self.resume_from_checkpoint = chkp_dir | |
self.out_dir = chkp_dir + '/' + config.run.out_dir + '/' + config.model.generation_engine + '_' + config.dataset.dataset_name + '_' + date_time | |
else: | |
self.out_dir = config.run.out_dir + '/' + config.model.generation_engine + '_' + config.model.name + '_' + config.dataset.dataset_name + '_' + date_time | |
os.makedirs(self.out_dir, exist_ok=True) | |
self.model_name = config.model.name | |
# Save config to yaml file | |
config_path = os.path.join(self.out_dir, "config.yaml") | |
self.config = config | |
with open(config_path, "w") as f: | |
OmegaConf.save(config=self.config, f=f) | |
print(f"Out dir: {self.out_dir}") | |
os.makedirs(self.out_dir, exist_ok=True) | |
metrics_config_path = f"configs/metrics/{self.task}.yaml" | |
default_metrics_config = OmegaConf.load(metrics_config_path) | |
self.metrics = SVGMetrics(default_metrics_config['metrics']) | |
self.results = {} | |
# If wandb reporting is enabled, initialize wandb and a table to record sample results. | |
if self.report_to_wandb: | |
try: | |
import wandb | |
wandb.init( | |
project=config.run.project_name, | |
name=config.run.run_id, | |
config=OmegaConf.to_container(config, resolve=True) | |
) | |
# Create a wandb table with columns for all relevant data. | |
self.results_table = wandb.Table(columns=[ | |
"sample_id", "svg", "svg_raw", "svg_gt", | |
"no_compile", "post_processed", "original_image", "generated_image", | |
"comparison_image" | |
]) | |
# Dictionary to hold table rows indexed by sample_id | |
self.table_data = {} | |
print("Initialized wandb run with results table") | |
except Exception as e: | |
print(f"Failed to initialize wandb: {e}") | |
def get_checkpoint_dir(self, checkpoint_path): | |
"""Get the directory of a checkpoint by name, returning the one with the highest step.""" | |
if re.search(r'checkpoint-\d+$', checkpoint_path): | |
return checkpoint_path | |
# Find all directories matching the checkpoint pattern | |
checkpoint_dirs = [] | |
for d in os.listdir(checkpoint_path): | |
if re.search(r'checkpoint-(\d+)$', d): | |
checkpoint_dirs.append(d) | |
if not checkpoint_dirs: | |
return None | |
# Extract step numbers and find the highest one | |
latest_dir = max(checkpoint_dirs, key=lambda x: int(re.search(r'checkpoint-(\d+)$', x).group(1))) | |
return os.path.join(checkpoint_path, latest_dir) | |
def _hash_config(self, config): | |
"""Create a deterministic hash of the config for caching/identification.""" | |
import json | |
import hashlib | |
# Convert OmegaConf to dict and sort it for deterministic serialization | |
config_dict = OmegaConf.to_container(config, resolve=True) | |
# Remove non-deterministic or irrelevant fields | |
if 'run' in config_dict: | |
config_dict['run'].pop('out_dir', None) # Remove output directory | |
config_dict['run'].pop('device', None) # Remove device specification | |
# Convert to sorted JSON string | |
config_str = json.dumps(config_dict, sort_keys=True) | |
# Create hash | |
return hashlib.md5(config_str.encode()).hexdigest() | |
def generate_svg(self, batch): | |
"""Generate SVG from batch data""" | |
pass | |
def post_process_svg(self, generated_output): | |
"""Post-process generated SVG""" | |
pass | |
def create_comparison_plot(self, sample_id, gt_raster, gen_raster, metrics, output_path): | |
""" | |
Creates and saves a comparison plot showing the ground truth and generated SVG images, along with computed metrics. | |
Args: | |
sample_id (str): Identifier for the sample. | |
gt_raster (PIL.Image.Image): Rasterized ground truth SVG image. | |
gen_raster (PIL.Image.Image): Rasterized generated SVG image. | |
metrics (dict): Dictionary of metric names and their values. | |
output_path (str): File path where the plot is saved. | |
Returns: | |
PIL.Image.Image: The generated comparison plot image. | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from io import BytesIO | |
from PIL import Image | |
# Create figure with two subplots: one for metrics text, one for the images | |
fig, (ax_metrics, ax_images) = plt.subplots(2, 1, figsize=(12, 8), gridspec_kw={'height_ratios': [1, 4]}) | |
fig.suptitle(f'Generation Results for {sample_id}', fontsize=16) | |
# Build text for metrics | |
if metrics: | |
metrics_text = "Metrics:\n" | |
for key, val in metrics.items(): | |
if isinstance(val, list) and val: | |
metrics_text += f"{key}: {val[-1]:.4f}\n" | |
elif isinstance(val, (int, float)): | |
metrics_text += f"{key}: {val:.4f}\n" | |
else: | |
metrics_text += f"{key}: {val}\n" | |
else: | |
metrics_text = "No metrics available." | |
# Add metrics text in the upper subplot | |
ax_metrics.text(0.5, 0.5, metrics_text, fontfamily='monospace', | |
horizontalalignment='center', verticalalignment='center') | |
ax_metrics.axis('off') | |
# Set title and prepare the images subplot | |
ax_images.set_title('Ground Truth (left) vs Generated (right)') | |
gt_array = np.array(gt_raster) | |
gen_array = np.array(gen_raster) | |
combined = np.hstack((gt_array, gen_array)) | |
ax_images.imshow(combined) | |
ax_images.axis('off') | |
# Save figure to buffer and file path | |
buf = BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) | |
plt.savefig(output_path, format='png', bbox_inches='tight', dpi=300) | |
plt.close(fig) | |
buf.seek(0) | |
return Image.open(buf) | |
def create_comparison_plots_with_metrics(self, all_metrics): | |
""" | |
Create and save comparison plots with metrics for all samples based on computed metrics. | |
""" | |
for sample_id, metrics in all_metrics.items(): | |
if sample_id not in self.results: | |
continue # Skip if the sample does not exist in the results | |
result = self.results[sample_id] | |
sample_dir = os.path.join(self.out_dir, sample_id) | |
# Retrieve the already rasterized images from the result | |
gt_raster = result.get('gt_im') | |
gen_raster = result.get('gen_im') | |
if gt_raster is None or gen_raster is None: | |
continue | |
# Define the output path for the comparison plot image | |
output_path = os.path.join(sample_dir, f"{sample_id}_comparison.png") | |
comp_img = self.create_comparison_plot(sample_id, gt_raster, gen_raster, metrics, output_path) | |
# Save the generated plot image in the result for later use | |
result['comparison_image'] = comp_img | |
# Also update the row in the internal table_data with the comparison image. | |
if self.report_to_wandb and sample_id in self.table_data and self.config.run.log_images: | |
import wandb | |
row = list(self.table_data[sample_id]) | |
row[-1] = wandb.Image(comp_img) | |
self.table_data[sample_id] = tuple(row) | |
self.update_results_table_log() | |
def save_results(self, results, batch, batch_idx): | |
"""Save results from generation.""" | |
out_path = self.out_dir | |
for i, sample in enumerate(batch['Svg']): | |
sample_id = str(batch['Filename'][i]).split('.')[0] | |
res = results[i] | |
res['sample_id'] = sample_id | |
res['gt_svg'] = sample | |
sample_dir = os.path.join(out_path, sample_id) | |
os.makedirs(sample_dir, exist_ok=True) | |
# Save SVG files and rasterized images using the base class method | |
svg_raster, gt_svg_raster = self._save_svg_files(sample_dir, sample_id, res) | |
# Save metadata to disk | |
with open(os.path.join(sample_dir, 'metadata.json'), 'w') as f: | |
json.dump(res, f, indent=4, sort_keys=True) | |
res['gen_im'] = svg_raster | |
res['gt_im'] = gt_svg_raster | |
self.results[sample_id] = res | |
# Instead of logging individual sample fields directly, add an entry (row) | |
# to the internal table_data with a placeholder for comparison_image. | |
if self.report_to_wandb and self.config.run.log_images: | |
import wandb | |
row = ( | |
sample_id, | |
res['svg'], | |
res['svg_raw'], | |
res['gt_svg'], | |
res['no_compile'], | |
res['post_processed'], | |
wandb.Image(gt_svg_raster), | |
wandb.Image(svg_raster), | |
None # Placeholder for comparison_image | |
) | |
self.table_data[sample_id] = row | |
self.update_results_table_log() | |
def _save_svg_files(self, sample_dir, outpath_filename, res): | |
"""Save SVG files and rasterized images.""" | |
# Save SVG files | |
with open(os.path.join(sample_dir, f"{outpath_filename}.svg"), 'w', encoding='utf-8') as f: | |
f.write(res['svg']) | |
with open(os.path.join(sample_dir, f"{outpath_filename}_raw.svg"), 'w', encoding='utf-8') as f: | |
f.write(res['svg_raw']) | |
with open(os.path.join(sample_dir, f"{outpath_filename}_gt.svg"), 'w', encoding='utf-8') as f: | |
f.write(res['gt_svg']) | |
# Rasterize and save PNG | |
svg_raster = rasterize_svg(res['svg'], resolution=512, dpi=100, scale=1) | |
gt_svg_raster = rasterize_svg(res['gt_svg'], resolution=512, dpi=100, scale=1) | |
svg_raster.save(os.path.join(sample_dir, f"{outpath_filename}_generated.png")) | |
gt_svg_raster.save(os.path.join(sample_dir, f"{outpath_filename}_original.png")) | |
return svg_raster, gt_svg_raster | |
def run_temperature_sweep(self, batch): | |
"""Run generation with different temperatures""" | |
out_dict = {} | |
sampling_temperatures = np.linspace( | |
self.config.generation_sweep.min_temperature, | |
self.config.generation_sweep.max_temperature, | |
self.config.generation_sweep.num_generations_different_temp | |
).tolist() | |
for temp in sampling_temperatures: | |
current_args = deepcopy(self.config.generation_params) | |
current_args['temperature'] = temp | |
results = self.generate_and_process_batch(batch, current_args) | |
for i, sample_id in enumerate(batch['id']): | |
sample_id = str(sample_id).split('.')[0] | |
if sample_id not in out_dict: | |
out_dict[sample_id] = {} | |
out_dict[sample_id][temp] = results[i] | |
return out_dict | |
def validate(self): | |
"""Main validation loop""" | |
for i, batch in enumerate(tqdm(self.dataloader, desc="Validating")): | |
if self.config.generation_params.generation_sweep: | |
results = self.run_temperature_sweep(batch) | |
else: | |
results = self.generate_and_process_batch(batch, self.config.generation_params) | |
self.save_results(results, batch, i) | |
self.release_memory() | |
# Calculate and save metrics | |
self.calculate_and_save_metrics() | |
# Final logging of the complete results table. | |
if self.report_to_wandb and self.config.run.log_images: | |
try: | |
import wandb | |
wandb.log({"results_table": self.results_table}) | |
except Exception as e: | |
print(f"Failed to log final results table to wandb: {e}") | |
def calculate_and_save_metrics(self): | |
"""Calculate and save metrics""" | |
batch_results = self.preprocess_results() | |
avg_results, all_results = self.metrics.calculate_metrics(batch_results) | |
out_path_results = os.path.join(self.out_dir, 'results') | |
os.makedirs(out_path_results, exist_ok=True) | |
# Save average results | |
with open(os.path.join(out_path_results, 'results_avg.json'), 'w') as f: | |
json.dump(avg_results, f, indent=4, sort_keys=True) | |
# Save detailed results | |
df = pd.DataFrame.from_dict(all_results, orient='index') | |
df.to_csv(os.path.join(out_path_results, 'all_results.csv')) | |
# Log average metrics to wandb if enabled | |
if self.report_to_wandb: | |
try: | |
import wandb | |
wandb.log({'avg_metrics': avg_results}) | |
except Exception as e: | |
print(f"Error logging average metrics to wandb: {e}") | |
# Create comparison plots with metrics | |
self.create_comparison_plots_with_metrics(all_results) | |
def preprocess_results(self): | |
"""Preprocess results from self.results into batch format with lists""" | |
batch = { | |
'gen_svg': [], | |
'gt_svg': [], | |
'gen_im': [], | |
'gt_im': [], | |
'json': [] | |
} | |
for sample_id, result_dict in self.results.items(): | |
# For single temperature case, result_dict contains one result | |
# For temperature sweep, take first temperature's result | |
if self.config.generation_params.generation_sweep: | |
result = result_dict[list(result_dict.keys())[0]] | |
else: | |
result = result_dict | |
batch['gen_svg'].append(result['svg']) | |
batch['gt_svg'].append(result['gt_svg']) | |
batch['gen_im'].append(result['gen_im']) | |
batch['gt_im'].append(result['gt_im']) | |
batch['json'].append(result) | |
return batch | |
def generate_and_process_batch(self, batch, generate_config): | |
"""Generate and post-process SVGs for a batch""" | |
generated_outputs = self.generate_svg(batch, generate_config) | |
processed_results = [self.post_process_svg(output) for output in generated_outputs] | |
return processed_results | |
def post_process_svg(self, text): | |
"""Post-process a single SVG text""" | |
try: | |
svgstr2paths(text) | |
return { | |
'svg': text, | |
'svg_raw': text, | |
'post_processed': False, | |
'no_compile': False | |
} | |
except: | |
try: | |
cleaned_svg = clean_svg(text) | |
svgstr2paths(cleaned_svg) | |
return { | |
'svg': cleaned_svg, | |
'svg_raw': text, | |
'post_processed': True, | |
'no_compile': False | |
} | |
except: | |
return { | |
'svg': use_placeholder(), | |
'svg_raw': text, | |
'post_processed': True, | |
'no_compile': True | |
} | |
def get_validator(cls, key, args, validator_configs): | |
""" | |
Factory method to get the appropriate SVGValidator subclass based on the key. | |
Args: | |
key (str): The key name to select the validator. | |
args (argparse.Namespace): Parsed command-line arguments. | |
validator_configs (dict): Mapping of validator keys to class paths. | |
Returns: | |
SVGValidator: An instance of a subclass of SVGValidator. | |
Raises: | |
ValueError: If the provided key is not in the mapping. | |
""" | |
if key not in validator_configs: | |
available_validators = list(validator_configs.keys()) | |
raise ValueError(f"Validator '{key}' is not recognized. Available validators: {available_validators}") | |
class_path = validator_configs[key] | |
module_path, class_name = class_path.rsplit('.', 1) | |
module = importlib.import_module(module_path) | |
validator_class = getattr(module, class_name) | |
return validator_class(args) | |
def update_results_table_log(self): | |
"""Rebuild and log the results table from self.table_data.""" | |
if self.report_to_wandb and self.config.run.log_images: | |
try: | |
import wandb | |
table = wandb.Table(columns=[ | |
"sample_id", "svg", "svg_raw", "svg_gt", | |
"no_compile", "post_processed", | |
"original_image", "generated_image", "comparison_image" | |
]) | |
for row in self.table_data.values(): | |
table.add_data(*row) | |
wandb.log({"results_table": table}) | |
self.results_table = table | |
except Exception as e: | |
print(f"Failed to update results table to wandb: {e}") | |