Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/app.cpython-311.pyc +0 -0
- src/__pycache__/app.cpython-39.pyc +0 -0
- src/__pycache__/app_state.cpython-311.pyc +0 -0
- src/__pycache__/cpu_config.cpython-311.pyc +0 -0
- src/__pycache__/create_app.cpython-311.pyc +0 -0
- src/__pycache__/typing.cpython-311.pyc +0 -0
- src/app_state.py +73 -0
- src/connect4_backend.egg-info/PKG-INFO +42 -0
- src/connect4_backend.egg-info/SOURCES.txt +11 -0
- src/connect4_backend.egg-info/dependency_links.txt +1 -0
- src/connect4_backend.egg-info/requires.txt +11 -0
- src/connect4_backend.egg-info/top_level.txt +4 -0
- src/create_app.py +113 -0
- src/huggingface/get_best_mp4.py +23 -0
- src/huggingface/get_environments.py +11 -0
- src/huggingface/get_gif.py +22 -0
- src/huggingface/huggingface_client.py +34 -0
- src/typing.py +24 -0
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (171 Bytes). View file
|
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (153 Bytes). View file
|
|
src/__pycache__/app.cpython-311.pyc
ADDED
Binary file (263 Bytes). View file
|
|
src/__pycache__/app.cpython-39.pyc
ADDED
Binary file (210 Bytes). View file
|
|
src/__pycache__/app_state.cpython-311.pyc
ADDED
Binary file (4.36 kB). View file
|
|
src/__pycache__/cpu_config.cpython-311.pyc
ADDED
Binary file (1.83 kB). View file
|
|
src/__pycache__/create_app.cpython-311.pyc
ADDED
Binary file (7.25 kB). View file
|
|
src/__pycache__/typing.cpython-311.pyc
ADDED
Binary file (1.21 kB). View file
|
|
src/app_state.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from litrl import make_multiagent
|
4 |
+
from litrl.algo.mcts.agent import MCTSAgent
|
5 |
+
from litrl.algo.mcts.mcts_config import MCTSConfig
|
6 |
+
from litrl.algo.mcts.rollout import VanillaRollout
|
7 |
+
from litrl.algo.sac.agent import OnnxSacDeterministicAgent
|
8 |
+
from litrl.common.agent import Agent, RandomMultiAgent
|
9 |
+
from litrl.env.connect_four import Board
|
10 |
+
from litrl.env.set_state import set_state
|
11 |
+
|
12 |
+
from src.typing import AgentType, CpuConfig, RolloutPolicy
|
13 |
+
|
14 |
+
|
15 |
+
class AppState:
|
16 |
+
def __init__(self) -> None:
|
17 |
+
self.env = make_multiagent(id="ConnectFour-v3", render_mode="human")
|
18 |
+
self.env.reset(seed=123)
|
19 |
+
self.agent: Agent[Any, Any] | None = None
|
20 |
+
self.cpu_config: CpuConfig | None = None
|
21 |
+
|
22 |
+
def set_board(self, board: Board) -> None:
|
23 |
+
set_state(env=self.env, board=board)
|
24 |
+
|
25 |
+
def set_config(self, cpu_config: CpuConfig) -> None:
|
26 |
+
if (
|
27 |
+
self.agent is None
|
28 |
+
or self.cpu_config is None
|
29 |
+
or cpu_config != self.cpu_config
|
30 |
+
):
|
31 |
+
self.cpu_config = cpu_config
|
32 |
+
self.set_agent()
|
33 |
+
|
34 |
+
def create_rollout(self) -> Agent[Any, Any]:
|
35 |
+
if self.cpu_config is None:
|
36 |
+
raise ValueError("self.cpu_config is None")
|
37 |
+
match self.cpu_config.rollout_policy:
|
38 |
+
case None:
|
39 |
+
return RandomMultiAgent()
|
40 |
+
case RolloutPolicy.SAC:
|
41 |
+
return OnnxSacDeterministicAgent()
|
42 |
+
case RolloutPolicy.RANDOM:
|
43 |
+
return RandomMultiAgent()
|
44 |
+
case _:
|
45 |
+
raise NotImplementedError(
|
46 |
+
f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
|
47 |
+
)
|
48 |
+
|
49 |
+
def set_agent(self) -> None:
|
50 |
+
if self.cpu_config is None:
|
51 |
+
raise ValueError("self.cpu_config is None")
|
52 |
+
|
53 |
+
match self.cpu_config.agent_type:
|
54 |
+
case AgentType.MCTS:
|
55 |
+
rollout_agent = self.create_rollout()
|
56 |
+
mcts_config = MCTSConfig(
|
57 |
+
simulations=self.cpu_config.simulations or 50,
|
58 |
+
rollout_strategy=VanillaRollout(rollout_agent=rollout_agent),
|
59 |
+
)
|
60 |
+
self.agent = MCTSAgent(cfg=mcts_config)
|
61 |
+
case AgentType.RANDOM:
|
62 |
+
self.agent = RandomMultiAgent()
|
63 |
+
case AgentType.SAC:
|
64 |
+
self.agent = OnnxSacDeterministicAgent()
|
65 |
+
case _:
|
66 |
+
raise NotImplementedError(
|
67 |
+
f"cpu_config.name: {self.cpu_config.agent_type}"
|
68 |
+
)
|
69 |
+
|
70 |
+
def get_action(self) -> int:
|
71 |
+
if self.agent is None:
|
72 |
+
raise ValueError("self.agent is None")
|
73 |
+
return self.agent.get_action(env=self.env)
|
src/connect4_backend.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: connect4-backend
|
3 |
+
Version: 0.0.1
|
4 |
+
Author: Clement Gohlke
|
5 |
+
Classifier: Programming Language :: Python :: 3
|
6 |
+
Classifier: License :: OSI Approved :: MIT License
|
7 |
+
Classifier: Operating System :: OS Independent
|
8 |
+
Requires-Python: >=3.11
|
9 |
+
Description-Content-Type: text/markdown
|
10 |
+
Requires-Dist: tensordict@ git+https://github.com/pytorch/tensordict.git@c3caa7612275306ce72697a82d5252681ddae0ab
|
11 |
+
Requires-Dist: torchrl@ git+https://github.com/pytorch/rl.git@1bb192e0f3ad9e7b8c6fa769bfa3bb9d82ca4f29
|
12 |
+
Requires-Dist: litrl~=0.0.9
|
13 |
+
Requires-Dist: fastapi==0.104.1
|
14 |
+
Requires-Dist: uvicorn==0.25.0
|
15 |
+
Requires-Dist: moviepy==1.0.3
|
16 |
+
Provides-Extra: test
|
17 |
+
Requires-Dist: pytest==7.4.4; extra == "test"
|
18 |
+
Requires-Dist: mypy==1.8.0; extra == "test"
|
19 |
+
Requires-Dist: httpx==0.26.0; extra == "test"
|
20 |
+
|
21 |
+
---
|
22 |
+
title: Connect4
|
23 |
+
emoji: 🌐
|
24 |
+
colorFrom: blue
|
25 |
+
colorTo: yellow
|
26 |
+
sdk: docker
|
27 |
+
pinned: false
|
28 |
+
license: mit
|
29 |
+
---
|
30 |
+
|
31 |
+
frontend adapted from [github](https://github.com/jprioses/connect-four-game)
|
32 |
+
|
33 |
+
Check out the configuration reference at [huggingface](https://huggingface.co/docs/hub/spaces-config-reference)
|
34 |
+
|
35 |
+
## TODO
|
36 |
+
|
37 |
+
link to github repo
|
38 |
+
|
39 |
+
test block forced move
|
40 |
+
|
41 |
+
Connect to AWS database https://aws.amazon.com/rds/?p=ft&c=db&z=3
|
42 |
+
using alembic and sqlmodel
|
src/connect4_backend.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
pyproject.toml
|
3 |
+
src/__init__.py
|
4 |
+
src/app_state.py
|
5 |
+
src/create_app.py
|
6 |
+
src/typing.py
|
7 |
+
src/connect4_backend.egg-info/PKG-INFO
|
8 |
+
src/connect4_backend.egg-info/SOURCES.txt
|
9 |
+
src/connect4_backend.egg-info/dependency_links.txt
|
10 |
+
src/connect4_backend.egg-info/requires.txt
|
11 |
+
src/connect4_backend.egg-info/top_level.txt
|
src/connect4_backend.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/connect4_backend.egg-info/requires.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensordict@ git+https://github.com/pytorch/tensordict.git@c3caa7612275306ce72697a82d5252681ddae0ab
|
2 |
+
torchrl@ git+https://github.com/pytorch/rl.git@1bb192e0f3ad9e7b8c6fa769bfa3bb9d82ca4f29
|
3 |
+
litrl~=0.0.9
|
4 |
+
fastapi==0.104.1
|
5 |
+
uvicorn==0.25.0
|
6 |
+
moviepy==1.0.3
|
7 |
+
|
8 |
+
[test]
|
9 |
+
pytest==7.4.4
|
10 |
+
mypy==1.8.0
|
11 |
+
httpx==0.26.0
|
src/connect4_backend.egg-info/top_level.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__init__
|
2 |
+
app_state
|
3 |
+
create_app
|
4 |
+
typing
|
src/create_app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, Any, Generator
|
2 |
+
from pathlib import Path
|
3 |
+
from gymnasium.wrappers.record_video import RecordVideo
|
4 |
+
from litrl.env.make import make
|
5 |
+
from litrl.common.agent import RandomAgent
|
6 |
+
from litrl.env.typing import SingleAgentId
|
7 |
+
from fastapi import Depends, FastAPI, Request, status
|
8 |
+
from fastapi.exceptions import RequestValidationError
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from litrl.env.connect_four import Board
|
12 |
+
from loguru import logger
|
13 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
14 |
+
from moviepy.editor import VideoFileClip
|
15 |
+
from src.app_state import AppState
|
16 |
+
from src.typing import CpuConfig
|
17 |
+
|
18 |
+
|
19 |
+
def create_app() -> FastAPI:
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
@app.post("/", response_model=int)
|
23 |
+
def bot_action(
|
24 |
+
board: Board,
|
25 |
+
cpuConfig: CpuConfig,
|
26 |
+
app_state: Annotated[AppState, Depends(dependency=AppState)],
|
27 |
+
) -> int:
|
28 |
+
app_state.set_config(cpu_config=cpuConfig)
|
29 |
+
app_state.set_board(board=board)
|
30 |
+
return app_state.get_action()
|
31 |
+
|
32 |
+
@app.post(path=f"/game", response_model=str)
|
33 |
+
def bot_action(
|
34 |
+
env_id: SingleAgentId,
|
35 |
+
) -> str:
|
36 |
+
env = make(id=env_id, render_mode="rgb_array")
|
37 |
+
env = RecordVideo(
|
38 |
+
env=env,
|
39 |
+
video_folder="tmp",
|
40 |
+
)
|
41 |
+
env.reset(seed=123)
|
42 |
+
agent = RandomAgent[Any, Any]()
|
43 |
+
terminated, truncated = False, False
|
44 |
+
while not (terminated or truncated):
|
45 |
+
action = agent.get_action(env=env)
|
46 |
+
_, _, terminated, truncated, _ = env.step(action=action)
|
47 |
+
env.render()
|
48 |
+
env.video_recorder.close()
|
49 |
+
|
50 |
+
def iterfile()-> Generator[bytes, Any, None]:
|
51 |
+
with Path(env.video_recorder.path).open(mode="rb") as env_file:
|
52 |
+
yield from env_file
|
53 |
+
|
54 |
+
return StreamingResponse(content=iterfile(), media_type="video/mp4")
|
55 |
+
|
56 |
+
@app.get(path=f"/gif")
|
57 |
+
def bot_action(
|
58 |
+
env_id: SingleAgentId,
|
59 |
+
) -> str:
|
60 |
+
env = make(id=env_id, render_mode="rgb_array")
|
61 |
+
env = RecordVideo(
|
62 |
+
env=env,
|
63 |
+
video_folder="tmp",
|
64 |
+
)
|
65 |
+
env.reset(seed=123)
|
66 |
+
agent = RandomAgent[Any, Any]()
|
67 |
+
terminated, truncated = False, False
|
68 |
+
while not (terminated or truncated):
|
69 |
+
action = agent.get_action(env=env)
|
70 |
+
_, _, terminated, truncated, _ = env.step(action=action)
|
71 |
+
env.render()
|
72 |
+
env.video_recorder.close()
|
73 |
+
|
74 |
+
mp4_path = Path(env.video_recorder.path)
|
75 |
+
video_clip = VideoFileClip(str(mp4_path))
|
76 |
+
gif_path = mp4_path.with_suffix(".gif")
|
77 |
+
video_clip.write_gif(str(gif_path))#, fps=30) # TODO check fps
|
78 |
+
return FileResponse(str(gif_path), media_type="image/gif")
|
79 |
+
|
80 |
+
@app.exception_handler(exc_class_or_status_code=RequestValidationError)
|
81 |
+
async def validation_exception_handler(
|
82 |
+
request: Request, exc: RequestValidationError
|
83 |
+
) -> JSONResponse:
|
84 |
+
logger.debug(f"url: {request.url}")
|
85 |
+
if hasattr(request, "_body"):
|
86 |
+
logger.debug(f"body: {request._body}")
|
87 |
+
logger.debug(f"header: {request.headers}")
|
88 |
+
logger.error(f"{request}: {exc}")
|
89 |
+
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
|
90 |
+
content = {"status_code": 10422, "message": exc_str, "data": None}
|
91 |
+
return JSONResponse(
|
92 |
+
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
93 |
+
)
|
94 |
+
|
95 |
+
app.add_middleware(
|
96 |
+
middleware_class=CORSMiddleware,
|
97 |
+
allow_origins="*",
|
98 |
+
allow_credentials=True,
|
99 |
+
allow_methods=["*"],
|
100 |
+
allow_headers=["*"],
|
101 |
+
)
|
102 |
+
return app
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
import uvicorn
|
106 |
+
import argparse
|
107 |
+
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
109 |
+
parser.add_argument("--port", type=int, default=8000)
|
110 |
+
args = parser.parse_args()
|
111 |
+
config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info")
|
112 |
+
server = uvicorn.Server(config=config)
|
113 |
+
server.run()
|
src/huggingface/get_best_mp4.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
|
4 |
+
mp4_paths: dict[str, Path] = {}
|
5 |
+
for env in environments:
|
6 |
+
mp4_paths[env] = get_best_mp4(env)
|
7 |
+
return mp4_paths
|
8 |
+
|
9 |
+
|
10 |
+
def get_best_mp4(env: str) -> Path:
|
11 |
+
hf_env_results_path = f"models/{env}/results.yaml"
|
12 |
+
local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
|
13 |
+
with open(local_env_results_path, "r") as f:
|
14 |
+
env_results = yaml.load(f, Loader=yaml.FullLoader)
|
15 |
+
best_model_type = max(env_results, key=lambda model: env_results[model])
|
16 |
+
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
|
17 |
+
local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
|
18 |
+
with open(local_model_results_path, "r") as f:
|
19 |
+
model_results = yaml.load(f, Loader=yaml.FullLoader)
|
20 |
+
|
21 |
+
best_model = max(model_results, key=lambda model: model_results[model])
|
22 |
+
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
|
23 |
+
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
|
src/huggingface/get_environments.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
def get_environments() -> list[str]:
|
4 |
+
environments: list[str] = []
|
5 |
+
files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
|
6 |
+
for file in files:
|
7 |
+
vals = file.split("/")
|
8 |
+
# e.g. ['models', 'CartPole-v1', 'results.yaml']
|
9 |
+
if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
|
10 |
+
environments.append(vals[1])
|
11 |
+
return environments
|
src/huggingface/get_gif.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_best_gif(env: str) -> Path:
|
2 |
+
hf_env_results_path = f"models/{env}/results.yaml"
|
3 |
+
local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
|
4 |
+
with open(local_env_results_path, "r") as f:
|
5 |
+
env_results = yaml.load(f, Loader=yaml.FullLoader)
|
6 |
+
best_model_type = max(env_results, key=lambda model: env_results[model])
|
7 |
+
model_results_path = f"models/{env}/{best_model_type}/results.yaml"
|
8 |
+
local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
|
9 |
+
with open(local_model_results_path, "r") as f:
|
10 |
+
model_results = yaml.load(f, Loader=yaml.FullLoader)
|
11 |
+
|
12 |
+
best_model = max(model_results, key=lambda model: model_results[model])
|
13 |
+
hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
|
14 |
+
return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def get_gif_paths(environments: list[str]) -> dict[str, Path]:
|
19 |
+
gif_paths: dict[str, Path] = {}
|
20 |
+
for env in environments:
|
21 |
+
gif_paths[env] = get_best_gif(env)
|
22 |
+
return gif_paths
|
src/huggingface/huggingface_client.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
from huggingface_hub import HfApi, login
|
3 |
+
import os
|
4 |
+
import yaml
|
5 |
+
from pathlib import Path
|
6 |
+
from loguru import logger
|
7 |
+
from PIL import Image
|
8 |
+
from .get_best_mp4 import get_mp4_paths
|
9 |
+
|
10 |
+
SPACE_REPO = "c-gohlke/litrl"
|
11 |
+
SPACE_REPO_TYPE = "space"
|
12 |
+
MODEL_REPO = "c-gohlke/litrl"
|
13 |
+
MODEL_REPO_TYPE = "model"
|
14 |
+
|
15 |
+
ENV_RESULTS_FILE_DEPTH = 3
|
16 |
+
|
17 |
+
|
18 |
+
class HugingFaceClient:
|
19 |
+
def __init__(self) -> None:
|
20 |
+
login( # type: ignore[no-untyped-call]
|
21 |
+
token=os.environ.get("HUGGINGFACE_TOKEN"),
|
22 |
+
add_to_git_credential=True,
|
23 |
+
new_session=False,
|
24 |
+
)
|
25 |
+
self.hf_api = HfApi()
|
26 |
+
self.mp4_paths = get_mp4_paths()
|
27 |
+
|
28 |
+
|
29 |
+
def api_predict(self, env_id: str)-> bytes|None:
|
30 |
+
if env_id not in self.mp4_paths:
|
31 |
+
logger.error(f"Environment {env_id} not found in {self.mp4_paths}")
|
32 |
+
return None
|
33 |
+
with open(self.mp4_paths[env_id], "rb") as f:
|
34 |
+
return f.read()
|
src/typing.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class AgentType(enum.Enum):
|
7 |
+
RANDOM = "random"
|
8 |
+
MCTS = "mcts"
|
9 |
+
SAC = "sac"
|
10 |
+
|
11 |
+
|
12 |
+
class RolloutPolicy(enum.Enum):
|
13 |
+
RANDOM = "random"
|
14 |
+
SAC = "sac"
|
15 |
+
|
16 |
+
|
17 |
+
class CpuConfig(BaseModel):
|
18 |
+
agent_type: AgentType
|
19 |
+
simulations: int | None = None
|
20 |
+
rollout_policy: RolloutPolicy | None = None
|
21 |
+
|
22 |
+
# def __format__(self, __format_spec: str) -> str:
|
23 |
+
# raise ValueError(f"__format__ not implemented for {self.__class__.__name__}")
|
24 |
+
# return super().__format__(__format_spec)
|