File size: 620 Bytes
8cbed98 20e8f6c 14333a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import math
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import (ModelOutput,)
class CounterModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.weight = config.weight
self.bias = config.bias
self.linear = torch.nn.Linear(1,1)
def forward(self, x,**kwargs):
x = self.weight * x + self.bias
logits = self.linear(x)
return logits
def add(self):
return self.weight + self.bias |