Spaces:
Running
Running
File size: 2,276 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import torch.nn as nn
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
from functools import partial
from starvector.model.models.starvector_base import StarVectorBase
from transformers import AutoImageProcessor
class StarVectorStarCoder2(StarVectorBase):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
def _get_svg_transformer(self, config, **kwargs):
from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2
return StarCoderModel(config, **kwargs)
def get_fsdp_wrapping_policy(self):
"""V2 specific FSDP wrapping policy"""
from starvector.model.image_encoder.image_encoder import ImageEncoder
image_encoder_wrapping_policy = partial(
_module_wrap_policy,
module_classes={ImageEncoder},
)
llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy()
from starvector.model.adapters.adapter import Adapter
adapter_wrapping_policy = partial(
_module_wrap_policy,
module_classes={Adapter},
)
return partial(
_or_policy,
policies=[
image_encoder_wrapping_policy,
llm_fsdp_wrapping_policy,
adapter_wrapping_policy,
],
)
def _get_embeddings(self, input_ids):
"""V2 specific embedding method"""
return self.svg_transformer.transformer.model.embed_tokens(input_ids)
def _get_svg_text(self, svg_list):
"""V2 specific SVG text preparation"""
return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list]
def _get_im2svg_specific_kwargs(self, kwargs):
"""V2 specific generation kwargs"""
return {
# 'eos_token_id': self.svg_transformer.svg_end_token_id,
}
def _get_text2svg_specific_kwargs(self, kwargs):
"""V2 specific text2svg generation kwargs"""
return {
'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
}
|