LitRL-Inference / app.py
c-gohlke's picture
Update Space
d4cabc4
raw
history blame
4.75 kB
import gradio as gr
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi, login
import os
import yaml
from pathlib import Path
from loguru import logger
from PIL import Image
SPACE_REPO = "c-gohlke/litrl"
SPACE_REPO_TYPE = "space"
MODEL_REPO = "c-gohlke/litrl"
MODEL_REPO_TYPE = "model"
ENV_RESULTS_FILE_DEPTH = 3
hf_api = HfApi()
login( # type: ignore[no-untyped-call]
token=os.environ.get("HUGGINGFACE_TOKEN"),
add_to_git_credential=True,
new_session=False,
)
# def get_best_gif(env: str) -> Path:
# hf_env_results_path = f"models/{env}/results.yaml"
# local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
# with open(local_env_results_path, "r") as f:
# env_results = yaml.load(f, Loader=yaml.FullLoader)
# best_model_type = max(env_results, key=lambda model: env_results[model])
# model_results_path = f"models/{env}/{best_model_type}/results.yaml"
# local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
# with open(local_model_results_path, "r") as f:
# model_results = yaml.load(f, Loader=yaml.FullLoader)
# best_model = max(model_results, key=lambda model: model_results[model])
# hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
# return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
def get_best_mp4(env: str) -> Path:
hf_env_results_path = f"models/{env}/results.yaml"
local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
with open(local_env_results_path, "r") as f:
env_results = yaml.load(f, Loader=yaml.FullLoader)
best_model_type = max(env_results, key=lambda model: env_results[model])
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
with open(local_model_results_path, "r") as f:
model_results = yaml.load(f, Loader=yaml.FullLoader)
best_model = max(model_results, key=lambda model: model_results[model])
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
def get_environments() -> list[str]:
environments: list[str] = []
files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
for file in files:
vals = file.split("/")
# e.g. ['models', 'CartPole-v1', 'results.yaml']
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
environments.append(vals[1])
return environments
# def get_gif_paths(environments: list[str]) -> dict[str, Path]:
# gif_paths: dict[str, Path] = {}
# for env in environments:
# gif_paths[env] = get_best_gif(env)
# return gif_paths
def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
mp4_paths: dict[str, Path] = {}
for env in environments:
mp4_paths[env] = get_best_mp4(env)
return mp4_paths
def run_demo() -> None:
environments = get_environments()
# gif_paths = get_gif_paths(environments)
mp4_paths = get_mp4_paths(environments)
def api_get_text(env_id: str)-> gr.Markdown:
logger.info(f"Getting text for {env_id}")
return gr.Markdown("# Greetings from LitRL!")
def api_predict(env_id: str)-> bytes:
# if env_id not in gif_paths:
# logger.error(f"Environment {env_id} not found in {gif_paths}")
# return None
# return Image.open(gif_paths[env_id], formats=["gif"])
if env_id not in mp4_paths:
logger.error(f"Environment {env_id} not found in {mp4_paths}")
return None
return gr.Video(mp4_paths[env_id])
with gr.Blocks() as demo:
md = gr.Markdown("# Greetings from LitRL!")
env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment")
button = gr.Button(value="Submit")
cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True)
button.click( # type: ignore[no-untyped-call]
fn=api_get_text,
inputs=env,
outputs=md,
api_name="get_text",
)
button.click( # type: ignore[no-untyped-call]
fn=api_predict,
inputs=env,
outputs=cartpole_out,
api_name="predict",
)
demo.launch() # type: ignore[no-untyped-call]
if __name__ == "__main__":
run_demo()