Spaces:
Build error
Build error
File size: 4,751 Bytes
ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 ec0f03f d4cabc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, login
import os
import yaml
from pathlib import Path
from loguru import logger
from PIL import Image
SPACE_REPO = "c-gohlke/litrl"
SPACE_REPO_TYPE = "space"
MODEL_REPO = "c-gohlke/litrl"
MODEL_REPO_TYPE = "model"
ENV_RESULTS_FILE_DEPTH = 3
hf_api = HfApi()
login( # type: ignore[no-untyped-call]
token=os.environ.get("HUGGINGFACE_TOKEN"),
add_to_git_credential=True,
new_session=False,
)
# def get_best_gif(env: str) -> Path:
# hf_env_results_path = f"models/{env}/results.yaml"
# local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
# with open(local_env_results_path, "r") as f:
# env_results = yaml.load(f, Loader=yaml.FullLoader)
# best_model_type = max(env_results, key=lambda model: env_results[model])
# model_results_path = f"models/{env}/{best_model_type}/results.yaml"
# local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
# with open(local_model_results_path, "r") as f:
# model_results = yaml.load(f, Loader=yaml.FullLoader)
# best_model = max(model_results, key=lambda model: model_results[model])
# hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
# return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
def get_best_mp4(env: str) -> Path:
hf_env_results_path = f"models/{env}/results.yaml"
local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
with open(local_env_results_path, "r") as f:
env_results = yaml.load(f, Loader=yaml.FullLoader)
best_model_type = max(env_results, key=lambda model: env_results[model])
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
with open(local_model_results_path, "r") as f:
model_results = yaml.load(f, Loader=yaml.FullLoader)
best_model = max(model_results, key=lambda model: model_results[model])
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
def get_environments() -> list[str]:
environments: list[str] = []
files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
for file in files:
vals = file.split("/")
# e.g. ['models', 'CartPole-v1', 'results.yaml']
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
environments.append(vals[1])
return environments
# def get_gif_paths(environments: list[str]) -> dict[str, Path]:
# gif_paths: dict[str, Path] = {}
# for env in environments:
# gif_paths[env] = get_best_gif(env)
# return gif_paths
def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
mp4_paths: dict[str, Path] = {}
for env in environments:
mp4_paths[env] = get_best_mp4(env)
return mp4_paths
def run_demo() -> None:
environments = get_environments()
# gif_paths = get_gif_paths(environments)
mp4_paths = get_mp4_paths(environments)
def api_get_text(env_id: str)-> gr.Markdown:
logger.info(f"Getting text for {env_id}")
return gr.Markdown("# Greetings from LitRL!")
def api_predict(env_id: str)-> bytes:
# if env_id not in gif_paths:
# logger.error(f"Environment {env_id} not found in {gif_paths}")
# return None
# return Image.open(gif_paths[env_id], formats=["gif"])
if env_id not in mp4_paths:
logger.error(f"Environment {env_id} not found in {mp4_paths}")
return None
return gr.Video(mp4_paths[env_id])
with gr.Blocks() as demo:
md = gr.Markdown("# Greetings from LitRL!")
env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment")
button = gr.Button(value="Submit")
cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True)
button.click( # type: ignore[no-untyped-call]
fn=api_get_text,
inputs=env,
outputs=md,
api_name="get_text",
)
button.click( # type: ignore[no-untyped-call]
fn=api_predict,
inputs=env,
outputs=cartpole_out,
api_name="predict",
)
demo.launch() # type: ignore[no-untyped-call]
if __name__ == "__main__":
run_demo()
|