import torch import torch.nn as nn from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy from functools import partial from starvector.model.models.starvector_base import StarVectorBase from transformers import AutoImageProcessor class StarVectorStarCoder2(StarVectorBase): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True) def _get_svg_transformer(self, config, **kwargs): from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2 return StarCoderModel(config, **kwargs) def get_fsdp_wrapping_policy(self): """V2 specific FSDP wrapping policy""" from starvector.model.image_encoder.image_encoder import ImageEncoder image_encoder_wrapping_policy = partial( _module_wrap_policy, module_classes={ImageEncoder}, ) llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy() from starvector.model.adapters.adapter import Adapter adapter_wrapping_policy = partial( _module_wrap_policy, module_classes={Adapter}, ) return partial( _or_policy, policies=[ image_encoder_wrapping_policy, llm_fsdp_wrapping_policy, adapter_wrapping_policy, ], ) def _get_embeddings(self, input_ids): """V2 specific embedding method""" return self.svg_transformer.transformer.model.embed_tokens(input_ids) def _get_svg_text(self, svg_list): """V2 specific SVG text preparation""" return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list] def _get_im2svg_specific_kwargs(self, kwargs): """V2 specific generation kwargs""" return { # 'eos_token_id': self.svg_transformer.svg_end_token_id, } def _get_text2svg_specific_kwargs(self, kwargs): """V2 specific text2svg generation kwargs""" return { 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id, }