Spaces:
Build error
Build error
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub import HfApi, login | |
import os | |
import yaml | |
from pathlib import Path | |
from loguru import logger | |
from PIL import Image | |
SPACE_REPO = "c-gohlke/litrl" | |
SPACE_REPO_TYPE = "space" | |
MODEL_REPO = "c-gohlke/litrl" | |
MODEL_REPO_TYPE = "model" | |
ENV_RESULTS_FILE_DEPTH = 3 | |
hf_api = HfApi() | |
login( # type: ignore[no-untyped-call] | |
token=os.environ.get("HUGGINGFACE_TOKEN"), | |
add_to_git_credential=True, | |
new_session=False, | |
) | |
# def get_best_gif(env: str) -> Path: | |
# hf_env_results_path = f"models/{env}/results.yaml" | |
# local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path) | |
# with open(local_env_results_path, "r") as f: | |
# env_results = yaml.load(f, Loader=yaml.FullLoader) | |
# best_model_type = max(env_results, key=lambda model: env_results[model]) | |
# model_results_path = f"models/{env}/{best_model_type}/results.yaml" | |
# local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path) | |
# with open(local_model_results_path, "r") as f: | |
# model_results = yaml.load(f, Loader=yaml.FullLoader) | |
# best_model = max(model_results, key=lambda model: model_results[model]) | |
# hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif" | |
# return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path)) | |
def get_best_mp4(env: str) -> Path: | |
hf_env_results_path = f"models/{env}/results.yaml" | |
local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path) | |
with open(local_env_results_path, "r") as f: | |
env_results = yaml.load(f, Loader=yaml.FullLoader) | |
best_model_type = max(env_results, key=lambda model: env_results[model]) | |
model_results_path = f"models/{env}/{best_model_type}/results.yaml" | |
local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path) | |
with open(local_model_results_path, "r") as f: | |
model_results = yaml.load(f, Loader=yaml.FullLoader) | |
best_model = max(model_results, key=lambda model: model_results[model]) | |
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4" | |
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path)) | |
def get_environments() -> list[str]: | |
environments: list[str] = [] | |
files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE) | |
for file in files: | |
vals = file.split("/") | |
# e.g. ['models', 'CartPole-v1', 'results.yaml'] | |
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models": | |
environments.append(vals[1]) | |
return environments | |
# def get_gif_paths(environments: list[str]) -> dict[str, Path]: | |
# gif_paths: dict[str, Path] = {} | |
# for env in environments: | |
# gif_paths[env] = get_best_gif(env) | |
# return gif_paths | |
def get_mp4_paths(environments: list[str]) -> dict[str, Path]: | |
mp4_paths: dict[str, Path] = {} | |
for env in environments: | |
mp4_paths[env] = get_best_mp4(env) | |
return mp4_paths | |
def run_demo() -> None: | |
environments = get_environments() | |
# gif_paths = get_gif_paths(environments) | |
mp4_paths = get_mp4_paths(environments) | |
def api_get_text(env_id: str)-> gr.Markdown: | |
logger.info(f"Getting text for {env_id}") | |
return gr.Markdown("# Greetings from LitRL!") | |
def api_predict(env_id: str)-> bytes: | |
# if env_id not in gif_paths: | |
# logger.error(f"Environment {env_id} not found in {gif_paths}") | |
# return None | |
# return Image.open(gif_paths[env_id], formats=["gif"]) | |
if env_id not in mp4_paths: | |
logger.error(f"Environment {env_id} not found in {mp4_paths}") | |
return None | |
return gr.Video(mp4_paths[env_id]) | |
with gr.Blocks() as demo: | |
md = gr.Markdown("# Greetings from LitRL!") | |
env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment") | |
button = gr.Button(value="Submit") | |
cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True) | |
button.click( # type: ignore[no-untyped-call] | |
fn=api_get_text, | |
inputs=env, | |
outputs=md, | |
api_name="get_text", | |
) | |
button.click( # type: ignore[no-untyped-call] | |
fn=api_predict, | |
inputs=env, | |
outputs=cartpole_out, | |
api_name="predict", | |
) | |
demo.launch() # type: ignore[no-untyped-call] | |
if __name__ == "__main__": | |
run_demo() | |