Spaces:
Running
Running
# hf https://huggingface.co/docs/transformers/main_classes/text_generation | |
from starvector.validation.svg_validator_base import SVGValidator, register_validator | |
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from torch.utils.data import Dataset, DataLoader | |
from datasets import load_dataset | |
from starvector.data.util import rasterize_svg | |
class SVGValDataset(Dataset): | |
def __init__(self, dataset_name, config_name, split, im_size, num_samples, processor): | |
self.dataset_name = dataset_name | |
self.config_name = config_name | |
self.split = split | |
self.im_size = im_size | |
self.num_samples = num_samples | |
self.processor = processor | |
if self.config_name: | |
self.data = load_dataset(self.dataset_name, self.config_name, split=self.split) | |
else: | |
self.data = load_dataset(self.dataset_name, split=self.split) | |
if self.num_samples != -1: | |
self.data = self.data.select(range(self.num_samples)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
svg_str = self.data[idx]['Svg'] | |
sample_id = self.data[idx]['Filename'] | |
image = rasterize_svg(svg_str, resolution=self.im_size) | |
image = self.processor(image, return_tensors="pt")['pixel_values'].squeeze(0) | |
caption = self.data[idx].get('Caption', "") | |
return { | |
'Svg': svg_str, | |
'image': image, | |
'Filename': sample_id, | |
'Caption': caption | |
} | |
class StarVectorHFSVGValidator(SVGValidator): | |
def __init__(self, config): | |
super().__init__(config) | |
# Initialize HuggingFace model and tokenizer here | |
self.torch_dtype = { | |
'bfloat16': torch.bfloat16, | |
'float16': torch.float16, | |
'float32': torch.float32 | |
}[config.model.torch_dtype] | |
# could also use AutoModelForCausalLM | |
if config.model.from_checkpoint: | |
self.model = AutoModelForCausalLM.from_pretrained(self.resume_from_checkpoint, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device) | |
else: | |
self.model = AutoModelForCausalLM.from_pretrained(config.model.name, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device) | |
self.tokenizer = self.model.model.svg_transformer.tokenizer | |
self.svg_end_token_id = self.tokenizer.encode("</svg>")[0] | |
def get_dataloader(self): | |
self.dataset = SVGValDataset(self.config.dataset.dataset_name, self.config.dataset.config_name, self.config.dataset.split, self.config.dataset.im_size, self.config.dataset.num_samples, self.processor) | |
self.dataloader = DataLoader(self.dataset, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers) | |
def release_memory(self): | |
# Clear references to free GPU memory | |
self.model.model.svg_transformer.tokenizer = None | |
self.model.model.svg_transformer.model = None | |
# Force CUDA garbage collection | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def generate_svg(self, batch, generate_config): | |
if generate_config['temperature'] == 0: | |
generate_config['temperature'] = 1.0 | |
generate_config['do_sample'] = False | |
outputs = [] | |
batch['image'] = batch['image'].to('cuda').to(self.torch_dtype) | |
# for i, batch in enumerate(batch['svg']): | |
if self.task == 'im2svg': | |
outputs = self.model.model.generate_im2svg(batch = batch, **generate_config) | |
elif self.task == 'text2svg': | |
outputs = self.model.model.generate_text2svg(batch = batch, **generate_config) | |
return outputs | |