Commit
·
4fb8cb3
1
Parent(s):
4d7f222
Update export.py with v3.0 Hardswish() support (#831)
Browse files- models/export.py +2 -1
models/export.py
CHANGED
@@ -7,6 +7,7 @@ Usage:
|
|
7 |
import argparse
|
8 |
|
9 |
import torch
|
|
|
10 |
|
11 |
from models.common import Conv
|
12 |
from models.experimental import attempt_load
|
@@ -32,7 +33,7 @@ if __name__ == '__main__':
|
|
32 |
# Update model
|
33 |
for k, m in model.named_modules():
|
34 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
35 |
-
if isinstance(m, Conv):
|
36 |
m.act = Hardswish() # assign activation
|
37 |
# if isinstance(m, Detect):
|
38 |
# m.forward = m.forward_export # assign forward (optional)
|
|
|
7 |
import argparse
|
8 |
|
9 |
import torch
|
10 |
+
import torch.nn as nn
|
11 |
|
12 |
from models.common import Conv
|
13 |
from models.experimental import attempt_load
|
|
|
33 |
# Update model
|
34 |
for k, m in model.named_modules():
|
35 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
36 |
+
if isinstance(m, Conv) and isinstance(m.act, nn.Hardswish):
|
37 |
m.act = Hardswish() # assign activation
|
38 |
# if isinstance(m, Detect):
|
39 |
# m.forward = m.forward_export # assign forward (optional)
|