File size: 872 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
from starvector.model.models.starvector_base import StarVectorBase
from transformers import AutoProcessor

class StarVectorStarCoder(StarVectorBase):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.processor = AutoProcessor.from_pretrained(config._name_or_path)

    def _get_svg_transformer(self, config, **kwargs):
        from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1)
        return StarCoderModel(config, **kwargs)

    def _get_embeddings(self, input_ids):
        """V1 specific embedding method"""
        return self.svg_transformer.transformer.transformer.wte(input_ids)

    def _get_svg_text(self, svg_list):
        """V1 specific SVG text preparation"""
        return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list]