HellSank commited on
Commit
b9f1ea6
·
1 Parent(s): 00d11d9

Update gears.py

Browse files
Files changed (1) hide show
  1. gears.py +2 -5
gears.py CHANGED
@@ -5,18 +5,15 @@ import numpy as np
5
 
6
 
7
  def load_tensor():
8
- coeffs = [None,None]
9
- coeffs[0] = torch.load('tensor1.pt')
10
- coeffs[1] = torch.load('tensor2.pt')
11
  return coeffs
12
 
13
 
14
-
15
  def calc_preds(coeffs, indeps):
16
  layers,consts = coeffs
17
  n = len(layers)
18
  res = indeps
19
  for i,l in enumerate(layers):
20
- res = res@l + consts[i]
21
  if i!=n-1: res = F.relu(res)
22
  return torch.sigmoid(res)
 
5
 
6
 
7
  def load_tensor():
8
+ coeffs = torch.load('tensor.pt')
 
 
9
  return coeffs
10
 
11
 
 
12
  def calc_preds(coeffs, indeps):
13
  layers,consts = coeffs
14
  n = len(layers)
15
  res = indeps
16
  for i,l in enumerate(layers):
17
+ res = res @ l + consts[i]
18
  if i!=n-1: res = F.relu(res)
19
  return torch.sigmoid(res)