Spaces:
Running
Running
File size: 872 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
from starvector.model.models.starvector_base import StarVectorBase
from transformers import AutoProcessor
class StarVectorStarCoder(StarVectorBase):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.processor = AutoProcessor.from_pretrained(config._name_or_path)
def _get_svg_transformer(self, config, **kwargs):
from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1)
return StarCoderModel(config, **kwargs)
def _get_embeddings(self, input_ids):
"""V1 specific embedding method"""
return self.svg_transformer.transformer.transformer.wte(input_ids)
def _get_svg_text(self, svg_list):
"""V1 specific SVG text preparation"""
return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list]
|