c-gohlke commited on
Commit
26b48fb
·
1 Parent(s): bafb458

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -117
app.py DELETED
@@ -1,117 +0,0 @@
1
- import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- from huggingface_hub import HfApi, login
4
- import os
5
- import yaml
6
- from pathlib import Path
7
- from loguru import logger
8
- from PIL import Image
9
-
10
- SPACE_REPO = "c-gohlke/litrl"
11
- SPACE_REPO_TYPE = "space"
12
- MODEL_REPO = "c-gohlke/litrl"
13
- MODEL_REPO_TYPE = "model"
14
-
15
- ENV_RESULTS_FILE_DEPTH = 3
16
-
17
- hf_api = HfApi()
18
- login( # type: ignore[no-untyped-call]
19
- token=os.environ.get("HUGGINGFACE_TOKEN"),
20
- add_to_git_credential=True,
21
- new_session=False,
22
- )
23
-
24
- # def get_best_gif(env: str) -> Path:
25
- # hf_env_results_path = f"models/{env}/results.yaml"
26
- # local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
27
- # with open(local_env_results_path, "r") as f:
28
- # env_results = yaml.load(f, Loader=yaml.FullLoader)
29
- # best_model_type = max(env_results, key=lambda model: env_results[model])
30
- # model_results_path = f"models/{env}/{best_model_type}/results.yaml"
31
- # local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
32
- # with open(local_model_results_path, "r") as f:
33
- # model_results = yaml.load(f, Loader=yaml.FullLoader)
34
-
35
- # best_model = max(model_results, key=lambda model: model_results[model])
36
- # hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
37
- # return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
38
-
39
- def get_best_mp4(env: str) -> Path:
40
- hf_env_results_path = f"models/{env}/results.yaml"
41
- local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
42
- with open(local_env_results_path, "r") as f:
43
- env_results = yaml.load(f, Loader=yaml.FullLoader)
44
- best_model_type = max(env_results, key=lambda model: env_results[model])
45
- model_results_path = f"models/{env}/{best_model_type}/results.yaml"
46
- local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
47
- with open(local_model_results_path, "r") as f:
48
- model_results = yaml.load(f, Loader=yaml.FullLoader)
49
-
50
- best_model = max(model_results, key=lambda model: model_results[model])
51
- hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
52
- return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
53
-
54
- def get_environments() -> list[str]:
55
- environments: list[str] = []
56
- files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
57
- for file in files:
58
- vals = file.split("/")
59
- # e.g. ['models', 'CartPole-v1', 'results.yaml']
60
- if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
61
- environments.append(vals[1])
62
- return environments
63
-
64
- # def get_gif_paths(environments: list[str]) -> dict[str, Path]:
65
- # gif_paths: dict[str, Path] = {}
66
- # for env in environments:
67
- # gif_paths[env] = get_best_gif(env)
68
- # return gif_paths
69
-
70
- def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
71
- mp4_paths: dict[str, Path] = {}
72
- for env in environments:
73
- mp4_paths[env] = get_best_mp4(env)
74
- return mp4_paths
75
-
76
- def run_demo() -> None:
77
- environments = get_environments()
78
- # gif_paths = get_gif_paths(environments)
79
- mp4_paths = get_mp4_paths(environments)
80
-
81
- def api_get_text(env_id: str)-> gr.Markdown:
82
- logger.info(f"Getting text for {env_id}")
83
- return gr.Markdown("# Greetings from LitRL!")
84
-
85
- def api_predict(env_id: str)-> bytes:
86
- # if env_id not in gif_paths:
87
- # logger.error(f"Environment {env_id} not found in {gif_paths}")
88
- # return None
89
- # return Image.open(gif_paths[env_id], formats=["gif"])
90
- if env_id not in mp4_paths:
91
- logger.error(f"Environment {env_id} not found in {mp4_paths}")
92
- return None
93
- return gr.Video(mp4_paths[env_id])
94
-
95
- with gr.Blocks() as demo:
96
- md = gr.Markdown("# Greetings from LitRL!")
97
- env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment")
98
- button = gr.Button(value="Submit")
99
- cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True)
100
-
101
- button.click( # type: ignore[no-untyped-call]
102
- fn=api_get_text,
103
- inputs=env,
104
- outputs=md,
105
- api_name="get_text",
106
- )
107
- button.click( # type: ignore[no-untyped-call]
108
- fn=api_predict,
109
- inputs=env,
110
- outputs=cartpole_out,
111
- api_name="predict",
112
- )
113
-
114
- demo.launch() # type: ignore[no-untyped-call]
115
-
116
- if __name__ == "__main__":
117
- run_demo()