pythia-1_4b-deduped-measurement_pred-generated_stories / modeling_gpt_neox_measurement_pred.py
oliverdk's picture
End of training
2ceedc4 verified
raw
history blame contribute delete
686 Bytes
from transformers.models.gpt_neox import GPTNeoXPreTrainedModel, GPTNeoXModel
from transformers import PreTrainedTokenizerBase
from .modeling_measurement_pred import MeasurementPredictorMixin
from .configuration_gpt_neox_measurement_pred import GPTNeoXMeasurementPredictorConfig
class GPTNeoXMeasurementPredictor(GPTNeoXPreTrainedModel, MeasurementPredictorMixin):
config_class = GPTNeoXMeasurementPredictorConfig
def __init__(self, config):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config)
self.post_init()
def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})