Spaces:
Runtime error
Runtime error
import torch | |
class Linear(): | |
def __init__(self, in_n, out_n, bias=True) -> None: | |
self.params = [] | |
self.have_bias = bias | |
self.weight = torch.randn((in_n,out_n)) / (in_n**0.5) | |
self.params.append(self.weight) | |
self.bias = None | |
if self.have_bias: | |
self.bias = torch.zeros(out_n) | |
self.params.append(self.bias) | |
def __call__(self,x, is_training =True): | |
self.is_training = is_training | |
self.out = x @ self.params[0] | |
if self.have_bias: | |
self.out += self.params[1] | |
return self.out | |
def set_parameters(self,p): | |
self.params = p | |
# self.weight = p[0] | |
# self.bias = p[1] | |
# self.params = [p] | |
def parameters(self): | |
return self.params | |
class BatchNorm(): | |
def __init__(self, in_n,eps=1e-5, momentum = 0.1) -> None: | |
self.eps = eps | |
self.is_training = True | |
self.momentum = momentum | |
self.running_mean = torch.zeros(in_n) | |
self.running_std = torch.ones(in_n) | |
self.gain = torch.ones(in_n) | |
self.bias = torch.zeros(in_n) | |
self.params = [self.gain , self.bias] | |
def __call__(self, x, is_training= True): | |
self.is_training = is_training | |
if self.is_training: | |
mean = x.mean(0,keepdims= True) | |
## unbiased?? | |
std = x.std(0,keepdims= True) | |
self.out = self.params[0] * (x - mean / (std + self.eps**0.5)) + self.params[1] | |
with torch.no_grad(): | |
self.running_mean = self.running_mean * (1- self.momentum) \ | |
+ self.momentum * mean | |
self.running_std = self.running_std * (1- self.momentum) \ | |
+ self.momentum * std | |
else: | |
# print(self.running_mean , self.running_std) | |
self.out = self.params[0] * (x - self.running_mean / (self.running_std + self.eps**0.5)) + self.params[1] | |
return self.out | |
def set_parameters(self,p): | |
self.params = p | |
# self.gain = p[0] | |
# self.bias = p[1] | |
# self.params = [self.gain , self.bias] | |
def set_mean_std(self, conf): | |
self.running_mean = conf[0] | |
self.running_std = conf[1] | |
def get_mean_std(self): | |
return [self.running_mean, self.running_std] | |
def parameters(self): | |
return self.params | |
class Activation(): | |
def __init__(self, activation='tanh'): | |
self.params = [] | |
if activation == 'tanh': | |
self.forward = self._forward_tanh | |
elif activation == 'relu': | |
self.forward = self._forward_relu | |
else: | |
raise Exception('Only tanh, and relu activations are supported') | |
def _forward_relu(self,x): | |
return torch.relu(x) | |
def _forward_tanh(self,x): | |
return torch.tanh(x) | |
def __call__(self, x, is_training= True): | |
self.is_training = is_training | |
self.out = self.forward(x) | |
return self.out | |
def set_parameters(self,p): | |
self.params = p | |
def parameters(self): | |
return self.params |