c-gohlke commited on
Commit
7011484
·
verified ·
1 Parent(s): 50fac44

Upload folder using huggingface_hub

Browse files
src/__init__.py CHANGED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ class EndpointFilter(logging.Filter):
5
+ def filter(self, record: logging.LogRecord) -> bool:
6
+ return record.getMessage().find("/connect_four/bot_progress") == -1
7
+
8
+
9
+ # Filter out /endpoint
10
+ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
src/app.py CHANGED
@@ -1,49 +1,59 @@
1
- from typing import Annotated, Any, Generator
2
  from pathlib import Path
3
- from gymnasium.wrappers.record_video import RecordVideo
4
- import numpy as np
5
- from litrl.env.make import make
6
- from litrl.common.agent import RandomAgent
7
- from litrl.env.typing import SingleAgentId
8
  from fastapi import Depends, FastAPI, Request, status
9
  from fastapi.exceptions import RequestValidationError
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.responses import JSONResponse
12
- from litrl.env.connect_four import Board
13
  from loguru import logger
14
- from fastapi.responses import StreamingResponse, RedirectResponse
15
- from src.app_state import AppState
16
- from src.typing import CpuConfig
17
  from pydantic import BaseModel
 
18
  from src.huggingface.huggingface_client import HuggingFaceClient
 
 
 
 
19
  from litrl.env.connect_four import ConnectFour
 
 
 
20
 
21
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
22
- def iter_file()-> Generator[bytes, Any, None]:
23
  with mp4_path.open(mode="rb") as env_file:
24
  yield from env_file
25
 
26
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
27
 
 
28
  ObservationType = list[list[list[int]]]
 
 
29
  class GridResponseType(BaseModel):
30
  grid: ObservationType
31
  done: bool
32
 
33
- def step(env: ConnectFour, action: int)->GridResponseType:
 
34
  env.step(action)
35
  return observe(env)
36
 
37
- def observe(env: ConnectFour)->GridResponseType:
 
38
  obs = env.observe("player_1")
39
- done = env.terminations[env.agent_selection] or env.truncations[env.agent_selection]
40
- return {"grid": obs['observation'].tolist(), "done": done}
 
 
 
41
 
42
- def create_app() -> FastAPI:
43
  app = FastAPI()
44
 
45
- @app.get('/')
46
- async def to_docs():
47
  return RedirectResponse("/docs")
48
 
49
  @app.post(path="/connect_four/play", response_model=GridResponseType)
@@ -52,13 +62,13 @@ def create_app() -> FastAPI:
52
  app_state: Annotated[AppState, Depends(dependency=AppState)],
53
  ) -> GridResponseType:
54
  return step(app_state.env, action)
55
-
56
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
57
  def endpoint_observe(
58
  app_state: Annotated[AppState, Depends(dependency=AppState)],
59
  ) -> GridResponseType:
60
  return observe(app_state.env)
61
-
62
  @app.post(path="/connect_four/bot_play", response_model=GridResponseType)
