c-gohlke commited on
Commit
5b344d4
·
verified ·
1 Parent(s): 0a38bdd

Upload folder using huggingface_hub

Browse files
src/app.py CHANGED
@@ -1,6 +1,11 @@
1
- from collections.abc import Generator
2
  from pathlib import Path
3
- from typing import Annotated, Any
 
 
 
 
 
4
 
5
  from fastapi import Depends, FastAPI, Request, status
6
  from fastapi.exceptions import RequestValidationError
@@ -28,7 +33,7 @@ def stream_mp4(mp4_path: Path) -> StreamingResponse:
28
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
29
 
30
 
31
- ObservationType = list[Board]
32
 
33
 
34
  class GridResponseType(BaseModel):
 
1
+ import sys
2
  from pathlib import Path
3
+ from typing import Any, Generator, List
4
+
5
+ if sys.version_info[:2] >= (3, 11):
6
+ from typing import Annotated
7
+ else:
8
+ from typing_extensions import Annotated
9
 
10
  from fastapi import Depends, FastAPI, Request, status
11
  from fastapi.exceptions import RequestValidationError
 
33
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
34
 
35
 
36
+ ObservationType = List[Board]
37
 
38
 
39
  class GridResponseType(BaseModel):
src/app_state.py CHANGED
@@ -1,20 +1,25 @@
1
- from typing import Any
 
 
2
 
3
  import numpy as np
 
4
 
5
- try:
6
- from typing import Self
7
- except ImportError:
8
- from typing_extensions import Self
9
 
10
- from loguru import logger
 
 
 
 
 
11
 
12
  from litrl import make_multiagent
13
  from litrl.algo.mcts.agent import MCTSAgent
14
  from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
15
  from litrl.algo.mcts.rollout import VanillaRollout
16
  from litrl.common.agent import Agent, RandomMultiAgent
17
- from litrl.env.connect_four import ConnectFour
18
  from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
19
  from src.typing import AgentType, CpuConfig, RolloutPolicy
20
 
@@ -34,7 +39,7 @@ class AppState:
34
  self.set_agent() # TODO in properties setter.
35
  self.agent: Agent[Any, Any]
36
 
37
- def __new__(cls: type["AppState"]) -> "AppState":
38
  if cls._instance is None:
39
  cls._instance = super().__new__(cls)
40
  cls._instance.setup()
@@ -76,6 +81,10 @@ class AppState:
76
  def get_action(self) -> int:
77
  return self.agent.get_action(env=self.env)
78
 
 
 
 
 
79
  def inform_action(self, action: int) -> None:
80
  """Update the agent's state as a result of external changes to the environment."""
81
  if isinstance(self.agent, MCTSAgent):
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
 
5
  import numpy as np
6
+ from loguru import logger
7
 
8
+ if TYPE_CHECKING:
9
+ import sys
 
 
10
 
11
+ from litrl.env.connect_four import ConnectFour
12
+
13
+ if sys.version_info[:2] >= (3, 11):
14
+ from typing import Self
15
+ else:
16
+ from typing_extensions import Self
17
 
18
  from litrl import make_multiagent
19
  from litrl.algo.mcts.agent import MCTSAgent
20
  from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
21
  from litrl.algo.mcts.rollout import VanillaRollout
22
  from litrl.common.agent import Agent, RandomMultiAgent
 
23
  from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
24
  from src.typing import AgentType, CpuConfig, RolloutPolicy
25
 
 
39
  self.set_agent() # TODO in properties setter.
40
  self.agent: Agent[Any, Any]
41
 
42
+ def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
43
  if cls._instance is None:
44
  cls._instance = super().__new__(cls)
45
  cls._instance.setup()
 
81
  def get_action(self) -> int:
82
  return self.agent.get_action(env=self.env)
83
 
84
+ def inform_reset(self) -> None:
85
+ if isinstance(self.agent, MCTSAgent):
86
+ self.agent.inform_reset()
87
+
88
  def inform_action(self, action: int) -> None:
89
  """Update the agent's state as a result of external changes to the environment."""
90
  if isinstance(self.agent, MCTSAgent):
src/huggingface/get_environments.py CHANGED
@@ -1,4 +1,9 @@
1
- from huggingface_hub import HfApi
 
 
 
 
 
2
 
3
  from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
4
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from huggingface_hub import HfApi
7
 
8
  from src.constants import ENV_RESULTS_FILE_DEPTH, MODEL_REPO, MODEL_REPO_TYPE
9
 
src/huggingface/get_files.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from pathlib import Path
2
 
3
  import yaml
 
1
+ from __future__ import annotations
2
+
3
  from pathlib import Path
4
 
5
  import yaml
src/huggingface/huggingface_client.py CHANGED
@@ -1,7 +1,7 @@
1
  from huggingface_hub import HfApi
2
 
3
- from .get_environments import get_environments
4
- from .get_files import get_mp4_paths
5
 
6
 
7
  class HuggingFaceClient:
 
1
  from huggingface_hub import HfApi
2
 
3
+ from src.huggingface.get_environments import get_environments
4
+ from src.huggingface.get_files import get_mp4_paths
5
 
6
 
7
  class HuggingFaceClient:
src/typing.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import enum
2
 
3
  from pydantic import BaseModel
 
1
+ from __future__ import annotations
2
+
3
  import enum
4
 
5
  from pydantic import BaseModel