Spaces:
Running
Running
File size: 2,735 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from torch.utils.data import Dataset
from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg
from starvector.util import instantiate_from_config
import numpy as np
from datasets import load_dataset
class SVGDatasetBase(Dataset):
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
self.split = split
self.im_size = im_size
transforms = kwargs.get('transforms', False)
if transforms:
self.transforms = instantiate_from_config(transforms)
self.p = self.transforms.p
else:
self.transforms = None
self.p = 0.0
normalization = kwargs.get('normalize', False)
if normalization:
mean = tuple(normalization.get('mean', None))
std = tuple(normalization.get('std', None))
else:
mean = None
std = None
self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std)
self.data = load_dataset(dataset_name, split=split)
print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split")
def __len__(self):
return len(self.data_json)
def get_svg_and_image(self, svg_str, sample_id):
do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p])
svg, image = None, None
# Try to augment the image if conditions are met
if self.transforms is not None and do_augment:
try:
svg, image = self.transforms.augment(svg_str)
except Exception as e:
print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG")
# If augmentation failed or wasn't attempted, try to rasterize the SVG
if svg is None or image is None:
try:
svg, image = svg_str, rasterize_svg(svg_str, self.im_size)
except Exception as e:
print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image")
svg = use_placeholder()
image = rasterize_svg(svg, self.im_size)
# If the image is completely white, use a placeholder image
if np.array(image).mean() == 255.0:
print(f"Image is full white, using placeholder image for {sample_id}")
svg = use_placeholder()
image = rasterize_svg(svg)
# Process the image
if 'siglip' in self.image_processor:
image = self.processor(image).pixel_values[0]
else:
image = self.processor(image)
return svg, image
def __getitem__(self, idx):
raise NotImplementedError("This method should be implemented by subclasses")
|