63
  def endpoint_bot_play(
64
  cpu_config: CpuConfig,
@@ -68,6 +78,18 @@ def create_app() -> FastAPI:
68
  action = app_state.get_action()
69
  return step(app_state.env, action)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
72
  def endpoint_reset(
73
  app_state: Annotated[AppState, Depends(dependency=AppState)],
@@ -77,7 +99,7 @@ def create_app() -> FastAPI:
77
 
78
  @app.get(path="/get_huggingface_video")
79
  def endpoint_get_huggingface_video(
80
- env_id: SingleAgentId,
81
  hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
82
  ) -> StreamingResponse:
83
  hf_client.mp4_paths[env_id]
@@ -85,7 +107,7 @@ def create_app() -> FastAPI:
85
 
86
  @app.get(path="/get_env_video")
87
  def endpoint_get_env_video(
88
- env_id: SingleAgentId,
89
  ) -> StreamingResponse:
90
  env = make(id=env_id, render_mode="rgb_array")
91
  env = RecordVideo(
@@ -93,29 +115,30 @@ def create_app() -> FastAPI:
93
  video_folder="tmp",
94
  )
95
  env.reset(seed=123)
 
 
 
 
 
96
  agent = RandomAgent[Any, Any]()
97
  terminated, truncated = False, False
98
  while not (terminated or truncated):
99
- action = agent.get_action(env=env)
100
  _, _, terminated, truncated, _ = env.step(action=action)
101
  env.render()
102
  env.video_recorder.close()
103
  return stream_mp4(mp4_path=Path(env.video_recorder.path))
104
 
105
  @app.exception_handler(exc_class_or_status_code=RequestValidationError)
106
- async def validation_exception_handler(
107
- request: Request, exc: RequestValidationError
108
- ) -> JSONResponse:
109
  logger.debug(f"url: {request.url}")
110
  if hasattr(request, "_body"):
111
- logger.debug(f"body: {request._body}")
112
  logger.debug(f"header: {request.headers}")
113
  logger.error(f"{request}: {exc}")
114
- exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
115
  content = {"status_code": 10422, "message": exc_str, "data": None}
116
- return JSONResponse(
117
- content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
118
- )
119
 
120
  app.add_middleware(
121
  middleware_class=CORSMiddleware,
@@ -125,17 +148,3 @@ def create_app() -> FastAPI:
125
  allow_headers=["*"],
126
  )
127
  return app
128
-
129
- if __name__ == "__main__":
130
- import uvicorn
131
- import argparse
132
- parser = argparse.ArgumentParser()
133
- parser.add_argument("--host", type=str, default="0.0.0.0")
134
- parser.add_argument("--port", type=int, default=7860)
135
- parser.add_argument("--log", type=str, default="info")
136
- parser.add_argument('--reload', action='store_true', help='Reload flag')
137
- args = parser.parse_args()
138
-
139
- config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level=args.log, reload=args.reload)
140
- server = uvicorn.Server(config=config)
141
- server.run()
 
1
+ from collections.abc import Generator
2
  from pathlib import Path
3
+ from typing import Annotated, Any
4
+
 
 
 
5
  from fastapi import Depends, FastAPI, Request, status
6
  from fastapi.exceptions import RequestValidationError
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
9
+ from gymnasium.wrappers.record_video import RecordVideo
10
  from loguru import logger
 
 
 
11
  from pydantic import BaseModel
12
+ from src.app_state import AppState
13
  from src.huggingface.huggingface_client import HuggingFaceClient
14
+ from src.typing import CpuConfig
15
+
16
+ from litrl.algo.mcts.agent import MCTSAgent
17
+ from litrl.common.agent import RandomAgent
18
  from litrl.env.connect_four import ConnectFour
19
+ from litrl.env.make import make
20
+ from litrl.env.typing import GymId
21
+
22
 
23
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
24
+ def iter_file() -> Generator[bytes, Any, None]:
25
  with mp4_path.open(mode="rb") as env_file:
26
  yield from env_file
27
 
28
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
29
 
30
+
31
  ObservationType = list[list[list[int]]]
32
+
33
+
34
  class GridResponseType(BaseModel):
35
  grid: ObservationType
36
  done: bool
37
 
38
+
39
+ def step(env: ConnectFour, action: int) -> GridResponseType:
40
  env.step(action)
41
  return observe(env)
42
 
43
+
44
+ def observe(env: ConnectFour) -> GridResponseType:
45
  obs = env.observe("player_1")
46
+ return GridResponseType(
47
+ grid=obs["observation"].tolist(),
48
+ done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed?
49
+ )
50
+
51
 
52
+ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
53
  app = FastAPI()
54
 
55
+ @app.get("/")
56
+ async def redirect_to_docs() -> RedirectResponse:
57
  return RedirectResponse("/docs")
58
 
59
  @app.post(path="/connect_four/play", response_model=GridResponseType)
 
62
  app_state: Annotated[AppState, Depends(dependency=AppState)],
63
  ) -> GridResponseType:
