boda
add MLP models
d7e4f1f
import torch
class Linear():
def __init__(self, in_n, out_n, bias=True) -> None:
self.params = []
self.have_bias = bias
self.weight = torch.randn((in_n,out_n)) / (in_n**0.5)
self.params.append(self.weight)
self.bias = None
if self.have_bias:
self.bias = torch.zeros(out_n)
self.params.append(self.bias)
def __call__(self,x, is_training =True):
self.is_training = is_training
self.out = x @ self.params[0]
if self.have_bias:
self.out += self.params[1]
return self.out
def set_parameters(self,p):
self.params = p
# self.weight = p[0]
# self.bias = p[1]
# self.params = [p]
def parameters(self):
return self.params
class BatchNorm():
def __init__(self, in_n,eps=1e-5, momentum = 0.1) -> None:
self.eps = eps
self.is_training = True
self.momentum = momentum
self.running_mean = torch.zeros(in_n)
self.running_std = torch.ones(in_n)
self.gain = torch.ones(in_n)
self.bias = torch.zeros(in_n)
self.params = [self.gain , self.bias]
def __call__(self, x, is_training= True):
self.is_training = is_training
if self.is_training:
mean = x.mean(0,keepdims= True)
## unbiased??
std = x.std(0,keepdims= True)
self.out = self.params[0] * (x - mean / (std + self.eps**0.5)) + self.params[1]
with torch.no_grad():
self.running_mean = self.running_mean * (1- self.momentum) \
+ self.momentum * mean
self.running_std = self.running_std * (1- self.momentum) \
+ self.momentum * std
else:
# print(self.running_mean , self.running_std)
self.out = self.params[0] * (x - self.running_mean / (self.running_std + self.eps**0.5)) + self.params[1]
return self.out
def set_parameters(self,p):
self.params = p
# self.gain = p[0]
# self.bias = p[1]
# self.params = [self.gain , self.bias]
def set_mean_std(self, conf):
self.running_mean = conf[0]
self.running_std = conf[1]
def get_mean_std(self):
return [self.running_mean, self.running_std]
def parameters(self):
return self.params
class Activation():
def __init__(self, activation='tanh'):
self.params = []
if activation == 'tanh':
self.forward = self._forward_tanh
elif activation == 'relu':
self.forward = self._forward_relu
else:
raise Exception('Only tanh, and relu activations are supported')
def _forward_relu(self,x):
return torch.relu(x)
def _forward_tanh(self,x):
return torch.tanh(x)
def __call__(self, x, is_training= True):
self.is_training = is_training
self.out = self.forward(x)
return self.out
def set_parameters(self,p):
self.params = p
def parameters(self):
return self.params