import math import torch import torch.utils.checkpoint import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import (ModelOutput,) class CounterModel(PreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.weight = config.weight self.bias = config.bias self.linear = torch.nn.Linear(1,1) def forward(self, x,**kwargs): x = self.weight * x + self.bias logits = self.linear(x) return logits def add(self): return self.weight + self.bias