Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- src/__init__.py +10 -0
- src/app.py +55 -46
- src/app_state.py +24 -28
- src/constants.py +1 -3
- src/huggingface/__init__.py +0 -0
- src/huggingface/get_environments.py +4 -3
- src/huggingface/get_files.py +17 -7
- src/huggingface/huggingface_client.py +5 -12
- src/typing.py +1 -1
src/__init__.py
CHANGED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
class EndpointFilter(logging.Filter):
|
5 |
+
def filter(self, record: logging.LogRecord) -> bool:
|
6 |
+
return record.getMessage().find("/connect_four/bot_progress") == -1
|
7 |
+
|
8 |
+
|
9 |
+
# Filter out /endpoint
|
10 |
+
logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
|
src/app.py
CHANGED
@@ -1,49 +1,59 @@
|
|
1 |
-
from
|
2 |
from pathlib import Path
|
3 |
-
from
|
4 |
-
|
5 |
-
from litrl.env.make import make
|
6 |
-
from litrl.common.agent import RandomAgent
|
7 |
-
from litrl.env.typing import SingleAgentId
|
8 |
from fastapi import Depends, FastAPI, Request, status
|
9 |
from fastapi.exceptions import RequestValidationError
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
-
from fastapi.responses import JSONResponse
|
12 |
-
from
|
13 |
from loguru import logger
|
14 |
-
from fastapi.responses import StreamingResponse, RedirectResponse
|
15 |
-
from src.app_state import AppState
|
16 |
-
from src.typing import CpuConfig
|
17 |
from pydantic import BaseModel
|
|
|
18 |
from src.huggingface.huggingface_client import HuggingFaceClient
|
|
|
|
|
|
|
|
|
19 |
from litrl.env.connect_four import ConnectFour
|
|
|
|
|
|
|
20 |
|
21 |
def stream_mp4(mp4_path: Path) -> StreamingResponse:
|
22 |
-
def iter_file()-> Generator[bytes, Any, None]:
|
23 |
with mp4_path.open(mode="rb") as env_file:
|
24 |
yield from env_file
|
25 |
|
26 |
return StreamingResponse(content=iter_file(), media_type="video/mp4")
|
27 |
|
|
|
28 |
ObservationType = list[list[list[int]]]
|
|
|
|
|
29 |
class GridResponseType(BaseModel):
|
30 |
grid: ObservationType
|
31 |
done: bool
|
32 |
|
33 |
-
|
|
|
34 |
env.step(action)
|
35 |
return observe(env)
|
36 |
|
37 |
-
|
|
|
38 |
obs = env.observe("player_1")
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
-
def create_app() -> FastAPI:
|
43 |
app = FastAPI()
|
44 |
|
45 |
-
@app.get(
|
46 |
-
async def
|
47 |
return RedirectResponse("/docs")
|
48 |
|
49 |
@app.post(path="/connect_four/play", response_model=GridResponseType)
|
@@ -52,13 +62,13 @@ def create_app() -> FastAPI:
|
|
52 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
53 |
) -> GridResponseType:
|
54 |
return step(app_state.env, action)
|
55 |
-
|
56 |
@app.get(path="/connect_four/observe", response_model=GridResponseType)
|
57 |
def endpoint_observe(
|
58 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
59 |
) -> GridResponseType:
|
60 |
return observe(app_state.env)
|
61 |
-
|
62 |
@app.post(path="/connect_four/bot_play", response_model=GridResponseType)
|
63 |
def endpoint_bot_play(
|
64 |
cpu_config: CpuConfig,
|
@@ -68,6 +78,18 @@ def create_app() -> FastAPI:
|
|
68 |
action = app_state.get_action()
|
69 |
return step(app_state.env, action)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
@app.get(path="/connect_four/reset", response_model=GridResponseType)
|
72 |
def endpoint_reset(
|
73 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
@@ -77,7 +99,7 @@ def create_app() -> FastAPI:
|
|
77 |
|
78 |
@app.get(path="/get_huggingface_video")
|
79 |
def endpoint_get_huggingface_video(
|
80 |
-
env_id:
|
81 |
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
|
82 |
) -> StreamingResponse:
|
83 |
hf_client.mp4_paths[env_id]
|
@@ -85,7 +107,7 @@ def create_app() -> FastAPI:
|
|
85 |
|
86 |
@app.get(path="/get_env_video")
|
87 |
def endpoint_get_env_video(
|
88 |
-
env_id:
|
89 |
) -> StreamingResponse:
|
90 |
env = make(id=env_id, render_mode="rgb_array")
|
91 |
env = RecordVideo(
|
@@ -93,29 +115,30 @@ def create_app() -> FastAPI:
|
|
93 |
video_folder="tmp",
|
94 |
)
|
95 |
env.reset(seed=123)
|
|
|
|
|
|
|
|
|
|
|
96 |
agent = RandomAgent[Any, Any]()
|
97 |
terminated, truncated = False, False
|
98 |
while not (terminated or truncated):
|
99 |
-
action = agent.get_action(env=env)
|
100 |
_, _, terminated, truncated, _ = env.step(action=action)
|
101 |
env.render()
|
102 |
env.video_recorder.close()
|
103 |
return stream_mp4(mp4_path=Path(env.video_recorder.path))
|
104 |
|
105 |
@app.exception_handler(exc_class_or_status_code=RequestValidationError)
|
106 |
-
async def validation_exception_handler(
|
107 |
-
request: Request, exc: RequestValidationError
|
108 |
-
) -> JSONResponse:
|
109 |
logger.debug(f"url: {request.url}")
|
110 |
if hasattr(request, "_body"):
|
111 |
-
logger.debug(f"body: {request._body}")
|
112 |
logger.debug(f"header: {request.headers}")
|
113 |
logger.error(f"{request}: {exc}")
|
114 |
-
exc_str =
|
115 |
content = {"status_code": 10422, "message": exc_str, "data": None}
|
116 |
-
return JSONResponse(
|
117 |
-
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
118 |
-
)
|
119 |
|
120 |
app.add_middleware(
|
121 |
middleware_class=CORSMiddleware,
|
@@ -125,17 +148,3 @@ def create_app() -> FastAPI:
|
|
125 |
allow_headers=["*"],
|
126 |
)
|
127 |
return app
|
128 |
-
|
129 |
-
if __name__ == "__main__":
|
130 |
-
import uvicorn
|
131 |
-
import argparse
|
132 |
-
parser = argparse.ArgumentParser()
|
133 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
134 |
-
parser.add_argument("--port", type=int, default=7860)
|
135 |
-
parser.add_argument("--log", type=str, default="info")
|
136 |
-
parser.add_argument('--reload', action='store_true', help='Reload flag')
|
137 |
-
args = parser.parse_args()
|
138 |
-
|
139 |
-
config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level=args.log, reload=args.reload)
|
140 |
-
server = uvicorn.Server(config=config)
|
141 |
-
server.run()
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
from pathlib import Path
|
3 |
+
from typing import Annotated, Any
|
4 |
+
|
|
|
|
|
|
|
5 |
from fastapi import Depends, FastAPI, Request, status
|
6 |
from fastapi.exceptions import RequestValidationError
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
9 |
+
from gymnasium.wrappers.record_video import RecordVideo
|
10 |
from loguru import logger
|
|
|
|
|
|
|
11 |
from pydantic import BaseModel
|
12 |
+
from src.app_state import AppState
|
13 |
from src.huggingface.huggingface_client import HuggingFaceClient
|
14 |
+
from src.typing import CpuConfig
|
15 |
+
|
16 |
+
from litrl.algo.mcts.agent import MCTSAgent
|
17 |
+
from litrl.common.agent import RandomAgent
|
18 |
from litrl.env.connect_four import ConnectFour
|
19 |
+
from litrl.env.make import make
|
20 |
+
from litrl.env.typing import GymId
|
21 |
+
|
22 |
|
23 |
def stream_mp4(mp4_path: Path) -> StreamingResponse:
|
24 |
+
def iter_file() -> Generator[bytes, Any, None]:
|
25 |
with mp4_path.open(mode="rb") as env_file:
|
26 |
yield from env_file
|
27 |
|
28 |
return StreamingResponse(content=iter_file(), media_type="video/mp4")
|
29 |
|
30 |
+
|
31 |
ObservationType = list[list[list[int]]]
|
32 |
+
|
33 |
+
|
34 |
class GridResponseType(BaseModel):
|
35 |
grid: ObservationType
|
36 |
done: bool
|
37 |
|
38 |
+
|
39 |
+
def step(env: ConnectFour, action: int) -> GridResponseType:
|
40 |
env.step(action)
|
41 |
return observe(env)
|
42 |
|
43 |
+
|
44 |
+
def observe(env: ConnectFour) -> GridResponseType:
|
45 |
obs = env.observe("player_1")
|
46 |
+
return GridResponseType(
|
47 |
+
grid=obs["observation"].tolist(),
|
48 |
+
done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed?
|
49 |
+
)
|
50 |
+
|
51 |
|
52 |
+
def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
|
53 |
app = FastAPI()
|
54 |
|
55 |
+
@app.get("/")
|
56 |
+
async def redirect_to_docs() -> RedirectResponse:
|
57 |
return RedirectResponse("/docs")
|
58 |
|
59 |
@app.post(path="/connect_four/play", response_model=GridResponseType)
|
|
|
62 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
63 |
) -> GridResponseType:
|
64 |
return step(app_state.env, action)
|
65 |
+
|
66 |
@app.get(path="/connect_four/observe", response_model=GridResponseType)
|
67 |
def endpoint_observe(
|
68 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
69 |
) -> GridResponseType:
|
70 |
return observe(app_state.env)
|
71 |
+
|
72 |
@app.post(path="/connect_four/bot_play", response_model=GridResponseType)
|
73 |
def endpoint_bot_play(
|
74 |
cpu_config: CpuConfig,
|
|
|
78 |
action = app_state.get_action()
|
79 |
return step(app_state.env, action)
|
80 |
|
81 |
+
@app.get(path="/connect_four/bot_progress", response_model=float)
|
82 |
+
def endpoint_bot_progress(
|
83 |
+
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
84 |
+
) -> float:
|
85 |
+
if isinstance(app_state.agent, MCTSAgent):
|
86 |
+
if app_state.cpu_config.simulations is None:
|
87 |
+
raise ValueError
|
88 |
+
return float(
|
89 |
+
app_state.agent.mcts._root.visits / app_state.cpu_config.simulations, # noqa: SLF001
|
90 |
+
) # TODO why needed?
|
91 |
+
return 1.0
|
92 |
+
|
93 |
@app.get(path="/connect_four/reset", response_model=GridResponseType)
|
94 |
def endpoint_reset(
|
95 |
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
|
|
99 |
|
100 |
@app.get(path="/get_huggingface_video")
|
101 |
def endpoint_get_huggingface_video(
|
102 |
+
env_id: GymId,
|
103 |
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
|
104 |
) -> StreamingResponse:
|
105 |
hf_client.mp4_paths[env_id]
|
|
|
107 |
|
108 |
@app.get(path="/get_env_video")
|
109 |
def endpoint_get_env_video(
|
110 |
+
env_id: GymId,
|
111 |
) -> StreamingResponse:
|
112 |
env = make(id=env_id, render_mode="rgb_array")
|
113 |
env = RecordVideo(
|
|
|
115 |
video_folder="tmp",
|
116 |
)
|
117 |
env.reset(seed=123)
|
118 |
+
|
119 |
+
if env.video_recorder is None:
|
120 |
+
msg = "env.video_recorder is None"
|
121 |
+
raise ValueError(msg)
|
122 |
+
|
123 |
agent = RandomAgent[Any, Any]()
|
124 |
terminated, truncated = False, False
|
125 |
while not (terminated or truncated):
|
126 |
+
action = agent.get_action(env=env) # type: ignore[arg-type]
|
127 |
_, _, terminated, truncated, _ = env.step(action=action)
|
128 |
env.render()
|
129 |
env.video_recorder.close()
|
130 |
return stream_mp4(mp4_path=Path(env.video_recorder.path))
|
131 |
|
132 |
@app.exception_handler(exc_class_or_status_code=RequestValidationError)
|
133 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
|
|
|
|
134 |
logger.debug(f"url: {request.url}")
|
135 |
if hasattr(request, "_body"):
|
136 |
+
logger.debug(f"body: {request._body.decode()}") # noqa: SLF001
|
137 |
logger.debug(f"header: {request.headers}")
|
138 |
logger.error(f"{request}: {exc}")
|
139 |
+
exc_str = str(exc).replace("\n", " ").replace(" ", " ")
|
140 |
content = {"status_code": 10422, "message": exc_str, "data": None}
|
141 |
+
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
|
|
|
|
|
142 |
|
143 |
app.add_middleware(
|
144 |
middleware_class=CORSMiddleware,
|
|
|
148 |
allow_headers=["*"],
|
149 |
)
|
150 |
return app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/app_state.py
CHANGED
@@ -1,45 +1,40 @@
|
|
1 |
from typing import Any, Self
|
2 |
|
|
|
|
|
|
|
3 |
from litrl import make_multiagent
|
4 |
from litrl.algo.mcts.agent import MCTSAgent
|
5 |
-
from litrl.algo.mcts.mcts_config import
|
6 |
from litrl.algo.mcts.rollout import VanillaRollout
|
7 |
-
from litrl.algo.sac.agent import OnnxSacDeterministicAgent
|
8 |
from litrl.common.agent import Agent, RandomMultiAgent
|
9 |
-
from litrl.env.connect_four import Board
|
10 |
-
from litrl.env.set_state import set_state
|
11 |
-
from loguru import logger
|
12 |
-
from src.typing import AgentType, CpuConfig, RolloutPolicy
|
13 |
from litrl.env.connect_four import ConnectFour
|
|
|
|
|
14 |
|
15 |
class AppState:
|
16 |
_instance: Self | None = None
|
17 |
env: ConnectFour
|
18 |
cpu_config: CpuConfig
|
19 |
-
agent: Agent[Any,
|
20 |
|
21 |
def setup(self) -> None:
|
22 |
logger.debug("AppState setup called")
|
23 |
-
self.env = make_multiagent(id="
|
24 |
self.env.reset(seed=123)
|
25 |
|
26 |
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
|
27 |
self.set_agent() # TODO in properties setter.
|
28 |
-
self.agent: Agent[Any,
|
29 |
|
30 |
-
def __new__(cls):
|
31 |
if cls._instance is None:
|
32 |
cls._instance = super().__new__(cls)
|
33 |
cls._instance.setup()
|
34 |
return cls._instance
|
35 |
-
|
36 |
-
def set_board(self, board: Board) -> None:
|
37 |
-
set_state(env=self.env, board=board)
|
38 |
|
39 |
def set_config(self, cpu_config: CpuConfig) -> None:
|
40 |
-
if
|
41 |
-
cpu_config != self.cpu_config
|
42 |
-
):
|
43 |
self.cpu_config = cpu_config
|
44 |
self.set_agent()
|
45 |
|
@@ -48,31 +43,32 @@ class AppState:
|
|
48 |
case None:
|
49 |
return RandomMultiAgent()
|
50 |
case RolloutPolicy.SAC:
|
51 |
-
return
|
52 |
case RolloutPolicy.RANDOM:
|
53 |
return RandomMultiAgent()
|
54 |
case _:
|
55 |
-
|
56 |
-
|
57 |
-
)
|
58 |
|
59 |
def set_agent(self) -> None:
|
60 |
match self.cpu_config.agent_type:
|
61 |
case AgentType.MCTS:
|
62 |
rollout_agent = self.create_rollout()
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
self.agent = MCTSAgent(cfg=mcts_config)
|
68 |
case AgentType.RANDOM:
|
69 |
self.agent = RandomMultiAgent()
|
70 |
case AgentType.SAC:
|
71 |
-
self.agent =
|
72 |
case _:
|
73 |
-
|
74 |
-
|
75 |
-
)
|
76 |
|
77 |
def get_action(self) -> int:
|
78 |
return self.agent.get_action(env=self.env)
|
|
|
1 |
from typing import Any, Self
|
2 |
|
3 |
+
from loguru import logger
|
4 |
+
from src.typing import AgentType, CpuConfig, RolloutPolicy
|
5 |
+
|
6 |
from litrl import make_multiagent
|
7 |
from litrl.algo.mcts.agent import MCTSAgent
|
8 |
+
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
|
9 |
from litrl.algo.mcts.rollout import VanillaRollout
|
|
|
10 |
from litrl.common.agent import Agent, RandomMultiAgent
|
|
|
|
|
|
|
|
|
11 |
from litrl.env.connect_four import ConnectFour
|
12 |
+
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
|
13 |
+
|
14 |
|
15 |
class AppState:
|
16 |
_instance: Self | None = None
|
17 |
env: ConnectFour
|
18 |
cpu_config: CpuConfig
|
19 |
+
agent: Agent[Any, int]
|
20 |
|
21 |
def setup(self) -> None:
|
22 |
logger.debug("AppState setup called")
|
23 |
+
self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
|
24 |
self.env.reset(seed=123)
|
25 |
|
26 |
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
|
27 |
self.set_agent() # TODO in properties setter.
|
28 |
+
self.agent: Agent[Any, int]
|
29 |
|
30 |
+
def __new__(cls) -> "AppState":
|
31 |
if cls._instance is None:
|
32 |
cls._instance = super().__new__(cls)
|
33 |
cls._instance.setup()
|
34 |
return cls._instance
|
|
|
|
|
|
|
35 |
|
36 |
def set_config(self, cpu_config: CpuConfig) -> None:
|
37 |
+
if cpu_config != self.cpu_config:
|
|
|
|
|
38 |
self.cpu_config = cpu_config
|
39 |
self.set_agent()
|
40 |
|
|
|
43 |
case None:
|
44 |
return RandomMultiAgent()
|
45 |
case RolloutPolicy.SAC:
|
46 |
+
return OnnxSacDeterministicMultiAgent()
|
47 |
case RolloutPolicy.RANDOM:
|
48 |
return RandomMultiAgent()
|
49 |
case _:
|
50 |
+
msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
|
51 |
+
raise NotImplementedError(msg)
|
|
|
52 |
|
53 |
def set_agent(self) -> None:
|
54 |
match self.cpu_config.agent_type:
|
55 |
case AgentType.MCTS:
|
56 |
rollout_agent = self.create_rollout()
|
57 |
+
# fmt: off
|
58 |
+
mcts_config = (
|
59 |
+
MCTSConfigBuilder()
|
60 |
+
.set_simulations(self.cpu_config.simulations or 50)
|
61 |
+
.set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
|
62 |
+
).build()
|
63 |
+
# fmt: on
|
64 |
self.agent = MCTSAgent(cfg=mcts_config)
|
65 |
case AgentType.RANDOM:
|
66 |
self.agent = RandomMultiAgent()
|
67 |
case AgentType.SAC:
|
68 |
+
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
|
69 |
case _:
|
70 |
+
msg = f"cpu_config.name: {self.cpu_config.agent_type}"
|
71 |
+
raise NotImplementedError(msg)
|
|
|
72 |
|
73 |
def get_action(self) -> int:
|
74 |
return self.agent.get_action(env=self.env)
|
src/constants.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
SPACE_REPO = "c-gohlke/litrl"
|
2 |
-
SPACE_REPO_TYPE = "space"
|
3 |
MODEL_REPO = "c-gohlke/litrl"
|
4 |
MODEL_REPO_TYPE = "model"
|
5 |
|
6 |
-
ENV_RESULTS_FILE_DEPTH = 3
|
|
|
|
|
|
|
1 |
MODEL_REPO = "c-gohlke/litrl"
|
2 |
MODEL_REPO_TYPE = "model"
|
3 |
|
4 |
+
ENV_RESULTS_FILE_DEPTH = 3
|
src/huggingface/__init__.py
ADDED
File without changes
|
src/huggingface/get_environments.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from huggingface_hub import HfApi
|
2 |
-
from src.constants import MODEL_REPO, MODEL_REPO_TYPE
|
|
|
3 |
|
4 |
def get_environments(hf_api: HfApi) -> list[str]:
|
5 |
environments: list[str] = []
|
@@ -9,4 +10,4 @@ def get_environments(hf_api: HfApi) -> list[str]:
|
|
9 |
# e.g. ['models', 'CartPole-v1', 'results.yaml']
|
10 |
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
|
11 |
environments.append(vals[1])
|
12 |
-
return environments
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
|
3 |
+
|
4 |
|
5 |
def get_environments(hf_api: HfApi) -> list[str]:
|
6 |
environments: list[str] = []
|
|
|
10 |
# e.g. ['models', 'CartPole-v1', 'results.yaml']
|
11 |
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
|
12 |
environments.append(vals[1])
|
13 |
+
return environments
|
src/huggingface/get_files.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
from pathlib import Path
|
2 |
-
|
3 |
-
import yaml
|
|
|
4 |
from src.constants import MODEL_REPO, MODEL_REPO_TYPE
|
5 |
|
|
|
6 |
def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
|
7 |
mp4_paths: dict[str, Path] = {}
|
8 |
for env in environments:
|
@@ -12,15 +14,23 @@ def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
|
|
12 |
|
13 |
def get_best(env: str, filename: str = "demo.mp4") -> Path:
|
14 |
hf_env_results_path = f"models/{env}/results.yaml"
|
15 |
-
local_env_results_path = hf_hub_download(
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
env_results = yaml.load(f, Loader=yaml.FullLoader)
|
18 |
best_model_type = max(env_results, key=lambda model: env_results[model])
|
19 |
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
|
20 |
-
local_model_results_path = hf_hub_download(
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
model_results = yaml.load(f, Loader=yaml.FullLoader)
|
23 |
|
24 |
best_model = max(model_results, key=lambda model: model_results[model])
|
25 |
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/{filename}"
|
26 |
-
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
|
|
|
1 |
from pathlib import Path
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
from src.constants import MODEL_REPO, MODEL_REPO_TYPE
|
6 |
|
7 |
+
|
8 |
def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
|
9 |
mp4_paths: dict[str, Path] = {}
|
10 |
for env in environments:
|
|
|
14 |
|
15 |
def get_best(env: str, filename: str = "demo.mp4") -> Path:
|
16 |
hf_env_results_path = f"models/{env}/results.yaml"
|
17 |
+
local_env_results_path = hf_hub_download(
|
18 |
+
repo_id=MODEL_REPO,
|
19 |
+
repo_type=MODEL_REPO_TYPE,
|
20 |
+
filename=hf_env_results_path,
|
21 |
+
)
|
22 |
+
with Path(local_env_results_path).open() as f:
|
23 |
env_results = yaml.load(f, Loader=yaml.FullLoader)
|
24 |
best_model_type = max(env_results, key=lambda model: env_results[model])
|
25 |
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
|
26 |
+
local_model_results_path = hf_hub_download(
|
27 |
+
repo_id=MODEL_REPO,
|
28 |
+
repo_type=MODEL_REPO_TYPE,
|
29 |
+
filename=model_results_path,
|
30 |
+
)
|
31 |
+
with Path(local_model_results_path).open() as f:
|
32 |
model_results = yaml.load(f, Loader=yaml.FullLoader)
|
33 |
|
34 |
best_model = max(model_results, key=lambda model: model_results[model])
|
35 |
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/{filename}"
|
36 |
+
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
|
src/huggingface/huggingface_client.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
-
from huggingface_hub import HfApi, login # type: ignore[import]
|
2 |
import os
|
3 |
-
|
4 |
-
from
|
|
|
5 |
from .get_environments import get_environments
|
|
|
6 |
|
7 |
|
8 |
class HuggingFaceClient:
|
9 |
def __init__(self) -> None:
|
10 |
-
login(
|
11 |
token=os.environ.get("HUGGINGFACE_TOKEN"),
|
12 |
add_to_git_credential=True,
|
13 |
new_session=False,
|
@@ -15,11 +16,3 @@ class HuggingFaceClient:
|
|
15 |
self.hf_api = HfApi()
|
16 |
self.environments = get_environments(self.hf_api)
|
17 |
self.mp4_paths = get_mp4_paths(environments=self.environments)
|
18 |
-
|
19 |
-
|
20 |
-
def api_predict(self, env_id: str)-> bytes|None:
|
21 |
-
if env_id not in self.mp4_paths:
|
22 |
-
logger.error(f"Environment {env_id} not found in {self.mp4_paths}")
|
23 |
-
return None
|
24 |
-
with open(self.mp4_paths[env_id], "rb") as f:
|
25 |
-
return f.read()
|
|
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
from huggingface_hub import HfApi, login
|
4 |
+
|
5 |
from .get_environments import get_environments
|
6 |
+
from .get_files import get_mp4_paths
|
7 |
|
8 |
|
9 |
class HuggingFaceClient:
|
10 |
def __init__(self) -> None:
|
11 |
+
login(
|
12 |
token=os.environ.get("HUGGINGFACE_TOKEN"),
|
13 |
add_to_git_credential=True,
|
14 |
new_session=False,
|
|
|
16 |
self.hf_api = HfApi()
|
17 |
self.environments = get_environments(self.hf_api)
|
18 |
self.mp4_paths = get_mp4_paths(environments=self.environments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/typing.py
CHANGED
@@ -17,4 +17,4 @@ class RolloutPolicy(enum.Enum):
|
|
17 |
class CpuConfig(BaseModel):
|
18 |
agent_type: AgentType
|
19 |
simulations: int | None = None
|
20 |
-
rollout_policy: RolloutPolicy | None = None
|
|
|
17 |
class CpuConfig(BaseModel):
|
18 |
agent_type: AgentType
|
19 |
simulations: int | None = None
|
20 |
+
rollout_policy: RolloutPolicy | None = None
|