rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
3.47 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import torch
from detectron2.config import get_cfg
from detectron2.engine import default_setup
from detectron2.modeling import build_model
from densepose import add_densepose_config
_BASE_CONFIG_DIR = "configs"
_EVOLUTION_CONFIG_SUB_DIR = "evolution"
_HRNET_CONFIG_SUB_DIR = "HRNet"
_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules"
_BASE_CONFIG_FILE_PREFIX = "Base-"
_CONFIG_FILE_EXT = ".yaml"
def _get_base_config_dir():
"""
Return the base directory for configurations
"""
return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR)
def _get_evolution_config_dir():
"""
Return the base directory for evolution configurations
"""
return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR)
def _get_hrnet_config_dir():
"""
Return the base directory for HRNet configurations
"""
return os.path.join(_get_base_config_dir(), _HRNET_CONFIG_SUB_DIR)
def _get_quick_schedules_config_dir():
"""
Return the base directory for quick schedules configurations
"""
return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR)
def _collect_config_files(config_dir):
"""
Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory
"""
start = _get_base_config_dir()
results = []
for entry in os.listdir(config_dir):
path = os.path.join(config_dir, entry)
if not os.path.isfile(path):
continue
_, ext = os.path.splitext(entry)
if ext != _CONFIG_FILE_EXT:
continue
if entry.startswith(_BASE_CONFIG_FILE_PREFIX):
continue
config_file = os.path.relpath(path, start)
results.append(config_file)
return results
def get_config_files():
"""
Get all the configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_base_config_dir())
def get_evolution_config_files():
"""
Get all the evolution configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_evolution_config_dir())
def get_hrnet_config_files():
"""
Get all the HRNet configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_hrnet_config_dir())
def get_quick_schedules_config_files():
"""
Get all the quick schedules configuration files (relative to the base configuration directory)
"""
return _collect_config_files(_get_quick_schedules_config_dir())
def get_model_config(config_file):
"""
Load and return the configuration from the specified file (relative to the base configuration
directory)
"""
cfg = get_cfg()
add_densepose_config(cfg)
path = os.path.join(_get_base_config_dir(), config_file)
cfg.merge_from_file(path)
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
return cfg
def get_model(config_file):
"""
Get the model from the specified file (relative to the base configuration directory)
"""
cfg = get_model_config(config_file)
return build_model(cfg)
def setup(config_file):
"""
Setup the configuration from the specified file (relative to the base configuration directory)
"""
cfg = get_model_config(config_file)
cfg.freeze()
default_setup(cfg, {})