coencoder_test2_phase1_2 / configuration_co_encoder.py
sudy-super's picture
Upload model and tokenizers
aee9709 verified
# coding=utf-8
"""CoEncoder model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class CoEncoderConfig(PretrainedConfig):
r"""
"""
model_type = "co_encoder"
def __init__(
self,
context_config=None,
text_config=None,
ignore_index=-100,
connector_hidden_act="gelu",
context_feature_layer=-2,
context_feature_select_strategy="default",
begin_of_context_token_id=None,
end_of_context_token_id=None,
tie_word_embeddings=False,
**kwargs,
):
self.ignore_index = ignore_index
self.connector_hidden_act = connector_hidden_act
self.context_feature_layer = context_feature_layer
self.context_feature_select_strategy = context_feature_select_strategy
self.begin_of_context_token_id = begin_of_context_token_id
self.end_of_context_token_id = end_of_context_token_id
if context_feature_select_strategy not in ["default"]:
raise ValueError(
"context_feature_select_strategy should be one of 'default'."
f"Got: {context_feature_select_strategy}"
)
if isinstance(context_config, dict):
context_config["model_type"] = (
context_config["model_type"] if "model_type" in context_config else "qwen2"
)
context_config = CONFIG_MAPPING[context_config["model_type"]](**context_config)
self.context_config = context_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config
super().__init__(
tie_word_embeddings=tie_word_embeddings,
ignore_index=ignore_index,
connector_hidden_act=connector_hidden_act,
context_feature_layer=context_feature_layer,
context_feature_select_strategy=context_feature_select_strategy,
begin_of_context_token_id=begin_of_context_token_id,
end_of_context_token_id=end_of_context_token_id,
**kwargs
)