Feng Wang
commited on
Commit
·
713fbb7
1
Parent(s):
b77e7e1
fix(utils,exp): logger compat issue and exp check (#1618)
Browse files- tools/train.py +2 -1
- yolox/exp/__init__.py +1 -2
- yolox/exp/base_exp.py +2 -3
- yolox/exp/yolox_base.py +7 -1
- yolox/utils/logger.py +7 -3
tools/train.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
|
13 |
from yolox.core import launch
|
14 |
-
from yolox.exp import Exp, get_exp
|
15 |
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
|
16 |
|
17 |
|
@@ -123,6 +123,7 @@ if __name__ == "__main__":
|
|
123 |
args = make_parser().parse_args()
|
124 |
exp = get_exp(args.exp_file, args.name)
|
125 |
exp.merge(args.opts)
|
|
|
126 |
|
127 |
if not args.experiment_name:
|
128 |
args.experiment_name = exp.exp_name
|
|
|
11 |
import torch.backends.cudnn as cudnn
|
12 |
|
13 |
from yolox.core import launch
|
14 |
+
from yolox.exp import Exp, check_exp_value, get_exp
|
15 |
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
|
16 |
|
17 |
|
|
|
123 |
args = make_parser().parse_args()
|
124 |
exp = get_exp(args.exp_file, args.name)
|
125 |
exp.merge(args.opts)
|
126 |
+
check_exp_value(exp)
|
127 |
|
128 |
if not args.experiment_name:
|
129 |
args.experiment_name = exp.exp_name
|
yolox/exp/__init__.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
5 |
from .base_exp import BaseExp
|
6 |
from .build import get_exp
|
7 |
-
from .yolox_base import Exp
|
|
|
1 |
#!/usr/bin/env python3
|
|
|
2 |
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
|
4 |
from .base_exp import BaseExp
|
5 |
from .build import get_exp
|
6 |
+
from .yolox_base import Exp, check_exp_value
|
yolox/exp/base_exp.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
5 |
import ast
|
@@ -66,7 +65,7 @@ class BaseExp(metaclass=ABCMeta):
|
|
66 |
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
|
67 |
|
68 |
def merge(self, cfg_list):
|
69 |
-
assert len(cfg_list) % 2 == 0
|
70 |
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
71 |
# only update value with same key
|
72 |
if hasattr(self, k):
|
@@ -74,7 +73,7 @@ class BaseExp(metaclass=ABCMeta):
|
|
74 |
src_type = type(src_value)
|
75 |
|
76 |
# pre-process input if source type is list or tuple
|
77 |
-
if isinstance(src_value, List
|
78 |
v = v.strip("[]()")
|
79 |
v = [t.strip() for t in v.split(",")]
|
80 |
|
|
|
1 |
#!/usr/bin/env python3
|
|
|
2 |
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
|
4 |
import ast
|
|
|
65 |
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
|
66 |
|
67 |
def merge(self, cfg_list):
|
68 |
+
assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}"
|
69 |
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
70 |
# only update value with same key
|
71 |
if hasattr(self, k):
|
|
|
73 |
src_type = type(src_value)
|
74 |
|
75 |
# pre-process input if source type is list or tuple
|
76 |
+
if isinstance(src_value, (List, Tuple)):
|
77 |
v = v.strip("[]()")
|
78 |
v = [t.strip() for t in v.split(",")]
|
79 |
|
yolox/exp/yolox_base.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
5 |
import os
|
@@ -11,6 +10,8 @@ import torch.nn as nn
|
|
11 |
|
12 |
from .base_exp import BaseExp
|
13 |
|
|
|
|
|
14 |
|
15 |
class Exp(BaseExp):
|
16 |
def __init__(self):
|
@@ -350,3 +351,8 @@ class Exp(BaseExp):
|
|
350 |
|
351 |
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
|
352 |
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/usr/bin/env python3
|
|
|
2 |
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
|
4 |
import os
|
|
|
10 |
|
11 |
from .base_exp import BaseExp
|
12 |
|
13 |
+
__all__ = ["Exp", "check_exp_value"]
|
14 |
+
|
15 |
|
16 |
class Exp(BaseExp):
|
17 |
def __init__(self):
|
|
|
351 |
|
352 |
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
|
353 |
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
|
354 |
+
|
355 |
+
|
356 |
+
def check_exp_value(exp: Exp):
|
357 |
+
h, w = exp.input_size
|
358 |
+
assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32"
|
yolox/utils/logger.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
5 |
import inspect
|
@@ -58,7 +57,8 @@ class StreamToLoguru:
|
|
58 |
sys.__stdout__.write(buf)
|
59 |
|
60 |
def flush(self):
|
61 |
-
|
|
|
62 |
|
63 |
def isatty(self):
|
64 |
# when using colab, jax is installed by default and issue like
|
@@ -66,7 +66,11 @@ class StreamToLoguru:
|
|
66 |
# due to missing attribute like`isatty`.
|
67 |
# For more details, checked the following link:
|
68 |
# https://github.com/google/jax/blob/10720258ea7fb5bde997dfa2f3f71135ab7a6733/jax/_src/pretty_printer.py#L54 # noqa
|
69 |
-
return
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
def redirect_sys_output(log_level="INFO"):
|
|
|
1 |
#!/usr/bin/env python3
|
|
|
2 |
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
|
4 |
import inspect
|
|
|
57 |
sys.__stdout__.write(buf)
|
58 |
|
59 |
def flush(self):
|
60 |
+
# flush is related with CPR(cursor position report) in terminal
|
61 |
+
return sys.__stdout__.flush()
|
62 |
|
63 |
def isatty(self):
|
64 |
# when using colab, jax is installed by default and issue like
|
|
|
66 |
# due to missing attribute like`isatty`.
|
67 |
# For more details, checked the following link:
|
68 |
# https://github.com/google/jax/blob/10720258ea7fb5bde997dfa2f3f71135ab7a6733/jax/_src/pretty_printer.py#L54 # noqa
|
69 |
+
return sys.__stdout__.isatty()
|
70 |
+
|
71 |
+
def fileno(self):
|
72 |
+
# To solve the issue when using debug tools like pdb
|
73 |
+
return sys.__stdout__.fileno()
|
74 |
|
75 |
|
76 |
def redirect_sys_output(log_level="INFO"):
|