c-gohlke commited on
Commit
d4cabc4
·
1 Parent(s): 82515b9

Update Space

Browse files
Files changed (1) hide show
  1. app.py +48 -22
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import yaml
6
  from pathlib import Path
7
  from loguru import logger
 
8
 
9
  SPACE_REPO = "c-gohlke/litrl"
10
  SPACE_REPO_TYPE = "space"
@@ -20,7 +21,22 @@ login( # type: ignore[no-untyped-call]
20
  new_session=False,
21
  )
22
 
23
- def get_best_gif(env: str) -> Path:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  hf_env_results_path = f"models/{env}/results.yaml"
25
  local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
26
  with open(local_env_results_path, "r") as f:
@@ -32,13 +48,11 @@ def get_best_gif(env: str) -> Path:
32
  model_results = yaml.load(f, Loader=yaml.FullLoader)
33
 
34
  best_model = max(model_results, key=lambda model: model_results[model])
35
- hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
36
  return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
37
 
38
-
39
-
40
  def get_environments() -> list[str]:
41
- environments = []
42
  files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
43
  for file in files:
44
  vals = file.split("/")
@@ -47,45 +61,57 @@ def get_environments() -> list[str]:
47
  environments.append(vals[1])
48
  return environments
49
 
50
- def get_gif_paths(environments: list[str]) -> dict[str, str]:
51
- gif_paths = {}
 
 
 
 
 
 
52
  for env in environments:
53
- gif_paths[env] = get_best_gif(env)
54
- return gif_paths
55
 
56
  def run_demo() -> None:
57
  environments = get_environments()
58
- gif_paths = get_gif_paths(environments)
 
59
 
60
- def api_get_text(env_id: str)-> str:
61
  logger.info(f"Getting text for {env_id}")
62
- return gr.Markdown("# Greetings from LitRL!")
63
 
64
- def api_predict(env_id: str)-> gr.Image|None:
65
- if env_id not in gif_paths:
66
- logger.error(f"Environment {env_id} not found in {gif_paths}")
 
 
 
 
67
  return None
68
- return gr.Image(gif_paths[env_id], type="filepath")
69
 
70
  with gr.Blocks() as demo:
71
  md = gr.Markdown("# Greetings from LitRL!")
72
- env = gr.Dropdown(choices=gif_paths.keys(), value="CartPole-v1", label="Environment")
73
  button = gr.Button(value="Submit")
74
- cartpole_out = gr.Image(gif_paths[env.value], type="filepath")
75
 
76
- button.click(
77
  fn=api_get_text,
78
  inputs=env,
79
  outputs=md,
80
  api_name="get_text",
81
  )
82
- button.click(
83
  fn=api_predict,
84
  inputs=env,
85
  outputs=cartpole_out,
86
  api_name="predict",
87
  )
88
- demo.launch()
 
89
 
90
  if __name__ == "__main__":
91
- run_demo()
 
5
  import yaml
6
  from pathlib import Path
7
  from loguru import logger
8
+ from PIL import Image
9
 
10
  SPACE_REPO = "c-gohlke/litrl"
11
  SPACE_REPO_TYPE = "space"
 
21
  new_session=False,
22
  )
23
 
24
+ # def get_best_gif(env: str) -> Path:
25
+ # hf_env_results_path = f"models/{env}/results.yaml"
26
+ # local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
27
+ # with open(local_env_results_path, "r") as f:
28
+ # env_results = yaml.load(f, Loader=yaml.FullLoader)
29
+ # best_model_type = max(env_results, key=lambda model: env_results[model])
30
+ # model_results_path = f"models/{env}/{best_model_type}/results.yaml"
31
+ # local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
32
+ # with open(local_model_results_path, "r") as f:
33
+ # model_results = yaml.load(f, Loader=yaml.FullLoader)
34
+
35
+ # best_model = max(model_results, key=lambda model: model_results[model])
36
+ # hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
37
+ # return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
38
+
39
+ def get_best_mp4(env: str) -> Path:
40
  hf_env_results_path = f"models/{env}/results.yaml"
41
  local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
42
  with open(local_env_results_path, "r") as f:
 
48
  model_results = yaml.load(f, Loader=yaml.FullLoader)
49
 
50
  best_model = max(model_results, key=lambda model: model_results[model])
51
+ hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
52
  return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
53
 
 
 
54
  def get_environments() -> list[str]:
55
+ environments: list[str] = []
56
  files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
57
  for file in files:
58
  vals = file.split("/")
 
61
  environments.append(vals[1])
62
  return environments
63
 
64
+ # def get_gif_paths(environments: list[str]) -> dict[str, Path]:
65
+ # gif_paths: dict[str, Path] = {}
66
+ # for env in environments:
67
+ # gif_paths[env] = get_best_gif(env)
68
+ # return gif_paths
69
+
70
+ def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
71
+ mp4_paths: dict[str, Path] = {}
72
  for env in environments:
73
+ mp4_paths[env] = get_best_mp4(env)
74
+ return mp4_paths
75
 
76
  def run_demo() -> None:
77
  environments = get_environments()
78
+ # gif_paths = get_gif_paths(environments)
79
+ mp4_paths = get_mp4_paths(environments)
80
 
81
+ def api_get_text(env_id: str)-> gr.Markdown:
82
  logger.info(f"Getting text for {env_id}")
83
+ return gr.Markdown("# Greetings from LitRL!")
84
 
85
+ def api_predict(env_id: str)-> bytes:
86
+ # if env_id not in gif_paths:
87
+ # logger.error(f"Environment {env_id} not found in {gif_paths}")
88
+ # return None
89
+ # return Image.open(gif_paths[env_id], formats=["gif"])
90
+ if env_id not in mp4_paths:
91
+ logger.error(f"Environment {env_id} not found in {mp4_paths}")
92
  return None
93
+ return gr.Video(mp4_paths[env_id])
94
 
95
  with gr.Blocks() as demo:
96
  md = gr.Markdown("# Greetings from LitRL!")
97
+ env = gr.Dropdown(choices=list(mp4_paths.keys()), value="CartPole-v1", label="Environment")
98
  button = gr.Button(value="Submit")
99
+ cartpole_out = gr.Video(mp4_paths[env.value], autoplay=True)
100
 
101
+ button.click( # type: ignore[no-untyped-call]
102
  fn=api_get_text,
103
  inputs=env,
104
  outputs=md,
105
  api_name="get_text",
106
  )
107
+ button.click( # type: ignore[no-untyped-call]
108
  fn=api_predict,
109
  inputs=env,
110
  outputs=cartpole_out,
111
  api_name="predict",
112
  )
113
+
114
+ demo.launch() # type: ignore[no-untyped-call]
115
 
116
  if __name__ == "__main__":
117
+ run_demo()