Spaces:
Running
Running
from torch.utils.data import Dataset | |
from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg | |
from starvector.util import instantiate_from_config | |
import numpy as np | |
from datasets import load_dataset | |
class SVGDatasetBase(Dataset): | |
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): | |
self.split = split | |
self.im_size = im_size | |
transforms = kwargs.get('transforms', False) | |
if transforms: | |
self.transforms = instantiate_from_config(transforms) | |
self.p = self.transforms.p | |
else: | |
self.transforms = None | |
self.p = 0.0 | |
normalization = kwargs.get('normalize', False) | |
if normalization: | |
mean = tuple(normalization.get('mean', None)) | |
std = tuple(normalization.get('std', None)) | |
else: | |
mean = None | |
std = None | |
self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) | |
self.data = load_dataset(dataset_name, split=split) | |
print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") | |
def __len__(self): | |
return len(self.data_json) | |
def get_svg_and_image(self, svg_str, sample_id): | |
do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) | |
svg, image = None, None | |
# Try to augment the image if conditions are met | |
if self.transforms is not None and do_augment: | |
try: | |
svg, image = self.transforms.augment(svg_str) | |
except Exception as e: | |
print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") | |
# If augmentation failed or wasn't attempted, try to rasterize the SVG | |
if svg is None or image is None: | |
try: | |
svg, image = svg_str, rasterize_svg(svg_str, self.im_size) | |
except Exception as e: | |
print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") | |
svg = use_placeholder() | |
image = rasterize_svg(svg, self.im_size) | |
# If the image is completely white, use a placeholder image | |
if np.array(image).mean() == 255.0: | |
print(f"Image is full white, using placeholder image for {sample_id}") | |
svg = use_placeholder() | |
image = rasterize_svg(svg) | |
# Process the image | |
if 'siglip' in self.image_processor: | |
image = self.processor(image).pixel_values[0] | |
else: | |
image = self.processor(image) | |
return svg, image | |
def __getitem__(self, idx): | |
raise NotImplementedError("This method should be implemented by subclasses") | |