import torch import torch.nn as nn from transformers import BertModel, PreTrainedModel, BertConfig, AutoModel from typing import List from .configuration_marqo_arctic_bge_chimera_m import ChimeraConfig class Chimera(PreTrainedModel): config_class = ChimeraConfig def __init__(self, config: ChimeraConfig): super().__init__(config) bert_config = BertConfig( vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, ) self.model = nn.ModuleDict( { "model_0": BertModel(bert_config), "model_1": BertModel(bert_config), } ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, ) -> torch.Tensor: embeddings = [] for _, model in self.model.items(): model_output = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) pooled_output = model_output[0][:, 0] embeddings.append(pooled_output) return torch.cat(embeddings, dim=-1) def load_weights_from_automodels( self, in_models: List[str], has_pooling_layer: List[bool] ): model_list = [] for i, model_name in enumerate(in_models): model = AutoModel.from_pretrained( model_name, add_pooling_layer=has_pooling_layer[i], trust_remote_code=True, ) model.eval() model_list.append(model) self.model = nn.ModuleDict( {f"model_{i}": model for i, model in enumerate(model_list)} )