c-gohlke commited on
Commit
5ae8333
·
1 Parent(s): fdfa48b

Upload folder using huggingface_hub

Browse files
src/constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ SPACE_REPO = "c-gohlke/litrl"
2
+ SPACE_REPO_TYPE = "space"
3
+ MODEL_REPO = "c-gohlke/litrl"
4
+ MODEL_REPO_TYPE = "model"
5
+
6
+ ENV_RESULTS_FILE_DEPTH = 3
src/huggingface/get_environments.py CHANGED
@@ -1,6 +1,7 @@
 
 
1
 
2
-
3
- def get_environments() -> list[str]:
4
  environments: list[str] = []
5
  files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
6
  for file in files:
 
1
+ from huggingface_hub import HfApi # type: ignore[import]
2
+ from src.constants import MODEL_REPO, MODEL_REPO_TYPE, ENV_RESULTS_FILE_DEPTH
3
 
4
+ def get_environments(hf_api: HfApi) -> list[str]:
 
5
  environments: list[str] = []
6
  files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
7
  for file in files:
src/huggingface/get_files.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from huggingface_hub import hf_hub_download # type: ignore[import]
3
+ import yaml # type: ignore[import]
4
+ from src.constants import MODEL_REPO, MODEL_REPO_TYPE
5
+
6
+ def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
7
+ mp4_paths: dict[str, Path] = {}
8
+ for env in environments:
9
+ mp4_paths[env] = get_best(env, filename="demo.mp4")
10
+ return mp4_paths
11
+
12
+
13
+ def get_best(env: str, filename: str = "demo.mp4") -> Path:
14
+ hf_env_results_path = f"models/{env}/results.yaml"
15
+ local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
16
+ with open(local_env_results_path, "r") as f:
17
+ env_results = yaml.load(f, Loader=yaml.FullLoader)
18
+ best_model_type = max(env_results, key=lambda model: env_results[model])
19
+ model_results_path = f"models/{env}/{best_model_type}/results.yaml"
20
+ local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
21
+ with open(local_model_results_path, "r") as f:
22
+ model_results = yaml.load(f, Loader=yaml.FullLoader)
23
+
24
+ best_model = max(model_results, key=lambda model: model_results[model])
25
+ hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/{filename}"
26
+ return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
src/huggingface/huggingface_client.py CHANGED
@@ -1,18 +1,8 @@
1
- from huggingface_hub import hf_hub_download
2
- from huggingface_hub import HfApi, login
3
  import os
4
- import yaml
5
- from pathlib import Path
6
  from loguru import logger
7
- from PIL import Image
8
- from .get_best_mp4 import get_mp4_paths
9
-
10
- SPACE_REPO = "c-gohlke/litrl"
11
- SPACE_REPO_TYPE = "space"
12
- MODEL_REPO = "c-gohlke/litrl"
13
- MODEL_REPO_TYPE = "model"
14
-
15
- ENV_RESULTS_FILE_DEPTH = 3
16
 
17
 
18
  class HugingFaceClient:
@@ -23,7 +13,8 @@ class HugingFaceClient:
23
  new_session=False,
24
  )
25
  self.hf_api = HfApi()
26
- self.mp4_paths = get_mp4_paths()
 
27
 
28
 
29
  def api_predict(self, env_id: str)-> bytes|None:
 
1
+ from huggingface_hub import HfApi, login # type: ignore[import]
 
2
  import os
 
 
3
  from loguru import logger
4
+ from .get_files import get_mp4_paths
5
+ from .get_environments import get_environments
 
 
 
 
 
 
 
6
 
7
 
8
  class HugingFaceClient:
 
13
  new_session=False,
14
  )
15
  self.hf_api = HfApi()
16
+ self.environments = get_environments(self.hf_api)
17
+ self.mp4_paths = get_mp4_paths(environments=self.environments)
18
 
19
 
20
  def api_predict(self, env_id: str)-> bytes|None:
src/typing.py CHANGED
@@ -17,8 +17,4 @@ class RolloutPolicy(enum.Enum):
17
  class CpuConfig(BaseModel):
18
  agent_type: AgentType
19
  simulations: int | None = None
20
- rollout_policy: RolloutPolicy | None = None
21
-
22
- # def __format__(self, __format_spec: str) -> str:
23
- # raise ValueError(f"__format__ not implemented for {self.__class__.__name__}")
24
- # return super().__format__(__format_spec)
 
17
  class CpuConfig(BaseModel):
18
  agent_type: AgentType
19
  simulations: int | None = None
20
+ rollout_policy: RolloutPolicy | None = None