c-gohlke commited on
Commit
bafb458
·
1 Parent(s): 0d17912

Upload folder using huggingface_hub

Browse files
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (153 Bytes). View file
 
src/__pycache__/app.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
src/__pycache__/app.cpython-39.pyc ADDED
Binary file (210 Bytes). View file
 
src/__pycache__/app_state.cpython-311.pyc ADDED
Binary file (4.36 kB). View file
 
src/__pycache__/cpu_config.cpython-311.pyc ADDED
Binary file (1.83 kB). View file
 
src/__pycache__/create_app.cpython-311.pyc ADDED
Binary file (7.25 kB). View file
 
src/__pycache__/typing.cpython-311.pyc ADDED
Binary file (1.21 kB). View file
 
src/app_state.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from litrl import make_multiagent
4
+ from litrl.algo.mcts.agent import MCTSAgent
5
+ from litrl.algo.mcts.mcts_config import MCTSConfig
6
+ from litrl.algo.mcts.rollout import VanillaRollout
7
+ from litrl.algo.sac.agent import OnnxSacDeterministicAgent
8
+ from litrl.common.agent import Agent, RandomMultiAgent
9
+ from litrl.env.connect_four import Board
10
+ from litrl.env.set_state import set_state
11
+
12
+ from src.typing import AgentType, CpuConfig, RolloutPolicy
13
+
14
+
15
+ class AppState:
16
+ def __init__(self) -> None:
17
+ self.env = make_multiagent(id="ConnectFour-v3", render_mode="human")
18
+ self.env.reset(seed=123)
19
+ self.agent: Agent[Any, Any] | None = None
20
+ self.cpu_config: CpuConfig | None = None
21
+
22
+ def set_board(self, board: Board) -> None:
23
+ set_state(env=self.env, board=board)
24
+
25
+ def set_config(self, cpu_config: CpuConfig) -> None:
26
+ if (
27
+ self.agent is None
28
+ or self.cpu_config is None
29
+ or cpu_config != self.cpu_config
30
+ ):
31
+ self.cpu_config = cpu_config
32
+ self.set_agent()
33
+
34
+ def create_rollout(self) -> Agent[Any, Any]:
35
+ if self.cpu_config is None:
36
+ raise ValueError("self.cpu_config is None")
37
+ match self.cpu_config.rollout_policy:
38
+ case None:
39
+ return RandomMultiAgent()
40
+ case RolloutPolicy.SAC:
41
+ return OnnxSacDeterministicAgent()
42
+ case RolloutPolicy.RANDOM:
43
+ return RandomMultiAgent()
44
+ case _:
45
+ raise NotImplementedError(
46
+ f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
47
+ )
48
+
49
+ def set_agent(self) -> None:
50
+ if self.cpu_config is None:
51
+ raise ValueError("self.cpu_config is None")
52
+
53
+ match self.cpu_config.agent_type:
54
+ case AgentType.MCTS:
55
+ rollout_agent = self.create_rollout()
56
+ mcts_config = MCTSConfig(
57
+ simulations=self.cpu_config.simulations or 50,
58
+ rollout_strategy=VanillaRollout(rollout_agent=rollout_agent),
59
+ )
60
+ self.agent = MCTSAgent(cfg=mcts_config)
61
+ case AgentType.RANDOM:
62
+ self.agent = RandomMultiAgent()
63
+ case AgentType.SAC:
64
+ self.agent = OnnxSacDeterministicAgent()
65
+ case _:
66
+ raise NotImplementedError(
67
+ f"cpu_config.name: {self.cpu_config.agent_type}"
68
+ )
69
+
70
+ def get_action(self) -> int:
71
+ if self.agent is None:
72
+ raise ValueError("self.agent is None")
73
+ return self.agent.get_action(env=self.env)
src/connect4_backend.egg-info/PKG-INFO ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: connect4-backend
3
+ Version: 0.0.1
4
+ Author: Clement Gohlke
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: License :: OSI Approved :: MIT License
7
+ Classifier: Operating System :: OS Independent
8
+ Requires-Python: >=3.11
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: tensordict@ git+https://github.com/pytorch/tensordict.git@c3caa7612275306ce72697a82d5252681ddae0ab
11
+ Requires-Dist: torchrl@ git+https://github.com/pytorch/rl.git@1bb192e0f3ad9e7b8c6fa769bfa3bb9d82ca4f29
12
+ Requires-Dist: litrl~=0.0.9
13
+ Requires-Dist: fastapi==0.104.1
14
+ Requires-Dist: uvicorn==0.25.0
15
+ Requires-Dist: moviepy==1.0.3
16
+ Provides-Extra: test
17
+ Requires-Dist: pytest==7.4.4; extra == "test"
18
+ Requires-Dist: mypy==1.8.0; extra == "test"
19
+ Requires-Dist: httpx==0.26.0; extra == "test"
20
+
21
+ ---
22
+ title: Connect4
23
+ emoji: 🌐
24
+ colorFrom: blue
25
+ colorTo: yellow
26
+ sdk: docker
27
+ pinned: false
28
+ license: mit
29
+ ---
30
+
31
+ frontend adapted from [github](https://github.com/jprioses/connect-four-game)
32
+
33
+ Check out the configuration reference at [huggingface](https://huggingface.co/docs/hub/spaces-config-reference)
34
+
35
+ ## TODO
36
+
37
+ link to github repo
38
+
39
+ test block forced move
40
+
41
+ Connect to AWS database https://aws.amazon.com/rds/?p=ft&c=db&z=3
42
+ using alembic and sqlmodel
src/connect4_backend.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/__init__.py
4
+ src/app_state.py
5
+ src/create_app.py
6
+ src/typing.py
7
+ src/connect4_backend.egg-info/PKG-INFO
8
+ src/connect4_backend.egg-info/SOURCES.txt
9
+ src/connect4_backend.egg-info/dependency_links.txt
10
+ src/connect4_backend.egg-info/requires.txt
11
+ src/connect4_backend.egg-info/top_level.txt
src/connect4_backend.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/connect4_backend.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensordict@ git+https://github.com/pytorch/tensordict.git@c3caa7612275306ce72697a82d5252681ddae0ab
2
+ torchrl@ git+https://github.com/pytorch/rl.git@1bb192e0f3ad9e7b8c6fa769bfa3bb9d82ca4f29
3
+ litrl~=0.0.9
4
+ fastapi==0.104.1
5
+ uvicorn==0.25.0
6
+ moviepy==1.0.3
7
+
8
+ [test]
9
+ pytest==7.4.4
10
+ mypy==1.8.0
11
+ httpx==0.26.0
src/connect4_backend.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __init__
2
+ app_state
3
+ create_app
4
+ typing
src/create_app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Any, Generator
2
+ from pathlib import Path
3
+ from gymnasium.wrappers.record_video import RecordVideo
4
+ from litrl.env.make import make
5
+ from litrl.common.agent import RandomAgent
6
+ from litrl.env.typing import SingleAgentId
7
+ from fastapi import Depends, FastAPI, Request, status
8
+ from fastapi.exceptions import RequestValidationError
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
+ from litrl.env.connect_four import Board
12
+ from loguru import logger
13
+ from fastapi.responses import StreamingResponse, FileResponse
14
+ from moviepy.editor import VideoFileClip
15
+ from src.app_state import AppState
16
+ from src.typing import CpuConfig
17
+
18
+
19
+ def create_app() -> FastAPI:
20
+ app = FastAPI()
21
+
22
+ @app.post("/", response_model=int)
23
+ def bot_action(
24
+ board: Board,
25
+ cpuConfig: CpuConfig,
26
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
27
+ ) -> int:
28
+ app_state.set_config(cpu_config=cpuConfig)
29
+ app_state.set_board(board=board)
30
+ return app_state.get_action()
31
+
32
+ @app.post(path=f"/game", response_model=str)
33
+ def bot_action(
34
+ env_id: SingleAgentId,
35
+ ) -> str:
36
+ env = make(id=env_id, render_mode="rgb_array")
37
+ env = RecordVideo(
38
+ env=env,
39
+ video_folder="tmp",
40
+ )
41
+ env.reset(seed=123)
42
+ agent = RandomAgent[Any, Any]()
43
+ terminated, truncated = False, False
44
+ while not (terminated or truncated):
45
+ action = agent.get_action(env=env)
46
+ _, _, terminated, truncated, _ = env.step(action=action)
47
+ env.render()
48
+ env.video_recorder.close()
49
+
50
+ def iterfile()-> Generator[bytes, Any, None]:
51
+ with Path(env.video_recorder.path).open(mode="rb") as env_file:
52
+ yield from env_file
53
+
54
+ return StreamingResponse(content=iterfile(), media_type="video/mp4")
55
+
56
+ @app.get(path=f"/gif")
57
+ def bot_action(
58
+ env_id: SingleAgentId,
59
+ ) -> str:
60
+ env = make(id=env_id, render_mode="rgb_array")
61
+ env = RecordVideo(
62
+ env=env,
63
+ video_folder="tmp",
64
+ )
65
+ env.reset(seed=123)
66
+ agent = RandomAgent[Any, Any]()
67
+ terminated, truncated = False, False
68
+ while not (terminated or truncated):
69
+ action = agent.get_action(env=env)
70
+ _, _, terminated, truncated, _ = env.step(action=action)
71
+ env.render()
72
+ env.video_recorder.close()
73
+
74
+ mp4_path = Path(env.video_recorder.path)
75
+ video_clip = VideoFileClip(str(mp4_path))
76
+ gif_path = mp4_path.with_suffix(".gif")
77
+ video_clip.write_gif(str(gif_path))#, fps=30) # TODO check fps
78
+ return FileResponse(str(gif_path), media_type="image/gif")
79
+
80
+ @app.exception_handler(exc_class_or_status_code=RequestValidationError)
81
+ async def validation_exception_handler(
82
+ request: Request, exc: RequestValidationError
83
+ ) -> JSONResponse:
84
+ logger.debug(f"url: {request.url}")
85
+ if hasattr(request, "_body"):
86
+ logger.debug(f"body: {request._body}")
87
+ logger.debug(f"header: {request.headers}")
88
+ logger.error(f"{request}: {exc}")
89
+ exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
90
+ content = {"status_code": 10422, "message": exc_str, "data": None}
91
+ return JSONResponse(
92
+ content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
93
+ )
94
+
95
+ app.add_middleware(
96
+ middleware_class=CORSMiddleware,
97
+ allow_origins="*",
98
+ allow_credentials=True,
99
+ allow_methods=["*"],
100
+ allow_headers=["*"],
101
+ )
102
+ return app
103
+
104
+ if __name__ == "__main__":
105
+ import uvicorn
106
+ import argparse
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--host", type=str, default="0.0.0.0")
109
+ parser.add_argument("--port", type=int, default=8000)
110
+ args = parser.parse_args()
111
+ config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info")
112
+ server = uvicorn.Server(config=config)
113
+ server.run()
src/huggingface/get_best_mp4.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_mp4_paths(environments: list[str]) -> dict[str, Path]:
4
+ mp4_paths: dict[str, Path] = {}
5
+ for env in environments:
6
+ mp4_paths[env] = get_best_mp4(env)
7
+ return mp4_paths
8
+
9
+
10
+ def get_best_mp4(env: str) -> Path:
11
+ hf_env_results_path = f"models/{env}/results.yaml"
12
+ local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
13
+ with open(local_env_results_path, "r") as f:
14
+ env_results = yaml.load(f, Loader=yaml.FullLoader)
15
+ best_model_type = max(env_results, key=lambda model: env_results[model])
16
+ model_results_path = f"models/{env}/{best_model_type}/results.yaml"
17
+ local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
18
+ with open(local_model_results_path, "r") as f:
19
+ model_results = yaml.load(f, Loader=yaml.FullLoader)
20
+
21
+ best_model = max(model_results, key=lambda model: model_results[model])
22
+ hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.mp4"
23
+ return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
src/huggingface/get_environments.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def get_environments() -> list[str]:
4
+ environments: list[str] = []
5
+ files = hf_api.list_repo_files(MODEL_REPO, repo_type=MODEL_REPO_TYPE)
6
+ for file in files:
7
+ vals = file.split("/")
8
+ # e.g. ['models', 'CartPole-v1', 'results.yaml']
9
+ if len(vals) == ENV_RESULTS_FILE_DEPTH and vals[2] == "results.yaml" and vals[0] == "models":
10
+ environments.append(vals[1])
11
+ return environments
src/huggingface/get_gif.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_best_gif(env: str) -> Path:
2
+ hf_env_results_path = f"models/{env}/results.yaml"
3
+ local_env_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_env_results_path)
4
+ with open(local_env_results_path, "r") as f:
5
+ env_results = yaml.load(f, Loader=yaml.FullLoader)
6
+ best_model_type = max(env_results, key=lambda model: env_results[model])
7
+ model_results_path = f"models/{env}/{best_model_type}/results.yaml"
8
+ local_model_results_path = hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=model_results_path)
9
+ with open(local_model_results_path, "r") as f:
10
+ model_results = yaml.load(f, Loader=yaml.FullLoader)
11
+
12
+ best_model = max(model_results, key=lambda model: model_results[model])
13
+ hf_gif_path = f"models/{env}/{best_model_type}/{best_model}/demo.gif"
14
+ return Path(hf_hub_download(repo_id=MODEL_REPO, repo_type=MODEL_REPO_TYPE, filename=hf_gif_path))
15
+
16
+
17
+
18
+ def get_gif_paths(environments: list[str]) -> dict[str, Path]:
19
+ gif_paths: dict[str, Path] = {}
20
+ for env in environments:
21
+ gif_paths[env] = get_best_gif(env)
22
+ return gif_paths
src/huggingface/huggingface_client.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from huggingface_hub import HfApi, login
3
+ import os
4
+ import yaml
5
+ from pathlib import Path
6
+ from loguru import logger
7
+ from PIL import Image
8
+ from .get_best_mp4 import get_mp4_paths
9
+
10
+ SPACE_REPO = "c-gohlke/litrl"
11
+ SPACE_REPO_TYPE = "space"
12
+ MODEL_REPO = "c-gohlke/litrl"
13
+ MODEL_REPO_TYPE = "model"
14
+
15
+ ENV_RESULTS_FILE_DEPTH = 3
16
+
17
+
18
+ class HugingFaceClient:
19
+ def __init__(self) -> None:
20
+ login( # type: ignore[no-untyped-call]
21
+ token=os.environ.get("HUGGINGFACE_TOKEN"),
22
+ add_to_git_credential=True,
23
+ new_session=False,
24
+ )
25
+ self.hf_api = HfApi()
26
+ self.mp4_paths = get_mp4_paths()
27
+
28
+
29
+ def api_predict(self, env_id: str)-> bytes|None:
30
+ if env_id not in self.mp4_paths:
31
+ logger.error(f"Environment {env_id} not found in {self.mp4_paths}")
32
+ return None
33
+ with open(self.mp4_paths[env_id], "rb") as f:
34
+ return f.read()
src/typing.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class AgentType(enum.Enum):
7
+ RANDOM = "random"
8
+ MCTS = "mcts"
9
+ SAC = "sac"
10
+
11
+
12
+ class RolloutPolicy(enum.Enum):
13
+ RANDOM = "random"
14
+ SAC = "sac"
15
+
16
+
17
+ class CpuConfig(BaseModel):
18
+ agent_type: AgentType
19
+ simulations: int | None = None
20
+ rollout_policy: RolloutPolicy | None = None
21
+
22
+ # def __format__(self, __format_spec: str) -> str:
23
+ # raise ValueError(f"__format__ not implemented for {self.__class__.__name__}")
24
+ # return super().__format__(__format_spec)