|
from typing import Union, Tuple, Dict, Any, Optional |
|
import os |
|
import json |
|
from collections import OrderedDict |
|
import torch |
|
from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url |
|
|
|
class PretrainedConfig(object): |
|
model_type: str = "" |
|
is_composition: bool = False |
|
|
|
def __init__(self, **kwargs): |
|
|
|
self.return_dict = kwargs.pop("return_dict", True) |
|
self.output_hidden_states = kwargs.pop("output_hidden_states", False) |
|
self.output_attentions = kwargs.pop("output_attentions", False) |
|
self.torchscript = kwargs.pop("torchscript", False) |
|
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) |
|
self.pruned_heads = kwargs.pop("pruned_heads", {}) |
|
self.tie_word_embeddings = kwargs.pop( |
|
"tie_word_embeddings", True |
|
) |
|
|
|
|
|
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) |
|
self.is_decoder = kwargs.pop("is_decoder", False) |
|
self.add_cross_attention = kwargs.pop("add_cross_attention", False) |
|
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) |
|
|
|
|
|
self.max_length = kwargs.pop("max_length", 20) |
|
self.min_length = kwargs.pop("min_length", 0) |
|
self.do_sample = kwargs.pop("do_sample", False) |
|
self.early_stopping = kwargs.pop("early_stopping", False) |
|
self.num_beams = kwargs.pop("num_beams", 1) |
|
self.num_beam_groups = kwargs.pop("num_beam_groups", 1) |
|
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) |
|
self.temperature = kwargs.pop("temperature", 1.0) |
|
self.top_k = kwargs.pop("top_k", 50) |
|
self.top_p = kwargs.pop("top_p", 1.0) |
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) |
|
self.length_penalty = kwargs.pop("length_penalty", 1.0) |
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) |
|
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) |
|
self.bad_words_ids = kwargs.pop("bad_words_ids", None) |
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) |
|
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) |
|
self.output_scores = kwargs.pop("output_scores", False) |
|
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) |
|
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) |
|
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) |
|
|
|
|
|
self.architectures = kwargs.pop("architectures", None) |
|
self.finetuning_task = kwargs.pop("finetuning_task", None) |
|
self.id2label = kwargs.pop("id2label", None) |
|
self.label2id = kwargs.pop("label2id", None) |
|
if self.id2label is not None: |
|
kwargs.pop("num_labels", None) |
|
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) |
|
|
|
else: |
|
self.num_labels = kwargs.pop("num_labels", 2) |
|
|
|
|
|
self.tokenizer_class = kwargs.pop("tokenizer_class", None) |
|
self.prefix = kwargs.pop("prefix", None) |
|
self.bos_token_id = kwargs.pop("bos_token_id", None) |
|
self.pad_token_id = kwargs.pop("pad_token_id", None) |
|
self.eos_token_id = kwargs.pop("eos_token_id", None) |
|
self.sep_token_id = kwargs.pop("sep_token_id", None) |
|
|
|
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) |
|
|
|
|
|
self.task_specific_params = kwargs.pop("task_specific_params", None) |
|
|
|
|
|
self.xla_device = kwargs.pop("xla_device", None) |
|
|
|
|
|
self._name_or_path = str(kwargs.pop("name_or_path", "")) |
|
|
|
|
|
kwargs.pop("transformers_version", None) |
|
|
|
|
|
for key, value in kwargs.items(): |
|
try: |
|
setattr(self, key, value) |
|
except AttributeError as err: |
|
raise err |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
@classmethod |
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): |
|
with open(json_file, "r", encoding="utf-8") as reader: |
|
text = reader.read() |
|
return json.loads(text) |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": |
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) |
|
|
|
config = cls(**config_dict) |
|
|
|
if hasattr(config, "pruned_heads"): |
|
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) |
|
|
|
|
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(config, key): |
|
setattr(config, key, value) |
|
to_remove.append(key) |
|
for key in to_remove: |
|
kwargs.pop(key, None) |
|
|
|
if return_unused_kwargs: |
|
return config, kwargs |
|
else: |
|
return config |
|
|
|
@classmethod |
|
def get_config_dict( |
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
resume_download = kwargs.pop("resume_download", False) |
|
proxies = kwargs.pop("proxies", None) |
|
use_auth_token = kwargs.pop("use_auth_token", None) |
|
local_files_only = kwargs.pop("local_files_only", False) |
|
revision = kwargs.pop("revision", None) |
|
|
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path) |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) |
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): |
|
config_file = pretrained_model_name_or_path |
|
else: |
|
config_file = hf_bucket_url( |
|
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None |
|
) |
|
|
|
try: |
|
|
|
resolved_config_file = cached_path( |
|
config_file, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
local_files_only=local_files_only, |
|
use_auth_token=use_auth_token, |
|
) |
|
|
|
config_dict = cls._dict_from_json_file(resolved_config_file) |
|
|
|
except EnvironmentError as err: |
|
msg = ( |
|
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" |
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" |
|
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" |
|
) |
|
raise EnvironmentError(msg) |
|
|
|
except json.JSONDecodeError: |
|
msg = ( |
|
"Couldn't reach server at '{}' to download configuration file or " |
|
"configuration file is not a valid JSON file. " |
|
"Please check network or file content here: {}.".format(config_file, resolved_config_file) |
|
) |
|
raise EnvironmentError(msg) |
|
|
|
return config_dict, kwargs |
|
|
|
|
|
class BertConfig(PretrainedConfig): |
|
model_type = "bert" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=30522, |
|
hidden_size=768, |
|
num_hidden_layers=12, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=512, |
|
type_vocab_size=2, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
pad_token_id=0, |
|
gradient_checkpointing=False, |
|
position_embedding_type="absolute", |
|
use_cache=True, |
|
**kwargs |
|
): |
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
|
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.hidden_act = hidden_act |
|
self.intermediate_size = intermediate_size |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.max_position_embeddings = max_position_embeddings |
|
self.type_vocab_size = type_vocab_size |
|
self.initializer_range = initializer_range |
|
self.layer_norm_eps = layer_norm_eps |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.position_embedding_type = position_embedding_type |
|
self.use_cache = use_cache |
|
|