coin_net / gears.py
HellSank's picture
Update gears.py
b9f1ea6
raw
history blame contribute delete
374 Bytes
import torch
import torch.nn.functional as F
import numpy as np
def load_tensor():
coeffs = torch.load('tensor.pt')
return coeffs
def calc_preds(coeffs, indeps):
layers,consts = coeffs
n = len(layers)
res = indeps
for i,l in enumerate(layers):
res = res @ l + consts[i]
if i!=n-1: res = F.relu(res)
return torch.sigmoid(res)