File size: 374 Bytes
9943c93 b9f1ea6 9943c93 b9f1ea6 9943c93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
import torch.nn.functional as F
import numpy as np
def load_tensor():
coeffs = torch.load('tensor.pt')
return coeffs
def calc_preds(coeffs, indeps):
layers,consts = coeffs
n = len(layers)
res = indeps
for i,l in enumerate(layers):
res = res @ l + consts[i]
if i!=n-1: res = F.relu(res)
return torch.sigmoid(res)
|