|
import urllib.request |
|
import tarfile |
|
from tqdm import tqdm |
|
import os |
|
import yaml |
|
from ruamel.yaml import YAML |
|
|
|
def read_plainconfig(configname): |
|
if not os.path.exists(configname): |
|
raise FileNotFoundError( |
|
f"Config {configname} is not found. Please make sure that the file exists." |
|
) |
|
with open(configname) as file: |
|
return YAML().load(file) |
|
|
|
def DownloadModel(modelname, target_dir): |
|
""" |
|
Downloads a DeepLabCut Model Zoo Project |
|
""" |
|
|
|
def show_progress(count, block_size, total_size): |
|
pbar.update(block_size) |
|
|
|
def tarfilenamecutting(tarf): |
|
"""' auxfun to extract folder path |
|
ie. /xyz-trainsetxyshufflez/ |
|
""" |
|
for memberid, member in enumerate(tarf.getmembers()): |
|
if memberid == 0: |
|
parent = str(member.path) |
|
l = len(parent) + 1 |
|
if member.path.startswith(parent): |
|
member.path = member.path[l:] |
|
yield member |
|
|
|
neturls = read_plainconfig("./model/pretrained_model_urls.yaml") |
|
|
|
if modelname in neturls.keys(): |
|
url = neturls[modelname] |
|
print(url) |
|
response = urllib.request.urlopen(url) |
|
print( |
|
"Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format( |
|
url |
|
) |
|
) |
|
total_size = int(response.getheader("Content-Length")) |
|
pbar = tqdm(unit="B", total=total_size, position=0) |
|
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) |
|
with tarfile.open(filename, mode="r:gz") as tar: |
|
tar.extractall(target_dir, members=tarfilenamecutting(tar)) |
|
else: |
|
models = [ |
|
fn |
|
for fn in neturls.keys() |
|
if "resnet_" not in fn and "mobilenet_" not in fn |
|
] |
|
print("Model does not exist: ", modelname) |
|
print("Pick one of the following: ", models) |
|
return target_dir |
|
|