File size: 620 Bytes
8cbed98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20e8f6c
14333a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

import math

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from transformers import PreTrainedModel

from transformers.modeling_outputs import (ModelOutput,)


class CounterModel(PreTrainedModel):


    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.weight = config.weight
        self.bias = config.bias
        self.linear = torch.nn.Linear(1,1)
    def forward(self, x,**kwargs):
        x = self.weight * x + self.bias
        logits = self.linear(x)
        return logits
    def add(self):
        return self.weight + self.bias