Spaces:
Running
Running
File size: 1,871 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import os
from starvector.data.base import SVGDatasetBase
from starvector.data.augmentation import SVGTransforms
from starvector.data.util import ImageTrainProcessor
from transformers import AutoProcessor
class SVGDataset(SVGDatasetBase):
def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs):
super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
select_dataset_name = kwargs.get('select_dataset_name', False)
if select_dataset_name:
self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
self.num_samples = num_samples
if self.num_samples != -1:
self.data = self.data.select(range(self.num_samples))
self.image_processor = kwargs.get('image_processor', None)
if 'siglip' in self.image_processor:
model_name = {'siglip_512': 'google/siglip-base-patch16-512',
'siglip_384': 'google/siglip-large-patch16-384',
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
self.processor = AutoProcessor.from_pretrained(model_name).image_processor
else:
self.processor = ImageTrainProcessor(size=self.im_size)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
svg_str = self.data[idx]['Svg']
sample_id = self.data[idx]['Filename']
svg, image = self.get_svg_and_image(svg_str, sample_id)
caption = self.data[idx].get('Caption', "")
return {
'svg': svg,
'image': image,
'id': sample_id,
'caption': caption
} |