class Model(PreTrainedModel): config_class = VLMConfig def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int): super().__init__(config) self.image_model = image_model self.language_model = language_model self.projector = nn.Sequential( *projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections) ) self.tokenizer = tokenizer self.eos_token = tokenizer.eos_token self.prepend_text = prepend_text self.image_tokens = image_tokens input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item() text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach() self.prepend_embeddings = text_embeddings[:, :eos_token_index] self.postpend_embeddings = text_embeddings[:, eos_token_index:] self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens) self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK) def project_image_features(self, images: torch.Tensor): image_features = self.image_model.forward_features(images) image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim") encoder_outputs = self.projector(image_features) return encoder_outputs def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]): image_outputs = self.project_image_features(images) caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach() device = images.device embeddings = torch.cat( [ self.prepend_embeddings.to(device).expand(len(images), -1, -1), image_outputs, self.postpend_embeddings.to(device).expand(len(images), -1, -1), caption_embeddings, ], dim=1, ) attention_mask = torch.cat( [ self.attention_mask.to(device).expand(len(images), -1), tokenized_captions.attention_mask ], dim=1 ) labels = torch.cat( [ self.labels.to(device).expand(len(images), -1), tokenized_captions.input_ids.clone() ], dim=1, ) labels[attention_mask == 0] = LABEL_MASK return self.language_model( inputs_embeds=embeddings, attention_mask=attention_mask, labels=labels, ) def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]): image_outputs = self.project_image_features(images) device = images.device embeddings = torch.cat( [ self.prepend_embeddings.to(device).expand(len(images), -1, -1), image_outputs, self.postpend_embeddings.to(device).expand(len(images), -1, -1), ], dim=1, ) attention_mask = self.attention_mask.to(device).expand(len(images), -1) return self.language_model.generate( inputs_embeds=embeddings, attention_mask=attention_mask, eos_token_id=self.tokenizer.eos_token_id, **generator_kwargs )