Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
import click | |
import pandas as pd | |
import yaml | |
from torchnlp.download import download_file_maybe_extract | |
from .estimators import PolosEstimator, QualityEstimator | |
from .model_base import ModelBase | |
from .ranking import PolosRanker | |
str2model = { | |
"PolosEstimator": PolosEstimator, | |
"PolosRanker": PolosRanker, | |
# Model that use source only: | |
"QualityEstimator": QualityEstimator, | |
} | |
def get_cache_folder(): | |
cache_directory = "./.cache/" | |
if not os.path.exists(cache_directory): | |
os.makedirs(cache_directory) | |
return cache_directory | |
def download_model(model: str, saving_directory: str = None) -> ModelBase: | |
"""Function that loads pretrained models from AWS. | |
:param model: Name of the model to be loaded. | |
:param saving_directory: RELATIVE path to the saving folder (must end with /). | |
Return: | |
- Pretrained model. | |
""" | |
if saving_directory is None: | |
saving_directory = get_cache_folder() | |
if not os.path.exists(saving_directory): | |
os.makedirs(saving_directory) | |
if os.path.exists(saving_directory + "reprod/reprod.ckpt"): | |
return saving_directory + "reprod/reprod.ckpt" | |
models = {"polos" : "https://polos-polaris.s3.ap-northeast-1.amazonaws.com/reprod.zip"} | |
if os.path.isdir(saving_directory + model): | |
click.secho(f"{model} is already in cache.", fg="yellow") | |
if not model.endswith("/"): | |
model += "/" | |
elif model not in models.keys(): | |
raise Exception(f"{model} is not a valid Polos model!") | |
elif models[model].startswith("https://"): | |
download_file_maybe_extract(models[model], directory=saving_directory) | |
else: | |
raise Exception("Something went wrong while dowloading the model!") | |
if os.path.exists(saving_directory + model + ".zip"): | |
os.remove(saving_directory + model + ".zip") | |
click.secho("Download succeeded. Loading model...", fg="yellow") | |
experiment_folder = saving_directory + "reprod" | |
checkpoints = [ | |
file for file in os.listdir(experiment_folder) if file.endswith(".ckpt") | |
] | |
checkpoint = checkpoints[-1] | |
checkpoint_path = experiment_folder + "/" + checkpoint | |
return checkpoint_path | |
def load_checkpoint(checkpoint: str) -> ModelBase: | |
"""Function that loads a model from a checkpoint file. | |
:param checkpoint: Path to the checkpoint file. | |
Returns: | |
- Polos Model | |
""" | |
if not os.path.exists(checkpoint): | |
raise Exception(f"{checkpoint} file not found!") | |
tags_csv_file = "/".join(checkpoint.split("/")[:-1] + ["meta_tags.csv"]) | |
hparam_yaml_file = "/".join(checkpoint.split("/")[:-1] + ["hparams.yaml"]) | |
if os.path.exists(tags_csv_file): | |
# Uggly convertion from older Lightning checkpoints | |
tags = pd.read_csv( | |
tags_csv_file, header=None, index_col=0, squeeze=True | |
).to_dict() | |
hparams = {} | |
for k, v in tags.items(): | |
if isinstance(v, str) and v.replace(".", "", 1).isdigit(): | |
hparams[k] = float(v) if "." in v else int(v) | |
else: | |
hparams[k] = v | |
model = str2model[tags["model"]].load_from_checkpoint( | |
checkpoint, hparams=hparams | |
) | |
elif os.path.exists(hparam_yaml_file): | |
with open(hparam_yaml_file) as yaml_file: | |
hparams = yaml.load(yaml_file.read(), Loader=yaml.FullLoader) | |
model = str2model[hparams["model"]].load_from_checkpoint( | |
checkpoint, hparams=hparams | |
) | |
else: | |
raise Exception( | |
"[meta_tags.csv|hparams.yaml is missing from the checkpoint folder." | |
" Please clean your cache folder (~/.cache/torch/yuigawada/) and try to download the model again." | |
) | |
model.eval() | |
model.freeze() | |
return model | |