File size: 2,276 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
from functools import partial
from starvector.model.models.starvector_base import StarVectorBase
from transformers import AutoImageProcessor

class StarVectorStarCoder2(StarVectorBase):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)

    def _get_svg_transformer(self, config, **kwargs):
        from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2
        return StarCoderModel(config, **kwargs)


    def get_fsdp_wrapping_policy(self):
        """V2 specific FSDP wrapping policy"""
        from starvector.model.image_encoder.image_encoder import ImageEncoder

        image_encoder_wrapping_policy = partial(
            _module_wrap_policy,
            module_classes={ImageEncoder},
        )

        llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy()
        from starvector.model.adapters.adapter import Adapter

        adapter_wrapping_policy = partial(
            _module_wrap_policy,
            module_classes={Adapter},
        )

        return partial(
            _or_policy,
            policies=[
                image_encoder_wrapping_policy,
                llm_fsdp_wrapping_policy,
                adapter_wrapping_policy,
            ],
        )

    def _get_embeddings(self, input_ids):
        """V2 specific embedding method"""
        return self.svg_transformer.transformer.model.embed_tokens(input_ids)

    def _get_svg_text(self, svg_list):
        """V2 specific SVG text preparation"""
        return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list]

    def _get_im2svg_specific_kwargs(self, kwargs):
        """V2 specific generation kwargs"""
        return {
            # 'eos_token_id': self.svg_transformer.svg_end_token_id,
        }

    def _get_text2svg_specific_kwargs(self, kwargs):
        """V2 specific text2svg generation kwargs"""
        return {
            'eos_token_id': self.svg_transformer.tokenizer.eos_token_id,
        }