Spaces:
Running
Running
import torch.nn as nn | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
from functools import partial | |
from starvector.train.util import get_module_class_from_name | |
import torch | |
class StarCoderModel(nn.Module): | |
def __init__(self, config, **kwargs): | |
super(StarCoderModel, self).__init__() | |
self.init_tokenizer(config.starcoder_model_name) | |
self.max_length = config.max_length | |
model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) | |
model_config.use_cache = config.use_cache | |
model_config.use_bfloat16 = True | |
model = AutoModelForCausalLM.from_pretrained( | |
config.starcoder_model_name, | |
config=model_config, | |
attn_implementation="flash_attention_2", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True) | |
model.resize_token_embeddings(len(self.tokenizer)) | |
self.transformer = model | |
# Prompt the model after image | |
self.prompt = '<svg' | |
transformer_layer_cls = kwargs.get('transformer_layer_cls', 'Starcoder2DecoderLayer') | |
self.transformer_layer_cls = get_module_class_from_name(self, transformer_layer_cls) | |
def init_tokenizer(self, model_name): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
# Incude padding and eos tokens in the vocabulary | |
if self.tokenizer.eos_token_id is None: | |
self.tokenizer.add_special_tokens({"eos_token": "[EOS]"}) | |
if self.tokenizer.pad_token_id is None: | |
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
self.svg_start_token = "<svg-start>" | |
self.svg_end_token = "<svg-end>" | |
self.image_start_token = "<image-start>" | |
self.text_start_token = "<caption-start>" | |
self.tokenizer.add_tokens([self.svg_start_token, self.image_start_token, self.text_start_token, self.svg_end_token]) | |
self.svg_start_token_id = self.tokenizer.encode(self.svg_start_token)[0] | |
self.svg_end_token_id = self.tokenizer.encode(self.svg_end_token)[0] | |
self.tokenizer.padding_side = "left" | |
def get_fsdp_wrapping_policy(self): | |
"""Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`""" | |
transformer_block_policy = partial( | |
transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls} | |
) | |
return transformer_block_policy |