Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from starvector.model.models.starvector_base import StarVectorBase | |
from transformers import AutoProcessor | |
class StarVectorStarCoder(StarVectorBase): | |
def __init__(self, config, **kwargs): | |
super().__init__(config, **kwargs) | |
self.processor = AutoProcessor.from_pretrained(config._name_or_path) | |
def _get_svg_transformer(self, config, **kwargs): | |
from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1) | |
return StarCoderModel(config, **kwargs) | |
def _get_embeddings(self, input_ids): | |
"""V1 specific embedding method""" | |
return self.svg_transformer.transformer.transformer.wte(input_ids) | |
def _get_svg_text(self, svg_list): | |
"""V1 specific SVG text preparation""" | |
return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list] | |