Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import os | |
import torch | |
from detectron2.config import get_cfg | |
from detectron2.engine import default_setup | |
from detectron2.modeling import build_model | |
from densepose import add_densepose_config | |
_BASE_CONFIG_DIR = "configs" | |
_EVOLUTION_CONFIG_SUB_DIR = "evolution" | |
_HRNET_CONFIG_SUB_DIR = "HRNet" | |
_QUICK_SCHEDULES_CONFIG_SUB_DIR = "quick_schedules" | |
_BASE_CONFIG_FILE_PREFIX = "Base-" | |
_CONFIG_FILE_EXT = ".yaml" | |
def _get_base_config_dir(): | |
""" | |
Return the base directory for configurations | |
""" | |
return os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", _BASE_CONFIG_DIR) | |
def _get_evolution_config_dir(): | |
""" | |
Return the base directory for evolution configurations | |
""" | |
return os.path.join(_get_base_config_dir(), _EVOLUTION_CONFIG_SUB_DIR) | |
def _get_hrnet_config_dir(): | |
""" | |
Return the base directory for HRNet configurations | |
""" | |
return os.path.join(_get_base_config_dir(), _HRNET_CONFIG_SUB_DIR) | |
def _get_quick_schedules_config_dir(): | |
""" | |
Return the base directory for quick schedules configurations | |
""" | |
return os.path.join(_get_base_config_dir(), _QUICK_SCHEDULES_CONFIG_SUB_DIR) | |
def _collect_config_files(config_dir): | |
""" | |
Collect all configuration files (i.e. densepose_*.yaml) directly in the specified directory | |
""" | |
start = _get_base_config_dir() | |
results = [] | |
for entry in os.listdir(config_dir): | |
path = os.path.join(config_dir, entry) | |
if not os.path.isfile(path): | |
continue | |
_, ext = os.path.splitext(entry) | |
if ext != _CONFIG_FILE_EXT: | |
continue | |
if entry.startswith(_BASE_CONFIG_FILE_PREFIX): | |
continue | |
config_file = os.path.relpath(path, start) | |
results.append(config_file) | |
return results | |
def get_config_files(): | |
""" | |
Get all the configuration files (relative to the base configuration directory) | |
""" | |
return _collect_config_files(_get_base_config_dir()) | |
def get_evolution_config_files(): | |
""" | |
Get all the evolution configuration files (relative to the base configuration directory) | |
""" | |
return _collect_config_files(_get_evolution_config_dir()) | |
def get_hrnet_config_files(): | |
""" | |
Get all the HRNet configuration files (relative to the base configuration directory) | |
""" | |
return _collect_config_files(_get_hrnet_config_dir()) | |
def get_quick_schedules_config_files(): | |
""" | |
Get all the quick schedules configuration files (relative to the base configuration directory) | |
""" | |
return _collect_config_files(_get_quick_schedules_config_dir()) | |
def get_model_config(config_file): | |
""" | |
Load and return the configuration from the specified file (relative to the base configuration | |
directory) | |
""" | |
cfg = get_cfg() | |
add_densepose_config(cfg) | |
path = os.path.join(_get_base_config_dir(), config_file) | |
cfg.merge_from_file(path) | |
if not torch.cuda.is_available(): | |
cfg.MODEL.DEVICE = "cpu" | |
return cfg | |
def get_model(config_file): | |
""" | |
Get the model from the specified file (relative to the base configuration directory) | |
""" | |
cfg = get_model_config(config_file) | |
return build_model(cfg) | |
def setup(config_file): | |
""" | |
Setup the configuration from the specified file (relative to the base configuration directory) | |
""" | |
cfg = get_model_config(config_file) | |
cfg.freeze() | |
default_setup(cfg, {}) | |