File size: 4,751 Bytes
ec0f03f
 
 
 
 
 
 
d4cabc4
ec0f03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4cabc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0f03f
 
 
 
 
 
 
 
 
 
 
d4cabc4
ec0f03f
 
 
d4cabc4
ec0f03f
 
 
 
 
 
 
 
d4cabc4
 
 
 
 
 
 
 
ec0f03f
d4cabc4
 
ec0f03f
 
 
d4cabc4
 
ec0f03f
d4cabc4
ec0f03f
d4cabc4
ec0f03f
d4cabc4
 
 
 
 
 
 
ec0f03f
d4cabc4
ec0f03f
 
 
d4cabc4
ec0f03f
d4cabc4
ec0f03f
d4cabc4
ec0f03f
 
 
 
 
d4cabc4
ec0f03f
 
 
 
 
d4cabc4
 
ec0f03f
 
d4cabc4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, login
import os
import yaml
from pathlib import Path
from loguru import logger
from PIL import Image

SPACE_REPO = "c-gohlke/litrl"
SPACE_REPO_TYPE = "space"
MODEL_REPO = "c-gohlke/litrl"
MODEL_REPO_TYPE = "model"

ENV_RESULTS_FILE_DEPTH = 3

hf_api = HfApi()
login(  # type: ignore[no-untyped-call]
    token=os.environ.get("HUGGINGFACE_TOKEN"),
    add_to_git_credential=True,
    new_session=False,
)

# def get_best_gif(env: str) -> Path:
#     hf_env_results_path = f"models/{env}/results.yaml"
#     local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
#     with open(local_env_results_path, "r") as f:
#         env_results = yaml.load(f, Loader=yaml.FullLoader)
#     best_model_type = max(env_results, key=lambda model: env_results[model])
#     model_results_path = f"models/{env}/{best_model_type}/results.yaml"
#     local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
#     with open(local_model_results_path, "r") as f:
#         model_results = yaml.load(f, Loader=yaml.FullLoader)

#     best_model = max(model_results, key=lambda model: model_results[model])
#     hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
#     return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))

def get_best_mp4(env: str) -> Path:
    hf_env_results_path = f"models/{env}/results.yaml"
    local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
    with open(local_env_results_path, "r") as f:
        env_results = yaml.load(f, Loader=yaml.FullLoader)
    best_model_type = max(env_results, key=lambda model: env_results[model])
    model_results_path = f"models/{env}/{best_model_type}/results.yaml"
    local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
    with open(local_model_results_path, "r") as f:
        model_results = yaml.load(f, Loader=yaml.FullLoader)

    best_model = max(model_results, key=lambda model: model_results[model])
    hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
    return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))

def get_environments() -> list[str]:
    environments: list[str] = []
    files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
    for file in files:
        vals = file.split("/")
        # e.g. ['models', 'CartPole-v1', 'results.yaml']
        if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
            environments.append(vals[1])
    return environments

# def get_gif_paths(environments: list[str]) -> dict[str, Path]:
#     gif_paths: dict[str, Path] = {}
#     for env in environments:
#         gif_paths[env] = get_best_gif(env)
#     return gif_paths

def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
    mp4_paths: dict[str, Path] = {}
    for env in environments:
        mp4_paths[env] = get_best_mp4(env)
    return mp4_paths

def run_demo() -> None:
    environments = get_environments()
    # gif_paths = get_gif_paths(environments)
    mp4_paths = get_mp4_paths(environments)

    def api_get_text(env_id: str)-> gr.Markdown:
        logger.info(f"Getting text for {env_id}")
        return gr.Markdown("# Greetings from LitRL!")

    def api_predict(env_id: str)-> bytes:
        # if env_id not in gif_paths:
        #     logger.error(f"Environment {env_id} not found in {gif_paths}")
        #     return None
        # return Image.open(gif_paths[env_id], formats=["gif"])
        if env_id not in mp4_paths:
            logger.error(f"Environment {env_id} not found in {mp4_paths}")
            return None
        return gr.Video(mp4_paths[env_id])

    with gr.Blocks() as demo:
        md = gr.Markdown("# Greetings from LitRL!")
        env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment")
        button = gr.Button(value="Submit")
        cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True)

        button.click(  # type: ignore[no-untyped-call]
            fn=api_get_text,
            inputs=env,
            outputs=md,
            api_name="get_text",
        )
        button.click( # type: ignore[no-untyped-call]
            fn=api_predict,
            inputs=env,
            outputs=cartpole_out,
            api_name="predict",
        )

        demo.launch()  # type: ignore[no-untyped-call]

if __name__ == "__main__":
    run_demo()