File size: 1,152 Bytes
bafb458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def get_best_gif(env: str) -> Path:
    hf_env_results_path = f"models/{env}/results.yaml"
    local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
    with open(local_env_results_path, "r") as f:
        env_results = yaml.load(f, Loader=yaml.FullLoader)
    best_model_type = max(env_results, key=lambda model: env_results[model])
    model_results_path = f"models/{env}/{best_model_type}/results.yaml"
    local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
    with open(local_model_results_path, "r") as f:
        model_results = yaml.load(f, Loader=yaml.FullLoader)

    best_model = max(model_results, key=lambda model: model_results[model])
    hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
    return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))



def get_gif_paths(environments: list[str]) -> dict[str, Path]:
    gif_paths: dict[str, Path] = {}
    for env in environments:
        gif_paths[env] = get_best_gif(env)
    return gif_paths