hz2475's picture
init
72f684c
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from starvector.model.adapters.adapter import Adapter
from starvector.model.image_encoder.image_encoder import ImageEncoder
from starvector.util import print_trainable_parameters
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[]):
super().__init__() # Correct super() call
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
# Check if any of the stop sequences are in the input_ids
for stop_ids in self.stops:
if input_ids[0][-len(stop_ids):].tolist() == stop_ids:
return True
return False
class StarVectorBase(nn.Module, ABC):
def __init__(self, config, **kwargs):
super().__init__()
# Task-specific layers
self.task = kwargs.get('task', 'im2svg')
self.model_precision = kwargs.get('model_precision', config.torch_dtype)
# Build Code LLM (StarCoder)
self.svg_transformer = self._get_svg_transformer(config, **kwargs)
if self.use_image_encoder():
# Build Image Encoder
self.image_encoder = ImageEncoder(config, **kwargs)
# Build Adapter
self.image_projection = self.get_adapter(config, **kwargs).to(dtype=self.model_precision)
else:
self.query_length = 0
self.max_length = config.max_length_train - self.query_length - 4 # for added special tokens
self.train_image_encoder = kwargs.get('train_image_encoder', False)
self.train_LLM = kwargs.get('train_LLM', False)
self.train_connector = kwargs.get('train_connector', False)
# Freeze parameters
self.freze_parameters(self.train_image_encoder, self.train_LLM, self.train_connector)
print_trainable_parameters(self)
@abstractmethod
def _get_svg_transformer(self, config, **kwargs):
"""Get SVG transformer model - implementation differs between versions"""
pass
def freze_parameters(self, train_image_encoder, train_LLM, train_connector):
"""V2 implementation of parameter freezing"""
if self.use_image_encoder():
for _, param in self.image_encoder.named_parameters():
param.requires_grad = train_image_encoder
# adapter trainable
for _, param in self.image_projection.named_parameters():
param.requires_grad = train_connector
for _, param in self.svg_transformer.named_parameters():
param.requires_grad = train_LLM
def use_image_encoder(self):
"""Determine if image encoder should be used based on task"""
return self.task == 'im2svg'
def get_adapter(self, config, **kwargs):
"""Get adapter layer for image projection"""
vision_hidden_size, self.query_length = self.get_hidden_size_and_query_length(config.image_encoder_type)
llm_hidden_size = self.svg_transformer.transformer.config.hidden_size
image_projection = Adapter(
vision_hidden_size,
llm_hidden_size,
adapter_norm=config.adapter_norm,
query_length=self.query_length,
dropout_prob=kwargs.get('dropout', 0.1)
)
return image_projection
def get_hidden_size_and_query_length(self, image_encoder_type):
"""Get hidden size and query length based on encoder type"""
if image_encoder_type == 'clip':
hidden_size = self.image_encoder.visual_encoder.num_features
query_length = 257
elif image_encoder_type == 'open-clip':
hidden_size = self.image_encoder.visual_encoder.transformer.width
query_length = 256
elif image_encoder_type == 'vqgan':
hidden_size = 256
query_length = 196
elif image_encoder_type == 'convnext':
hidden_size = 1024
query_length = 49
elif 'siglip' in image_encoder_type:
hidden_size = self.image_encoder.visual_encoder.head.mlp.fc2.out_features
if '512' in image_encoder_type:
query_length = 1024
elif '384' in image_encoder_type:
query_length = 576
return hidden_size, query_length
def _tokenize(self, text, max_length, device, add_special_tokens=True):
"""Common tokenization logic"""
tokens = self.svg_transformer.tokenizer(
text,
truncation=True,
add_special_tokens=add_special_tokens,
padding='longest',
max_length=max_length,
return_tensors="pt"
).to(device)
return tokens
def _create_targets(self, tokens):
"""Create targets with padding mask"""
target_mask = (tokens.input_ids == self.svg_transformer.tokenizer.pad_token_id)
return tokens.input_ids.masked_fill(target_mask, -100)
@abstractmethod
def _get_embeddings(self, input_ids):
"""Get embeddings from input ids - implementation differs between v1 and v2"""
pass
def embed_text_to_svg(self, batch, device):
"""Common text to SVG embedding logic"""
captions = batch["caption"]
svgs = batch["svg"]
samples = [captions[i] + self.svg_transformer.svg_start_token + svgs[i] + self.svg_transformer.tokenizer.eos_token
for i in range(len(captions))]
tokens = self._tokenize(samples, self.max_length, device)
targets = self._create_targets(tokens)
inputs_embeds = self._get_embeddings(tokens.input_ids)
return inputs_embeds, tokens.attention_mask, targets
def get_image_embeddings(self, batch, device):
"""Get image embeddings"""
image = batch["image"].to(dtype=self.model_precision)
embedded_image = self.image_encoder(image)
conditioning_embeds = self.image_projection(embedded_image)
return conditioning_embeds
def embed_im_to_svg(self, batch, device):
"""Common image to SVG embedding logic"""
# Process image
image = batch["image"].to(dtype=self.model_precision)
embedded_image = self.image_encoder(image)
conditioning_embeds = self.image_projection(embedded_image)
conditioning_embeds_att = torch.ones(conditioning_embeds.size()[:-1], dtype=torch.long).to(device)
# Get SVG text with appropriate end tokens (implemented by subclasses)
svg_text = self._get_svg_text(batch["svg"])
svg_tokens = self._tokenize(svg_text, self.max_length, device)
svg_tokens_embeds = self._get_embeddings(svg_tokens.input_ids)
inputs_embeds = torch.cat([conditioning_embeds, svg_tokens_embeds], dim=1)
svg_targets = self._create_targets(svg_tokens)
empty_targets = torch.ones(conditioning_embeds_att.size(), dtype=torch.long).to(device).fill_(-100)
targets = torch.cat([empty_targets, svg_targets], dim=1)
attention_mask = torch.cat([conditioning_embeds_att, svg_tokens.attention_mask], dim=1)
return inputs_embeds, attention_mask, targets
def forward(self, batch):
"""Forward pass"""
device = batch["image"].device
task = self.task
# Depending
if task == 'text2svg':
inputs_embeds, attention_mask, targets = self.embed_text_to_svg(batch, device)
elif task == 'im2svg':
inputs_embeds, attention_mask, targets = self.embed_im_to_svg(batch, device)
outputs = self.svg_transformer.transformer(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=targets,
return_dict=True,
output_hidden_states=True,
use_cache=False,
)
loss = outputs.loss
return loss
@abstractmethod
def _get_svg_text(self, svg_list):
"""Get SVG text with appropriate end tokens - implementation differs between v1 and v2"""
pass
def _prepare_generation_inputs(self, batch, prompt, device):
"""Common preparation for generation inputs"""
image = batch["image"]
image = image.to(device).to(self.model_precision)
embedded_image = self.image_encoder(image)
embedded_image = self.image_projection(embedded_image)
embedded_att = torch.ones(embedded_image.size()[:-1], dtype=torch.long).to(device)
if prompt is None:
prompt = self.svg_transformer.prompt
prompt = [prompt] * image.size(0)
prompt_tokens = self._tokenize(prompt, None, device, add_special_tokens=False)
attention_mask = torch.cat([embedded_att, prompt_tokens.attention_mask], dim=1)
inputs_embeds = self._get_embeddings(prompt_tokens.input_ids)
inputs_embeds = torch.cat([embedded_image, inputs_embeds], dim=1)
return inputs_embeds, attention_mask, prompt_tokens
def _get_generation_kwargs(self, base_kwargs):
"""Common generation kwargs preparation"""
# Get token IDs for "</svg>"
end_sequence = self.svg_transformer.tokenizer("</svg>", add_special_tokens=False)['input_ids']
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[end_sequence])])
return {
'inputs_embeds': base_kwargs['inputs_embeds'],
'attention_mask': base_kwargs['attention_mask'],
'do_sample': base_kwargs.get('use_nucleus_sampling', True),
'top_p': base_kwargs.get('top_p', 0.9),
'temperature': base_kwargs.get('temperature', 1),
'num_beams': base_kwargs.get('num_beams', 2),
'max_length': base_kwargs.get('max_length', 30),
'min_length': base_kwargs.get('min_length', 1),
'repetition_penalty': base_kwargs.get('repetition_penalty', 1.0),
'length_penalty': base_kwargs.get('length_penalty', 1.0),
'use_cache': base_kwargs.get('use_cache', True),
'stopping_criteria': stopping_criteria
}
def generate_im2svg(self, batch, **kwargs):
"""Base implementation of image to SVG generation"""
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
batch, kwargs.get('prompt'), batch["image"].device
)
generation_kwargs = self._get_generation_kwargs(
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
)
# Let subclasses override these defaults if needed
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
outputs = torch.cat([prompt_tokens.input_ids, outputs], dim=1)
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return raw_svg
def generate_im2svg_grpo(self, batch, **kwargs):
"""Base implementation of image to SVG generation"""
inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs(
batch, kwargs.get('prompt'), batch["image"].device
)
generation_kwargs = self._get_generation_kwargs(
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}
)
# Let subclasses override these defaults if needed
generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs))
num_return_sequences = kwargs.get('num_return_sequences', 1)
if num_return_sequences > 1:
generation_kwargs['num_return_sequences'] = num_return_sequences
generation_kwargs['num_beams'] = 1
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
outputs = torch.cat([prompt_tokens.input_ids.repeat(num_return_sequences, 1), outputs], dim=1)
raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return {
"raw_svg": raw_svg,
"outputs": outputs,
"inputs_embeds": inputs_embeds,
}
def _get_im2svg_specific_kwargs(self, kwargs):
"""Default implementation of im2svg specific generation kwargs.
Subclasses can override this to customize generation behavior."""
return {
'early_stopping': True,
'pad_token_id': self.svg_transformer.tokenizer.pad_token_id
}
def generate_text2svg(self, batch, **kwargs):
"""Base implementation of text to SVG generation"""
device = batch["image"].device
prompt = batch["caption"]
prompt_tokens = self._tokenize(
prompt,
max_length=kwargs.get('max_length', 30),
device=device,
add_special_tokens=False
)
trigger_token = self._tokenize(
[self.svg_transformer.svg_start_token for _ in batch["caption"]],
max_length=None,
device=device,
add_special_tokens=False
)
input_tokens = torch.cat([prompt_tokens.input_ids, trigger_token.input_ids], dim=1)
attention_mask = torch.cat([prompt_tokens.attention_mask, trigger_token.attention_mask], dim=1)
inputs_embeds = self._get_embeddings(input_tokens)
max_length = kwargs.get('max_length', 30) - input_tokens.size(1)
generation_kwargs = self._get_generation_kwargs(
{**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask},
input_tokens.size(1)
)
# Let subclasses override these defaults if needed
generation_kwargs.update(self._get_text2svg_specific_kwargs(kwargs))
generation_kwargs['max_length'] = max_length
outputs = self.svg_transformer.transformer.generate(**generation_kwargs)
return outputs
def _get_text2svg_specific_kwargs(self, kwargs):
"""Default implementation of text2svg specific generation kwargs.
Subclasses can override this to customize generation behavior."""
return {
'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
'early_stopping': True,
'length_penalty': kwargs.get('length_penalty', 1.0)
}