glenn-jocher commited on
Commit
264d860
·
unverified ·
1 Parent(s): 0395e39

ACON activation function (#2893)

Browse files

* ACON Activation Function

## 🚀 Feature

There is a new activation function [ACON (CVPR 2021)](https://arxiv.org/pdf/2009.04759.pdf) that unifies ReLU and Swish.
ACON is simple but very effective, code is here: https://github.com/nmaac/acon/blob/main/acon.py#L19

![image](https://user-images.githubusercontent.com/5032208/115676962-a38dfe80-a382-11eb-9883-61fa3216e3e6.png)

The improvements are very significant:
![image](https://user-images.githubusercontent.com/5032208/115680180-eac9be80-a385-11eb-9c7a-8643db552c69.png)

## Alternatives

It also has an enhanced version meta-ACON that uses a small network to learn beta explicitly, which may influence the speed a bit.

## Additional context

[Code](https://github.com/nmaac/acon) and [paper](https://arxiv.org/pdf/2009.04759.pdf).

* Update activations.py

Files changed (1) hide show
  1. utils/activations.py +41 -17
utils/activations.py CHANGED
@@ -19,23 +19,6 @@ class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
19
  return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX
20
 
21
 
22
- class MemoryEfficientSwish(nn.Module):
23
- class F(torch.autograd.Function):
24
- @staticmethod
25
- def forward(ctx, x):
26
- ctx.save_for_backward(x)
27
- return x * torch.sigmoid(x)
28
-
29
- @staticmethod
30
- def backward(ctx, grad_output):
31
- x = ctx.saved_tensors[0]
32
- sx = torch.sigmoid(x)
33
- return grad_output * (sx * (1 + x * (1 - sx)))
34
-
35
- def forward(self, x):
36
- return self.F.apply(x)
37
-
38
-
39
  # Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
40
  class Mish(nn.Module):
41
  @staticmethod
@@ -70,3 +53,44 @@ class FReLU(nn.Module):
70
 
71
  def forward(self, x):
72
  return torch.max(x, self.bn(self.conv(x)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
23
  class Mish(nn.Module):
24
  @staticmethod
 
53
 
54
  def forward(self, x):
55
  return torch.max(x, self.bn(self.conv(x)))
56
+
57
+
58
+ # ACON https://arxiv.org/pdf/2009.04759.pdf ----------------------------------------------------------------------------
59
+ class AconC(nn.Module):
60
+ r""" ACON activation (activate or not).
61
+ AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
62
+ according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
63
+ """
64
+
65
+ def __init__(self, c1):
66
+ super().__init__()
67
+ self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
68
+ self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
69
+ self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
70
+
71
+ def forward(self, x):
72
+ dpx = (self.p1 - self.p2) * x
73
+ return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
74
+
75
+
76
+ class MetaAconC(nn.Module):
77
+ r""" ACON activation (activate or not).
78
+ MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
79
+ according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
80
+ """
81
+
82
+ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
83
+ super().__init__()
84
+ c2 = max(r, c1 // r)
85
+ self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
86
+ self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
87
+ self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
88
+ self.bn1 = nn.BatchNorm2d(c2)
89
+ self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
90
+ self.bn2 = nn.BatchNorm2d(c1)
91
+
92
+ def forward(self, x):
93
+ y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
94
+ beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
95
+ dpx = (self.p1 - self.p2) * x
96
+ return dpx * torch.sigmoid(beta * dpx) + self.p2 * x