glenn-jocher commited on
Commit
f346da9
·
1 Parent(s): eb99dff

update activations.py

Browse files
Files changed (1) hide show
  1. utils/activations.py +38 -42
utils/activations.py CHANGED
@@ -3,69 +3,65 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
 
6
- # Swish ------------------------------------------------------------------------
7
- class SwishImplementation(torch.autograd.Function):
8
- @staticmethod
9
- def forward(ctx, x):
10
- ctx.save_for_backward(x)
11
- return x * torch.sigmoid(x)
12
-
13
- @staticmethod
14
- def backward(ctx, grad_output):
15
- x = ctx.saved_tensors[0]
16
- sx = torch.sigmoid(x)
17
- return grad_output * (sx * (1 + x * (1 - sx)))
18
-
19
-
20
- class MemoryEfficientSwish(nn.Module):
21
  @staticmethod
22
  def forward(x):
23
- return SwishImplementation.apply(x)
24
 
25
 
26
- class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
27
  @staticmethod
28
  def forward(x):
29
  return x * F.hardtanh(x + 3, 0., 6., True) / 6.
30
 
31
 
32
- class Swish(nn.Module):
33
- @staticmethod
34
- def forward(x):
35
- return x * torch.sigmoid(x)
 
 
 
 
 
 
 
 
36
 
 
 
37
 
38
- # Mish ------------------------------------------------------------------------
39
- class MishImplementation(torch.autograd.Function):
40
- @staticmethod
41
- def forward(ctx, x):
42
- ctx.save_for_backward(x)
43
- return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
44
 
 
 
45
  @staticmethod
46
- def backward(ctx, grad_output):
47
- x = ctx.saved_tensors[0]
48
- sx = torch.sigmoid(x)
49
- fx = F.softplus(x).tanh()
50
- return grad_output * (fx + x * sx * (1 - fx * fx))
51
 
52
 
53
  class MemoryEfficientMish(nn.Module):
54
- @staticmethod
55
- def forward(x):
56
- return MishImplementation.apply(x)
 
 
 
 
 
 
 
 
 
57
 
58
-
59
- class Mish(nn.Module): # https://github.com/digantamisra98/Mish
60
- @staticmethod
61
- def forward(x):
62
- return x * F.softplus(x).tanh()
63
 
64
 
65
- # FReLU https://arxiv.org/abs/2007.11824 --------------------------------------
66
  class FReLU(nn.Module):
67
  def __init__(self, c1, k=3): # ch_in, kernel
68
- super(FReLU, self).__init__()
69
  self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1)
70
  self.bn = nn.BatchNorm2d(c1)
71
 
 
3
  import torch.nn.functional as F
4
 
5
 
6
+ # Swish https://arxiv.org/pdf/1905.02244.pdf ---------------------------------------------------------------------------
7
+ class Swish(nn.Module): #
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @staticmethod
9
  def forward(x):
10
+ return x * torch.sigmoid(x)
11
 
12
 
13
+ class HardSwish(nn.Module):
14
  @staticmethod
15
  def forward(x):
16
  return x * F.hardtanh(x + 3, 0., 6., True) / 6.
17
 
18
 
19
+ class MemoryEfficientSwish(nn.Module):
20
+ class F(torch.autograd.Function):
21
+ @staticmethod
22
+ def forward(ctx, x):
23
+ ctx.save_for_backward(x)
24
+ return x * torch.sigmoid(x)
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ x = ctx.saved_tensors[0]
29
+ sx = torch.sigmoid(x)
30
+ return grad_output * (sx * (1 + x * (1 - sx)))
31
 
32
+ def forward(self, x):
33
+ return self.F.apply(x)
34
 
 
 
 
 
 
 
35
 
36
+ # Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
37
+ class Mish(nn.Module):
38
  @staticmethod
39
+ def forward(x):
40
+ return x * F.softplus(x).tanh()
 
 
 
41
 
42
 
43
  class MemoryEfficientMish(nn.Module):
44
+ class F(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, x):
47
+ ctx.save_for_backward(x)
48
+ return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
49
+
50
+ @staticmethod
51
+ def backward(ctx, grad_output):
52
+ x = ctx.saved_tensors[0]
53
+ sx = torch.sigmoid(x)
54
+ fx = F.softplus(x).tanh()
55
+ return grad_output * (fx + x * sx * (1 - fx * fx))
56
 
57
+ def forward(self, x):
58
+ return self.F.apply(x)
 
 
 
59
 
60
 
61
+ # FReLU https://arxiv.org/abs/2007.11824 -------------------------------------------------------------------------------
62
  class FReLU(nn.Module):
63
  def __init__(self, c1, k=3): # ch_in, kernel
64
+ super().__init__()
65
  self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1)
66
  self.bn = nn.BatchNorm2d(c1)
67