64
  return step(app_state.env, action)
65
+
66
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
67
  def endpoint_observe(
68
  app_state: Annotated[AppState, Depends(dependency=AppState)],
69
  ) -> GridResponseType:
70
  return observe(app_state.env)
71
+
72
  @app.post(path="/connect_four/bot_play", response_model=GridResponseType)
73
  def endpoint_bot_play(
74
  cpu_config: CpuConfig,
 
78
  action = app_state.get_action()
79
  return step(app_state.env, action)
80
 
81
+ @app.get(path="/connect_four/bot_progress", response_model=float)
82
+ def endpoint_bot_progress(
83
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
84
+ ) -> float:
85
+ if isinstance(app_state.agent, MCTSAgent):
86
+ if app_state.cpu_config.simulations is None:
87
+ raise ValueError
88
+ return float(
89
+ app_state.agent.mcts._root.visits / app_state.cpu_config.simulations, # noqa: SLF001
90
+ ) # TODO why needed?
91
+ return 1.0
92
+
93
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
94
  def endpoint_reset(
95
  app_state: Annotated[AppState, Depends(dependency=AppState)],
 
99
 
100
  @app.get(path="/get_huggingface_video")
101
  def endpoint_get_huggingface_video(
102
+ env_id: GymId,
103
  hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
104
  ) -> StreamingResponse:
105
  hf_client.mp4_paths[env_id]
 
107
 
108
  @app.get(path="/get_env_video")
109
  def endpoint_get_env_video(
110
+ env_id: GymId,
111
  ) -> StreamingResponse:
112
  env = make(id=env_id, render_mode="rgb_array")
113
  env = RecordVideo(
 
115
  video_folder="tmp",
116
  )
117
  env.reset(seed=123)
118
+
119
+ if env.video_recorder is None:
120
+ msg = "env.video_recorder is None"
121
+ raise ValueError(msg)
122
+
123
  agent = RandomAgent[Any, Any]()
124
  terminated, truncated = False, False
125
  while not (terminated or truncated):
126
+ action = agent.get_action(env=env) # type: ignore[arg-type]
127
  _, _, terminated, truncated, _ = env.step(action=action)
128
  env.render()
129
  env.video_recorder.close()
130
  return stream_mp4(mp4_path=Path(env.video_recorder.path))
131
 
132
  @app.exception_handler(exc_class_or_status_code=RequestValidationError)
133
+ async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
 
 
134
  logger.debug(f"url: {request.url}")
135
  if hasattr(request, "_body"):
136
+ logger.debug(f"body: {request._body.decode()}") # noqa: SLF001
137
  logger.debug(f"header: {request.headers}")
138
  logger.error(f"{request}: {exc}")
139
+ exc_str = str(exc).replace("\n", " ").replace(" ", " ")
140
  content = {"status_code": 10422, "message": exc_str, "data": None}
141
+ return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
 
 
142
 
143
  app.add_middleware(
144
  middleware_class=CORSMiddleware,
 
148
  allow_headers=["*"],
149
  )
150
  return app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/app_state.py CHANGED
@@ -1,45 +1,40 @@
1
  from typing import Any, Self
2
 
 
 
 
3
  from litrl import make_multiagent
4
  from litrl.algo.mcts.agent import MCTSAgent
5
- from litrl.algo.mcts.mcts_config import MCTSConfig
6
  from litrl.algo.mcts.rollout import VanillaRollout
7
- from litrl.algo.sac.agent import OnnxSacDeterministicAgent
8
  from litrl.common.agent import Agent, RandomMultiAgent
9
- from litrl.env.connect_four import Board
10
- from litrl.env.set_state import set_state
11
- from loguru import logger
12
- from src.typing import AgentType, CpuConfig, RolloutPolicy
13
  from litrl.env.connect_four import ConnectFour
 
 
14
 
15
  class AppState:
16
  _instance: Self | None = None
17
  env: ConnectFour
18
  cpu_config: CpuConfig
19
- agent: Agent[Any, Any]
20
 
21
  def setup(self) -> None:
22
  logger.debug("AppState setup called")
23
- self.env = make_multiagent(id="ConnectFour-v3", render_mode="rgb_array")
24
  self.env.reset(seed=123)
25
 
26
  self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
27
  self.set_agent() # TODO in properties setter.
28
- self.agent: Agent[Any, Any]
29
 
30
- def __new__(cls):
31
  if cls._instance is None:
32
  cls._instance = super().__new__(cls)
33
  cls._instance.setup()
34
  return cls._instance
35
-
36
- def set_board(self, board: Board) -> None:
37
- set_state(env=self.env, board=board)
38
 
39
  def set_config(self, cpu_config: CpuConfig) -> None:
40
- if (
41
- cpu_config != self.cpu_config
42
- ):
43
  self.cpu_config = cpu_config
44
  self.set_agent()
45
 
@@ -48,31 +43,32 @@ class AppState:
48
  case None:
49
  return RandomMultiAgent()
50
  case RolloutPolicy.SAC:
51
- return OnnxSacDeterministicAgent()
52
  case RolloutPolicy.RANDOM:
53
  return RandomMultiAgent()
54
  case _:
55
- raise NotImplementedError(
56
- f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
57
- )
58
 
59
  def set_agent(self) -> None:
60
  match self.cpu_config.agent_type:
61
  case AgentType.MCTS:
62
  rollout_agent = self.create_rollout()
63
- mcts_config = MCTSConfig(
64
- simulations=self.cpu_config.simulations or 50,
65
- rollout_strategy=VanillaRollout(rollout_agent=rollout_agent),
66
- )
 
 
 
67
  self.agent = MCTSAgent(cfg=mcts_config)
68
  case AgentType.RANDOM:
69
  self.agent = RandomMultiAgent()
70
  case AgentType.SAC:
71
- self.agent = OnnxSacDeterministicAgent()
72
  case _:
73
- raise NotImplementedError(
74
- f"cpu_config.name: {self.cpu_config.agent_type}"
75
- )
76
 
77
  def get_action(self) -> int:
78
  return self.agent.get_action(env=self.env)
 
1
  from typing import Any, Self
2
 
3
+ from loguru import logger
4
+ from src.typing import AgentType, CpuConfig, RolloutPolicy
5
+
6
  from litrl import make_multiagent
7
  from litrl.algo.mcts.agent import MCTSAgent
8
+ from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
9
  from litrl.algo.mcts.rollout import VanillaRollout
 
10
  from litrl.common.agent import Agent, RandomMultiAgent
 
 
 
 
11
  from litrl.env.connect_four import ConnectFour
12
+ from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
13
+
14
 
15
  class AppState:
16
  _instance: Self | None = None
17
  env: ConnectFour
18
  cpu_config: CpuConfig
19
+ agent: Agent[Any, int]
20
 
21
  def setup(self) -> None:
22
  logger.debug("AppState setup called")
23
+ self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
24
  self.env.reset(seed=123)
25
 
26
  self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
27
  self.set_agent() # TODO in properties setter.
28
+ self.agent: Agent[Any, int]
29
 
30
+ def __new__(cls) -> "AppState":
31
  if cls._instance is None:
32
  cls._instance = super().__new__(cls)
33
  cls._instance.setup()
34
  return cls._instance
 
 
 
35
 
36
  def set_config(self, cpu_config: CpuConfig) -> None:
37
+ if cpu_config != self.cpu_config:
 
 
38
  self.cpu_config = cpu_config
39
  self.set_agent()
40
 
 
43
  case None:
44
  return RandomMultiAgent()
45
  case RolloutPolicy.SAC:
46
+ return OnnxSacDeterministicMultiAgent()
47
  case RolloutPolicy.RANDOM:
48
  return RandomMultiAgent()
49
  case _:
50
+ msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
51
+ raise NotImplementedError(msg)
 
52
 
53
  def set_agent(self) -> None:
54
  match self.cpu_config.agent_type:
55
  case AgentType.MCTS:
56
  rollout_agent = self.create_rollout()
57
+ # fmt: off
58
+ mcts_config = (
59
+ MCTSConfigBuilder()
60
+ .set_simulations(self.cpu_config.simulations or 50)
61
+ .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
62
+ ).build()
63
+ # fmt: on
64
  self.agent = MCTSAgent(cfg=mcts_config)
65
  case AgentType.RANDOM:
66
  self.agent = RandomMultiAgent()
67
  case AgentType.SAC:
68
+ self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
69
  case _:
70
+ msg = f"cpu_config.name: {self.cpu_config.agent_type}"
71
+ raise NotImplementedError(msg)
 
72
 
73
  def get_action(self) -> int:
74
  return self.agent.get_action(env=self.env)
src/constants.py CHANGED
@@ -1,6 +1,4 @@
1
- SPACE_REPO = "c-gohlke/litrl"
2
- SPACE_REPO_TYPE = "space"
3
  MODEL_REPO = "c-gohlke/litrl"
4
  MODEL_REPO_TYPE = "model"
5
 
6
- ENV_RESULTS_FILE_DEPTH = 3
 
 
 
1
  MODEL_REPO = "c-gohlke/litrl"
2
  MODEL_REPO_TYPE = "model"
3
 
4
+ ENV_RESULTS_FILE_DEPTH = 3
src/huggingface/__init__.py ADDED
File without changes
src/huggingface/get_environments.py CHANGED
@@ -1,5 +1,6 @@
1
- from huggingface_hub import HfApi # type: ignore[import]
2
- from src.constants import MODEL_REPO, MODEL_REPO_TYPE, ENV_RESULTS_FILE_DEPTH
 
3
 
4
  def get_environments(hf_api: HfApi) -> list[str]:
5
  environments: list[str] = []
@@ -9,4 +10,4 @@ def get_environments(hf_api: HfApi) -> list[str]:
9
  # e.g. ['models', 'CartPole-v1', 'results.yaml']
10
  if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
11
  environments.append(vals[1])
12
- return environments
 
1
+ from huggingface_hub import HfApi
2
+ from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
3
+
4
 
5
  def get_environments(hf_api: HfApi) -> list[str]:
6
  environments: list[str] = []
 
10
  # e.g. ['models', 'CartPole-v1', 'results.yaml']
11
  if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
12
  environments.append(vals[1])
13
+ return environments
src/huggingface/get_files.py CHANGED
@@ -1,8 +1,10 @@
1
  from pathlib import Path
2
- from huggingface_hub import hf_hub_download # type: ignore[import]
3
- import yaml # type: ignore[import]
 
4
  from src.constants import MODEL_REPO, MODEL_REPO_TYPE
5
 
 
6
  def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
7
  mp4_paths: dict[str, Path] = {}
8
  for env in environments:
@@ -12,15 +14,23 @@ def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
12
 
13
  def get_best(env: str, filename: str = "demo.mp4") -> Path:
14
  hf_env_results_path = f"models/{env}/results.yaml"
15
- local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
16
- with open(local_env_results_path, "r") as f:
 
 
 
 
17
  env_results = yaml.load(f, Loader=yaml.FullLoader)
18
  best_model_type = max(env_results, key=lambda model: env_results[model])
19
  model_results_path = f"models/{env}/{best_model_type}/results.yaml"
20
- local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
21
- with open(local_model_results_path, "r") as f:
 
 
 
 
22
  model_results = yaml.load(f, Loader=yaml.FullLoader)
23
 
24
  best_model = max(model_results, key=lambda model: model_results[model])
25
  hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/{filename}"
26
- return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
 
1
  from pathlib import Path
2
+
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download
5
  from src.constants import MODEL_REPO, MODEL_REPO_TYPE
6
 
7
+
8
  def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
9
  mp4_paths: dict[str, Path] = {}
10
  for env in environments:
 
14
 
15
  def get_best(env: str, filename: str = "demo.mp4") -> Path:
16
  hf_env_results_path = f"models/{env}/results.yaml"
17
+ local_env_results_path = hf_hub_download(
18
+ repo_id=MODEL_REPO,
19
+ repo_type=MODEL_REPO_TYPE,
20
+ filename=hf_env_results_path,
21
+ )
22
+ with Path(local_env_results_path).open() as f:
23
  env_results = yaml.load(f, Loader=yaml.FullLoader)
24
  best_model_type = max(env_results, key=lambda model: env_results[model])
25
  model_results_path = f"models/{env}/{best_model_type}/results.yaml"
26
+ local_model_results_path = hf_hub_download(
27
+ repo_id=MODEL_REPO,
28
+ repo_type=MODEL_REPO_TYPE,
29
+ filename=model_results_path,
30
+ )
31
+ with Path(local_model_results_path).open() as f:
32
  model_results = yaml.load(f, Loader=yaml.FullLoader)
33
 
34
  best_model = max(model_results, key=lambda model: model_results[model])
35
  hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/{filename}"
36
+ return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
src/huggingface/huggingface_client.py CHANGED
@@ -1,13 +1,14 @@
1
- from huggingface_hub import HfApi, login # type: ignore[import]
2
  import os
3
- from loguru import logger
4
- from .get_files import get_mp4_paths
 
5
  from .get_environments import get_environments
 
6
 
7
 
8
  class HuggingFaceClient:
9
  def __init__(self) -> None:
10
- login( # type: ignore[no-untyped-call]
11
  token=os.environ.get("HUGGINGFACE_TOKEN"),
12
  add_to_git_credential=True,
13
  new_session=False,
@@ -15,11 +16,3 @@ class HuggingFaceClient:
15
  self.hf_api = HfApi()
16
  self.environments = get_environments(self.hf_api)
17
  self.mp4_paths = get_mp4_paths(environments=self.environments)
18
-
19
-
20
- def api_predict(self, env_id: str)-> bytes|None:
21
- if env_id not in self.mp4_paths:
22
- logger.error(f"Environment {env_id} not found in {self.mp4_paths}")
23
- return None
24
- with open(self.mp4_paths[env_id], "rb") as f:
25
- return f.read()
 
 
1
  import os
2
+
3
+ from huggingface_hub import HfApi, login
4
+
5
  from .get_environments import get_environments
6
+ from .get_files import get_mp4_paths
7
 
8
 
9
  class HuggingFaceClient:
10
  def __init__(self) -> None:
11
+ login(
12
  token=os.environ.get("HUGGINGFACE_TOKEN"),
13
  add_to_git_credential=True,
14
  new_session=False,
 
16
  self.hf_api = HfApi()
17
  self.environments = get_environments(self.hf_api)
18
  self.mp4_paths = get_mp4_paths(environments=self.environments)
 
 
 
 
 
 
 
 
src/typing.py CHANGED
@@ -17,4 +17,4 @@ class RolloutPolicy(enum.Enum):
17
  class CpuConfig(BaseModel):
18
  agent_type: AgentType
19
  simulations: int | None = None
20
- rollout_policy: RolloutPolicy | None = None
 
17
  class CpuConfig(BaseModel):
18
  agent_type: AgentType
19
  simulations: int | None = None
20
+ rollout_policy: RolloutPolicy | None = None