c-gohlke commited on
Commit
ec0f03f
·
1 Parent(s): a7445ad

Update Space

Browse files
Files changed (2) hide show
  1. app.py +90 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from huggingface_hub import HfApi, login
4
+ import os
5
+ import yaml
6
+ from pathlib import Path
7
+ from loguru import logger
8
+
9
+ SPACE_REPO = "c-gohlke/litrl"
10
+ SPACE_REPO_TYPE = "space"
11
+ MODEL_REPO = "c-gohlke/litrl"
12
+ MODEL_REPO_TYPE = "model"
13
+
14
+ ENV_RESULTS_FILE_DEPTH = 3
15
+
16
+ hf_api = HfApi()
17
+ login( # type: ignore[no-untyped-call]
18
+ token=os.environ.get("HUGGINGFACE_TOKEN"),
19
+ add_to_git_credential=True,
20
+ new_session=False,
21
+ )
22
+
23
+ def get_best_gif(env: str) -> Path:
24
+ hf_env_results_path = f"models/{env}/results.yaml"
25
+ local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
26
+ with open(local_env_results_path, "r") as f:
27
+ env_results = yaml.load(f, Loader=yaml.FullLoader)
28
+ best_model_type = max(env_results, key=lambda model: env_results[model])
29
+ model_results_path = f"models/{env}/{best_model_type}/results.yaml"
30
+ local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
31
+ with open(local_model_results_path, "r") as f:
32
+ model_results = yaml.load(f, Loader=yaml.FullLoader)
33
+
34
+ best_model = max(model_results, key=lambda model: model_results[model])
35
+ hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
36
+ return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
37
+
38
+
39
+
40
+ def get_environments() -> list[str]:
41
+ environments = []
42
+ files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
43
+ for file in files:
44
+ vals = file.split("/")
45
+ # e.g. ['models', 'CartPole-v1', 'results.yaml']
46
+ if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
47
+ environments.append(vals[1])
48
+ return environments
49
+
50
+ def get_gif_paths(environments: list[str]) -> dict[str, str]:
51
+ gif_paths = {}
52
+ for env in environments:
53
+ gif_paths[env] = get_best_gif(env)
54
+
55
+ def run_demo() -> None:
56
+ environments = get_environments()
57
+ gif_paths = get_gif_paths(environments)
58
+
59
+ def api_get_text(env_id: str)-> str:
60
+ logger.info(f"Getting text for {env_id}")
61
+ return gr.Markdown("# Greetings from LitRL!")
62
+
63
+ def api_predict(env_id: str)-> gr.Image|None:
64
+ if env_id not in gif_paths:
65
+ logger.error(f"Environment {env_id} not found in {gif_paths}")
66
+ return None
67
+ return gr.Image(gif_paths[env_id], type="filepath")
68
+
69
+ with gr.Blocks() as demo:
70
+ md = gr.Markdown("# Greetings from LitRL!")
71
+ env = gr.Dropdown(choices=gif_paths.keys(), value="CartPole-v1", label="Environment")
72
+ button = gr.Button(value="Submit")
73
+ cartpole_out = gr.Image(gif_paths[env.value], type="filepath")
74
+
75
+ button.click(
76
+ fn=api_get_text,
77
+ inputs=env,
78
+ outputs=md,
79
+ api_name="get_text",
80
+ )
81
+ button.click(
82
+ fn=api_predict,
83
+ inputs=env,
84
+ outputs=cartpole_out,
85
+ api_name="predict",
86
+ )
87
+ demo.launch()
88
+
89
+ if __name__ == "__main__":
90
+ run_demo()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ loguru==0.7.2