coin_net / gears.py
Joao Henrique
less
9943c93
raw
history blame
443 Bytes
import torch
import torch.nn.functional as F
import numpy as np
def load_tensor():
coeffs = [None,None]
coeffs[0] = torch.load('tensor1.pt')
coeffs[1] = torch.load('tensor2.pt')
return coeffs
def calc_preds(coeffs, indeps):
layers,consts = coeffs
n = len(layers)
res = indeps
for i,l in enumerate(layers):
res = res@l + consts[i]
if i!=n-1: res = F.relu(res)
return torch.sigmoid(res)