hz2475's picture
init
72f684c
import os
from starvector.data.base import SVGDatasetBase
from starvector.data.augmentation import SVGTransforms
import random
from transformers import AutoProcessor
from starvector.data.util import ImageTrainProcessor
text2svg_captions = [
"Draw an SVG of ",
"Draw an SVG image of ",
"Draw an SVG picture of ",
"Generate an SVG of ",
"Create an SVG of ",
"Design an SVG of ",
"Make an SVG of ",
]
class SVGStackDataset(SVGDatasetBase):
def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
# Text2SVG specific
self.random_caption = kwargs.get('random_caption', True)
select_dataset_name = kwargs.get('select_dataset_name', False)
if select_dataset_name:
self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
self.num_samples = num_samples
if self.num_samples != -1:
self.data = self.data.select(range(self.num_samples))
self.image_processor = kwargs.get('image_processor', None)
if self.image_processor and 'siglip' in self.image_processor:
model_name = {'siglip_512': 'google/siglip-base-patch16-512',
'siglip_384': 'google/siglip-large-patch16-384',
'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
self.processor = AutoProcessor.from_pretrained(model_name).image_processor
else:
self.processor = ImageTrainProcessor(size=self.im_size)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
svg_str = self.data[idx]['Svg']
sample_id = self.data[idx]['Filename']
svg, image = self.get_svg_and_image(svg_str, sample_id)
# Randomly choose between 'caption_blip' and 'caption_llava'
caption_column = random.choice(['caption_blip2', 'caption_llava'])
caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "")
return {
'svg': svg,
'image': image,
'id': sample_id,
'caption': caption,
}