File size: 3,177 Bytes
ec0f03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82515b9
ec0f03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, login
import os
import yaml
from pathlib import Path
from loguru import logger

SPACE_REPO = "c-gohlke/litrl"
SPACE_REPO_TYPE = "space"
MODEL_REPO = "c-gohlke/litrl"
MODEL_REPO_TYPE = "model"

ENV_RESULTS_FILE_DEPTH = 3

hf_api = HfApi()
login(  # type: ignore[no-untyped-call]
    token=os.environ.get("HUGGINGFACE_TOKEN"),
    add_to_git_credential=True,
    new_session=False,
)

def get_best_gif(env: str) -> Path:
    hf_env_results_path = f"models/{env}/results.yaml"
    local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
    with open(local_env_results_path, "r") as f:
        env_results = yaml.load(f, Loader=yaml.FullLoader)
    best_model_type = max(env_results, key=lambda model: env_results[model])
    model_results_path = f"models/{env}/{best_model_type}/results.yaml"
    local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
    with open(local_model_results_path, "r") as f:
        model_results = yaml.load(f, Loader=yaml.FullLoader)

    best_model = max(model_results, key=lambda model: model_results[model])
    hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
    return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))



def get_environments() -> list[str]:
    environments = []
    files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
    for file in files:
        vals = file.split("/")
        # e.g. ['models', 'CartPole-v1', 'results.yaml']
        if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
            environments.append(vals[1])
    return environments

def get_gif_paths(environments: list[str]) -> dict[str, str]:
    gif_paths = {}
    for env in environments:
        gif_paths[env] = get_best_gif(env)
    return gif_paths

def run_demo() -> None:
    environments = get_environments()
    gif_paths = get_gif_paths(environments)

    def api_get_text(env_id: str)-> str:
        logger.info(f"Getting text for {env_id}")
        return  gr.Markdown("# Greetings from LitRL!")

    def api_predict(env_id: str)-> gr.Image|None:
        if env_id not in gif_paths:
            logger.error(f"Environment {env_id} not found in {gif_paths}")
            return None
        return gr.Image(gif_paths[env_id], type="filepath")

    with gr.Blocks() as demo:
        md = gr.Markdown("# Greetings from LitRL!")
        env = gr.Dropdown(choices=gif_paths.keys(), value="CartPole-v1", label="Environment")
        button = gr.Button(value="Submit")
        cartpole_out = gr.Image(gif_paths[env.value], type="filepath")

        button.click(
            fn=api_get_text,
            inputs=env,
            outputs=md,
            api_name="get_text",
        )
        button.click(
            fn=api_predict,
            inputs=env,
            outputs=cartpole_out,
            api_name="predict",
        )
        demo.launch()

if __name__ == "__main__":
    run_demo()