File size: 3,132 Bytes
d7e4f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch



class Linear():
    def __init__(self, in_n, out_n, bias=True) -> None:
        self.params = []
        self.have_bias = bias
        self.weight = torch.randn((in_n,out_n)) / (in_n**0.5)
        self.params.append(self.weight)

        self.bias = None
        if self.have_bias:
            self.bias = torch.zeros(out_n)
            self.params.append(self.bias)
    
    def __call__(self,x, is_training =True):
        self.is_training = is_training

        self.out = x @ self.params[0]
        if self.have_bias:
            self.out += self.params[1]
            return self.out

    def set_parameters(self,p):
        self.params = p
        # self.weight = p[0]
        # self.bias = p[1]
        # self.params = [p]

    def parameters(self):
        return self.params
        

class BatchNorm():
    def __init__(self, in_n,eps=1e-5, momentum = 0.1) -> None:
        self.eps = eps
        self.is_training = True
        self.momentum = momentum
        self.running_mean = torch.zeros(in_n)
        self.running_std  = torch.ones(in_n)
        self.gain = torch.ones(in_n)
        self.bias = torch.zeros(in_n)
        self.params = [self.gain , self.bias]


    def __call__(self, x,  is_training= True):

        self.is_training = is_training
        if self.is_training:
            mean = x.mean(0,keepdims= True)
            ## unbiased??
            std = x.std(0,keepdims= True)

            self.out =  self.params[0] * (x - mean / (std + self.eps**0.5)) + self.params[1]

            with torch.no_grad():
                self.running_mean = self.running_mean * (1- self.momentum) \
                    + self.momentum * mean
                self.running_std = self.running_std * (1- self.momentum) \
                + self.momentum * std

        else:
            # print(self.running_mean , self.running_std)
            self.out =  self.params[0] * (x - self.running_mean / (self.running_std + self.eps**0.5)) + self.params[1]



        return self.out

    def set_parameters(self,p):
        self.params = p
        # self.gain = p[0]
        # self.bias = p[1]
        # self.params = [self.gain , self.bias]
    def set_mean_std(self, conf):
        self.running_mean = conf[0]
        self.running_std = conf[1]
    def get_mean_std(self):
        return [self.running_mean, self.running_std]

    def parameters(self):
        return self.params

class Activation():
    def __init__(self, activation='tanh'):
        self.params = []
        if activation == 'tanh':
            self.forward = self._forward_tanh
        elif activation == 'relu':
            self.forward = self._forward_relu
        else:
            raise Exception('Only tanh, and relu activations are supported')

    def _forward_relu(self,x):
        return torch.relu(x)
    def _forward_tanh(self,x):
        return torch.tanh(x)

    def __call__(self, x, is_training= True):

        self.is_training = is_training

        
        self.out = self.forward(x)
        return self.out
    
    def set_parameters(self,p):
        self.params = p
    def parameters(self):
        return self.params