import torch.nn as nn from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, ) from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from functools import partial from starvector.train.util import get_module_class_from_name import torch class StarCoderModel(nn.Module): def __init__(self, config, **kwargs): super(StarCoderModel, self).__init__() self.init_tokenizer(config.starcoder_model_name) self.max_length = config.max_length model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) model_config.use_cache = config.use_cache model_config.use_bfloat16 = True model = AutoModelForCausalLM.from_pretrained( config.starcoder_model_name, config=model_config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) model.resize_token_embeddings(len(self.tokenizer)) self.transformer = model # Prompt the model after image self.prompt = '