Spaces:
Running
Running
import os | |
from starvector.data.base import SVGDatasetBase | |
from starvector.data.augmentation import SVGTransforms | |
import random | |
from transformers import AutoProcessor | |
from starvector.data.util import ImageTrainProcessor | |
text2svg_captions = [ | |
"Draw an SVG of ", | |
"Draw an SVG image of ", | |
"Draw an SVG picture of ", | |
"Generate an SVG of ", | |
"Create an SVG of ", | |
"Design an SVG of ", | |
"Make an SVG of ", | |
] | |
class SVGStackDataset(SVGDatasetBase): | |
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): | |
super().__init__(dataset_name, split, im_size, num_samples, **kwargs) | |
self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) | |
# Text2SVG specific | |
self.random_caption = kwargs.get('random_caption', True) | |
select_dataset_name = kwargs.get('select_dataset_name', False) | |
if select_dataset_name: | |
self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) | |
self.num_samples = num_samples | |
if self.num_samples != -1: | |
self.data = self.data.select(range(self.num_samples)) | |
self.image_processor = kwargs.get('image_processor', None) | |
if self.image_processor and 'siglip' in self.image_processor: | |
model_name = {'siglip_512': 'google/siglip-base-patch16-512', | |
'siglip_384': 'google/siglip-large-patch16-384', | |
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] | |
self.processor = AutoProcessor.from_pretrained(model_name).image_processor | |
else: | |
self.processor = ImageTrainProcessor(size=self.im_size) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
svg_str = self.data[idx]['Svg'] | |
sample_id = self.data[idx]['Filename'] | |
svg, image = self.get_svg_and_image(svg_str, sample_id) | |
# Randomly choose between 'caption_blip' and 'caption_llava' | |
caption_column = random.choice(['caption_blip2', 'caption_llava']) | |
caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "") | |
return { | |
'svg': svg, | |
'image': image, | |
'id': sample_id, | |
'caption': caption, | |
} | |