File size: 2,735 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torch.utils.data import Dataset
from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg
from starvector.util import instantiate_from_config
import numpy as np
from datasets import load_dataset

class SVGDatasetBase(Dataset):
    def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs):
        self.split = split
        self.im_size = im_size

        transforms = kwargs.get('transforms', False)
        if transforms:
            self.transforms = instantiate_from_config(transforms)
            self.p = self.transforms.p
        else:
            self.transforms = None
            self.p = 0.0

        normalization = kwargs.get('normalize', False)
        if normalization:
            mean = tuple(normalization.get('mean', None))
            std = tuple(normalization.get('std', None))
        else:
            mean = None
            std = None

        self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std)
        self.data = load_dataset(dataset_name, split=split)

        print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split")

    def __len__(self):
        return len(self.data_json)
    
    def get_svg_and_image(self, svg_str, sample_id):
        do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p])
        svg, image = None, None

        # Try to augment the image if conditions are met
        if self.transforms is not None and do_augment:
            try:
                svg, image = self.transforms.augment(svg_str)
            except Exception as e:
                print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG")

        # If augmentation failed or wasn't attempted, try to rasterize the SVG
        if svg is None or image is None:
            try:
                svg, image = svg_str, rasterize_svg(svg_str, self.im_size)
            except Exception as e:
                print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image")
                svg = use_placeholder()
                image = rasterize_svg(svg, self.im_size)

        # If the image is completely white, use a placeholder image
        if np.array(image).mean() == 255.0:
            print(f"Image is full white, using placeholder image for {sample_id}")
            svg = use_placeholder()
            image = rasterize_svg(svg)

        # Process the image
        if 'siglip' in self.image_processor:
            image = self.processor(image).pixel_values[0]
        else:
            image = self.processor(image)

        return svg, image

    def __getitem__(self, idx):
        raise NotImplementedError("This method should be implemented by subclasses")