Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- src/app.py +8 -3
- src/app_state.py +17 -8
- src/huggingface/get_environments.py +6 -1
- src/huggingface/get_files.py +2 -0
- src/huggingface/huggingface_client.py +2 -2
- src/typing.py +2 -0
src/app.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
-
|
2 |
from pathlib import Path
|
3 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
from fastapi import Depends, FastAPI, Request, status
|
6 |
from fastapi.exceptions import RequestValidationError
|
@@ -28,7 +33,7 @@ def stream_mp4(mp4_path: Path) -> StreamingResponse:
|
|
28 |
return StreamingResponse(content=iter_file(), media_type="video/mp4")
|
29 |
|
30 |
|
31 |
-
ObservationType =
|
32 |
|
33 |
|
34 |
class GridResponseType(BaseModel):
|
|
|
1 |
+
import sys
|
2 |
from pathlib import Path
|
3 |
+
from typing import Any, Generator, List
|
4 |
+
|
5 |
+
if sys.version_info[:2] >= (3, 11):
|
6 |
+
from typing import Annotated
|
7 |
+
else:
|
8 |
+
from typing_extensions import Annotated
|
9 |
|
10 |
from fastapi import Depends, FastAPI, Request, status
|
11 |
from fastapi.exceptions import RequestValidationError
|
|
|
33 |
return StreamingResponse(content=iter_file(), media_type="video/mp4")
|
34 |
|
35 |
|
36 |
+
ObservationType = List[Board]
|
37 |
|
38 |
|
39 |
class GridResponseType(BaseModel):
|
src/app_state.py
CHANGED
@@ -1,20 +1,25 @@
|
|
1 |
-
from
|
|
|
|
|
2 |
|
3 |
import numpy as np
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
except ImportError:
|
8 |
-
from typing_extensions import Self
|
9 |
|
10 |
-
from
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from litrl import make_multiagent
|
13 |
from litrl.algo.mcts.agent import MCTSAgent
|
14 |
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
|
15 |
from litrl.algo.mcts.rollout import VanillaRollout
|
16 |
from litrl.common.agent import Agent, RandomMultiAgent
|
17 |
-
from litrl.env.connect_four import ConnectFour
|
18 |
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
|
19 |
from src.typing import AgentType, CpuConfig, RolloutPolicy
|
20 |
|
@@ -34,7 +39,7 @@ class AppState:
|
|
34 |
self.set_agent() # TODO in properties setter.
|
35 |
self.agent: Agent[Any, Any]
|
36 |
|
37 |
-
def __new__(cls: type[
|
38 |
if cls._instance is None:
|
39 |
cls._instance = super().__new__(cls)
|
40 |
cls._instance.setup()
|
@@ -76,6 +81,10 @@ class AppState:
|
|
76 |
def get_action(self) -> int:
|
77 |
return self.agent.get_action(env=self.env)
|
78 |
|
|
|
|
|
|
|
|
|
79 |
def inform_action(self, action: int) -> None:
|
80 |
"""Update the agent's state as a result of external changes to the environment."""
|
81 |
if isinstance(self.agent, MCTSAgent):
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any
|
4 |
|
5 |
import numpy as np
|
6 |
+
from loguru import logger
|
7 |
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
import sys
|
|
|
|
|
10 |
|
11 |
+
from litrl.env.connect_four import ConnectFour
|
12 |
+
|
13 |
+
if sys.version_info[:2] >= (3, 11):
|
14 |
+
from typing import Self
|
15 |
+
else:
|
16 |
+
from typing_extensions import Self
|
17 |
|
18 |
from litrl import make_multiagent
|
19 |
from litrl.algo.mcts.agent import MCTSAgent
|
20 |
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
|
21 |
from litrl.algo.mcts.rollout import VanillaRollout
|
22 |
from litrl.common.agent import Agent, RandomMultiAgent
|
|
|
23 |
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
|
24 |
from src.typing import AgentType, CpuConfig, RolloutPolicy
|
25 |
|
|
|
39 |
self.set_agent() # TODO in properties setter.
|
40 |
self.agent: Agent[Any, Any]
|
41 |
|
42 |
+
def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
|
43 |
if cls._instance is None:
|
44 |
cls._instance = super().__new__(cls)
|
45 |
cls._instance.setup()
|
|
|
81 |
def get_action(self) -> int:
|
82 |
return self.agent.get_action(env=self.env)
|
83 |
|
84 |
+
def inform_reset(self) -> None:
|
85 |
+
if isinstance(self.agent, MCTSAgent):
|
86 |
+
self.agent.inform_reset()
|
87 |
+
|
88 |
def inform_action(self, action: int) -> None:
|
89 |
"""Update the agent's state as a result of external changes to the environment."""
|
90 |
if isinstance(self.agent, MCTSAgent):
|
src/huggingface/get_environments.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
|
4 |
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING
|
4 |
+
|
5 |
+
if TYPE_CHECKING:
|
6 |
+
from huggingface_hub import HfApi
|
7 |
|
8 |
from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
|
9 |
|
src/huggingface/get_files.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import yaml
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
from pathlib import Path
|
4 |
|
5 |
import yaml
|
src/huggingface/huggingface_client.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from huggingface_hub import HfApi
|
2 |
|
3 |
-
from .get_environments import get_environments
|
4 |
-
from .get_files import get_mp4_paths
|
5 |
|
6 |
|
7 |
class HuggingFaceClient:
|
|
|
1 |
from huggingface_hub import HfApi
|
2 |
|
3 |
+
from src.huggingface.get_environments import get_environments
|
4 |
+
from src.huggingface.get_files import get_mp4_paths
|
5 |
|
6 |
|
7 |
class HuggingFaceClient:
|
src/typing.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import enum
|
2 |
|
3 |
from pydantic import BaseModel
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import enum
|
4 |
|
5 |
from pydantic import BaseModel
|