Spaces:
Runtime error
Runtime error
import importlib | |
import torch | |
import os | |
from collections import OrderedDict | |
def get_func(func_name): | |
"""Helper to return a function object by name. func_name must identify a | |
function in this module or the path to a function relative to the base | |
'modeling' module. | |
""" | |
if func_name == '': | |
return None | |
try: | |
parts = func_name.split('.') | |
# Refers to a function in this module | |
if len(parts) == 1: | |
return globals()[parts[0]] | |
# Otherwise, assume we're referencing a module under modeling | |
module_name = 'lib.' + '.'.join(parts[:-1]) | |
module = importlib.import_module(module_name) | |
return getattr(module, parts[-1]) | |
except Exception: | |
print('Failed to f1ind function: %s', func_name) | |
raise | |
def load_ckpt(args, depth_model, shift_model, focal_model): | |
""" | |
Load checkpoint. | |
""" | |
if os.path.isfile(args.load_ckpt): | |
print("loading checkpoint %s" % args.load_ckpt) | |
checkpoint = torch.load(args.load_ckpt) | |
if shift_model is not None: | |
shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), | |
strict=True) | |
if focal_model is not None: | |
focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), | |
strict=True) | |
depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), | |
strict=True) | |
del checkpoint | |
torch.cuda.empty_cache() | |
def strip_prefix_if_present(state_dict, prefix): | |
keys = sorted(state_dict.keys()) | |
if not all(key.startswith(prefix) for key in keys): | |
return state_dict | |
stripped_state_dict = OrderedDict() | |
for key, value in state_dict.items(): | |
stripped_state_dict[key.replace(prefix, "")] = value | |
return stripped_state_dict |