Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from abc import ABC, abstractmethod | |
from starvector.model.adapters.adapter import Adapter | |
from starvector.model.image_encoder.image_encoder import ImageEncoder | |
from starvector.util import print_trainable_parameters | |
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[]): | |
super().__init__() # Correct super() call | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
# Check if any of the stop sequences are in the input_ids | |
for stop_ids in self.stops: | |
if input_ids[0][-len(stop_ids):].tolist() == stop_ids: | |
return True | |
return False | |
class StarVectorBase(nn.Module, ABC): | |
def __init__(self, config, **kwargs): | |
super().__init__() | |
# Task-specific layers | |
self.task = kwargs.get('task', 'im2svg') | |
self.model_precision = kwargs.get('model_precision', config.torch_dtype) | |
# Build Code LLM (StarCoder) | |
self.svg_transformer = self._get_svg_transformer(config, **kwargs) | |
if self.use_image_encoder(): | |
# Build Image Encoder | |
self.image_encoder = ImageEncoder(config, **kwargs) | |
# Build Adapter | |
self.image_projection = self.get_adapter(config, **kwargs).to(dtype=self.model_precision) | |
else: | |
self.query_length = 0 | |
self.max_length = config.max_length_train - self.query_length - 4 # for added special tokens | |
self.train_image_encoder = kwargs.get('train_image_encoder', False) | |
self.train_LLM = kwargs.get('train_LLM', False) | |
self.train_connector = kwargs.get('train_connector', False) | |
# Freeze parameters | |
self.freze_parameters(self.train_image_encoder, self.train_LLM, self.train_connector) | |
print_trainable_parameters(self) | |
def _get_svg_transformer(self, config, **kwargs): | |
"""Get SVG transformer model - implementation differs between versions""" | |
pass | |
def freze_parameters(self, train_image_encoder, train_LLM, train_connector): | |
"""V2 implementation of parameter freezing""" | |
if self.use_image_encoder(): | |
for _, param in self.image_encoder.named_parameters(): | |
param.requires_grad = train_image_encoder | |
# adapter trainable | |
for _, param in self.image_projection.named_parameters(): | |
param.requires_grad = train_connector | |
for _, param in self.svg_transformer.named_parameters(): | |
param.requires_grad = train_LLM | |
def use_image_encoder(self): | |
"""Determine if image encoder should be used based on task""" | |
return self.task == 'im2svg' | |
def get_adapter(self, config, **kwargs): | |
"""Get adapter layer for image projection""" | |
vision_hidden_size, self.query_length = self.get_hidden_size_and_query_length(config.image_encoder_type) | |
llm_hidden_size = self.svg_transformer.transformer.config.hidden_size | |
image_projection = Adapter( | |
vision_hidden_size, | |
llm_hidden_size, | |
adapter_norm=config.adapter_norm, | |
query_length=self.query_length, | |
dropout_prob=kwargs.get('dropout', 0.1) | |
) | |
return image_projection | |
def get_hidden_size_and_query_length(self, image_encoder_type): | |
"""Get hidden size and query length based on encoder type""" | |
if image_encoder_type == 'clip': | |
hidden_size = self.image_encoder.visual_encoder.num_features | |
query_length = 257 | |
elif image_encoder_type == 'open-clip': | |
hidden_size = self.image_encoder.visual_encoder.transformer.width | |
query_length = 256 | |
elif image_encoder_type == 'vqgan': | |
hidden_size = 256 | |
query_length = 196 | |
elif image_encoder_type == 'convnext': | |
hidden_size = 1024 | |
query_length = 49 | |
elif 'siglip' in image_encoder_type: | |
hidden_size = self.image_encoder.visual_encoder.head.mlp.fc2.out_features | |
if '512' in image_encoder_type: | |
query_length = 1024 | |
elif '384' in image_encoder_type: | |
query_length = 576 | |
return hidden_size, query_length | |
def _tokenize(self, text, max_length, device, add_special_tokens=True): | |
"""Common tokenization logic""" | |
tokens = self.svg_transformer.tokenizer( | |
text, | |
truncation=True, | |
add_special_tokens=add_special_tokens, | |
padding='longest', | |
max_length=max_length, | |
return_tensors="pt" | |
).to(device) | |
return tokens | |
def _create_targets(self, tokens): | |
"""Create targets with padding mask""" | |
target_mask = (tokens.input_ids == self.svg_transformer.tokenizer.pad_token_id) | |
return tokens.input_ids.masked_fill(target_mask, -100) | |
def _get_embeddings(self, input_ids): | |
"""Get embeddings from input ids - implementation differs between v1 and v2""" | |
pass | |
def embed_text_to_svg(self, batch, device): | |
"""Common text to SVG embedding logic""" | |
captions = batch["caption"] | |
svgs = batch["svg"] | |
samples = [captions[i] + self.svg_transformer.svg_start_token + svgs[i] + self.svg_transformer.tokenizer.eos_token | |
for i in range(len(captions))] | |
tokens = self._tokenize(samples, self.max_length, device) | |
targets = self._create_targets(tokens) | |
inputs_embeds = self._get_embeddings(tokens.input_ids) | |
return inputs_embeds, tokens.attention_mask, targets | |
def get_image_embeddings(self, batch, device): | |
"""Get image embeddings""" | |
image = batch["image"].to(dtype=self.model_precision) | |
embedded_image = self.image_encoder(image) | |
conditioning_embeds = self.image_projection(embedded_image) | |
return conditioning_embeds | |
def embed_im_to_svg(self, batch, device): | |
"""Common image to SVG embedding logic""" | |
# Process image | |
image = batch["image"].to(dtype=self.model_precision) | |
embedded_image = self.image_encoder(image) | |
conditioning_embeds = self.image_projection(embedded_image) | |
conditioning_embeds_att = torch.ones(conditioning_embeds.size()[:-1], dtype=torch.long).to(device) | |
# Get SVG text with appropriate end tokens (implemented by subclasses) | |
svg_text = self._get_svg_text(batch["svg"]) | |
svg_tokens = self._tokenize(svg_text, self.max_length, device) | |
svg_tokens_embeds = self._get_embeddings(svg_tokens.input_ids) | |
inputs_embeds = torch.cat([conditioning_embeds, svg_tokens_embeds], dim=1) | |
svg_targets = self._create_targets(svg_tokens) | |
empty_targets = torch.ones(conditioning_embeds_att.size(), dtype=torch.long).to(device).fill_(-100) | |
targets = torch.cat([empty_targets, svg_targets], dim=1) | |
attention_mask = torch.cat([conditioning_embeds_att, svg_tokens.attention_mask], dim=1) | |
return inputs_embeds, attention_mask, targets | |
def forward(self, batch): | |
"""Forward pass""" | |
device = batch["image"].device | |
task = self.task | |
# Depending | |
if task == 'text2svg': | |
inputs_embeds, attention_mask, targets = self.embed_text_to_svg(batch, device) | |
elif task == 'im2svg': | |
inputs_embeds, attention_mask, targets = self.embed_im_to_svg(batch, device) | |
outputs = self.svg_transformer.transformer( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
labels=targets, | |
return_dict=True, | |
output_hidden_states=True, | |
use_cache=False, | |
) | |
loss = outputs.loss | |
return loss | |
def _get_svg_text(self, svg_list): | |
"""Get SVG text with appropriate end tokens - implementation differs between v1 and v2""" | |
pass | |
def _prepare_generation_inputs(self, batch, prompt, device): | |
"""Common preparation for generation inputs""" | |
image = batch["image"] | |
image = image.to(device).to(self.model_precision) | |
embedded_image = self.image_encoder(image) | |
embedded_image = self.image_projection(embedded_image) | |
embedded_att = torch.ones(embedded_image.size()[:-1], dtype=torch.long).to(device) | |
if prompt is None: | |
prompt = self.svg_transformer.prompt | |
prompt = [prompt] * image.size(0) | |
prompt_tokens = self._tokenize(prompt, None, device, add_special_tokens=False) | |
attention_mask = torch.cat([embedded_att, prompt_tokens.attention_mask], dim=1) | |
inputs_embeds = self._get_embeddings(prompt_tokens.input_ids) | |
inputs_embeds = torch.cat([embedded_image, inputs_embeds], dim=1) | |
return inputs_embeds, attention_mask, prompt_tokens | |
def _get_generation_kwargs(self, base_kwargs): | |
"""Common generation kwargs preparation""" | |
# Get token IDs for "</svg>" | |
end_sequence = self.svg_transformer.tokenizer("</svg>", add_special_tokens=False)['input_ids'] | |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[end_sequence])]) | |
return { | |
'inputs_embeds': base_kwargs['inputs_embeds'], | |
'attention_mask': base_kwargs['attention_mask'], | |
'do_sample': base_kwargs.get('use_nucleus_sampling', True), | |
'top_p': base_kwargs.get('top_p', 0.9), | |
'temperature': base_kwargs.get('temperature', 1), | |
'num_beams': base_kwargs.get('num_beams', 2), | |
'max_length': base_kwargs.get('max_length', 30), | |
'min_length': base_kwargs.get('min_length', 1), | |
'repetition_penalty': base_kwargs.get('repetition_penalty', 1.0), | |
'length_penalty': base_kwargs.get('length_penalty', 1.0), | |
'use_cache': base_kwargs.get('use_cache', True), | |
'stopping_criteria': stopping_criteria | |
} | |
def generate_im2svg(self, batch, **kwargs): | |
"""Base implementation of image to SVG generation""" | |
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs( | |
batch, kwargs.get('prompt'), batch["image"].device | |
) | |
generation_kwargs = self._get_generation_kwargs( | |
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask} | |
) | |
# Let subclasses override these defaults if needed | |
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs)) | |
outputs = self.svg_transformer.transformer.generate(**generation_kwargs) | |
outputs = torch.cat([prompt_tokens.input_ids, outputs], dim=1) | |
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return raw_svg | |
def generate_im2svg_grpo(self, batch, **kwargs): | |
"""Base implementation of image to SVG generation""" | |
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs( | |
batch, kwargs.get('prompt'), batch["image"].device | |
) | |
generation_kwargs = self._get_generation_kwargs( | |
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask} | |
) | |
# Let subclasses override these defaults if needed | |
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs)) | |
num_return_sequences = kwargs.get('num_return_sequences', 1) | |
if num_return_sequences > 1: | |
generation_kwargs['num_return_sequences'] = num_return_sequences | |
generation_kwargs['num_beams'] = 1 | |
outputs = self.svg_transformer.transformer.generate(**generation_kwargs) | |
outputs = torch.cat([prompt_tokens.input_ids.repeat(num_return_sequences, 1), outputs], dim=1) | |
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return { | |
"raw_svg": raw_svg, | |
"outputs": outputs, | |
"inputs_embeds": inputs_embeds, | |
} | |
def _get_im2svg_specific_kwargs(self, kwargs): | |
"""Default implementation of im2svg specific generation kwargs. | |
Subclasses can override this to customize generation behavior.""" | |
return { | |
'early_stopping': True, | |
'pad_token_id': self.svg_transformer.tokenizer.pad_token_id | |
} | |
def generate_text2svg(self, batch, **kwargs): | |
"""Base implementation of text to SVG generation""" | |
device = batch["image"].device | |
prompt = batch["caption"] | |
prompt_tokens = self._tokenize( | |
prompt, | |
max_length=kwargs.get('max_length', 30), | |
device=device, | |
add_special_tokens=False | |
) | |
trigger_token = self._tokenize( | |
[self.svg_transformer.svg_start_token for _ in batch["caption"]], | |
max_length=None, | |
device=device, | |
add_special_tokens=False | |
) | |
input_tokens = torch.cat([prompt_tokens.input_ids, trigger_token.input_ids], dim=1) | |
attention_mask = torch.cat([prompt_tokens.attention_mask, trigger_token.attention_mask], dim=1) | |
inputs_embeds = self._get_embeddings(input_tokens) | |
max_length = kwargs.get('max_length', 30) - input_tokens.size(1) | |
generation_kwargs = self._get_generation_kwargs( | |
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}, | |
input_tokens.size(1) | |
) | |
# Let subclasses override these defaults if needed | |
generation_kwargs.update(self._get_text2svg_specific_kwargs(kwargs)) | |
generation_kwargs['max_length'] = max_length | |
outputs = self.svg_transformer.transformer.generate(**generation_kwargs) | |
return outputs | |
def _get_text2svg_specific_kwargs(self, kwargs): | |
"""Default implementation of text2svg specific generation kwargs. | |
Subclasses can override this to customize generation behavior.""" | |
return { | |
'eos_token_id': self.svg_transformer.tokenizer.eos_token_id, | |
'early_stopping': True, | |
'length_penalty': kwargs.get('length_penalty', 1.0) | |
} | |