File size: 2,398 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from starvector.data.base import SVGDatasetBase
from starvector.data.augmentation import SVGTransforms
import random
from transformers import AutoProcessor
from starvector.data.util import ImageTrainProcessor

text2svg_captions = [
    "Draw an SVG of ",
    "Draw an SVG image of ",
    "Draw an SVG picture of ",
    "Generate an SVG of ",
    "Create an SVG of ",
    "Design an SVG of ",
    "Make an SVG of ",
]

class SVGStackDataset(SVGDatasetBase):
    def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
        super().__init__(dataset_name, split, im_size, num_samples, **kwargs)
        self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']})
        
        # Text2SVG specific
        self.random_caption = kwargs.get('random_caption', True)
        select_dataset_name = kwargs.get('select_dataset_name', False)
        if select_dataset_name:
            self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name)
        
        self.num_samples = num_samples
        if self.num_samples != -1:
            self.data = self.data.select(range(self.num_samples))

        self.image_processor = kwargs.get('image_processor', None)
        if self.image_processor and 'siglip' in self.image_processor:
            model_name = {'siglip_512': 'google/siglip-base-patch16-512', 
                        'siglip_384': 'google/siglip-large-patch16-384', 
                        'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor]
            self.processor = AutoProcessor.from_pretrained(model_name).image_processor
        else:
            self.processor = ImageTrainProcessor(size=self.im_size)


    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        svg_str = self.data[idx]['Svg']
        sample_id = self.data[idx]['Filename']
        svg, image = self.get_svg_and_image(svg_str, sample_id)
        
        # Randomly choose between 'caption_blip' and 'caption_llava'
        caption_column = random.choice(['caption_blip2', 'caption_llava'])
        caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "")
        return {
            'svg': svg,
            'image': image,
            'id': sample_id,
            'caption': caption,
            }