c-gohlke commited on
Commit
acf3b96
·
verified ·
1 Parent(s): 0d04c8e

Upload folder using huggingface_hub

Browse files
src/app.py CHANGED
@@ -12,7 +12,7 @@ from pydantic import BaseModel
12
 
13
  from litrl.algo.mcts.agent import MCTSAgent
14
  from litrl.common.agent import RandomAgent
15
- from litrl.env.connect_four import ConnectFour
16
  from litrl.env.make import make
17
  from litrl.env.typing import GymId
18
  from src.app_state import AppState
@@ -28,7 +28,7 @@ def stream_mp4(mp4_path: Path) -> StreamingResponse:
28
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
29
 
30
 
31
- ObservationType = list[list[list[int]]]
32
 
33
 
34
  class GridResponseType(BaseModel):
@@ -36,6 +36,14 @@ class GridResponseType(BaseModel):
36
  done: bool
37
 
38
 
 
 
 
 
 
 
 
 
39
  def step(env: ConnectFour, action: int) -> GridResponseType:
40
  env.step(action)
41
  return observe(env)
@@ -61,7 +69,9 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
61
  action: int,
62
  app_state: Annotated[AppState, Depends(dependency=AppState)],
63
  ) -> GridResponseType:
64
- return step(app_state.env, action)
 
 
65
 
66
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
67
  def endpoint_observe(
@@ -69,14 +79,20 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
69
  ) -> GridResponseType:
70
  return observe(app_state.env)
71
 
72
- @app.post(path="/connect_four/bot_play", response_model=GridResponseType)
73
  def endpoint_bot_play(
74
  cpu_config: CpuConfig,
75
- app_state: Annotated[AppState, Depends(dependency=AppState)],
76
- ) -> GridResponseType:
77
  app_state.set_config(cpu_config)
78
  action = app_state.get_action()
79
- return step(app_state.env, action)
 
 
 
 
 
 
80
 
81
  @app.get(path="/connect_four/bot_progress", response_model=float)
82
  def endpoint_bot_progress(
@@ -85,9 +101,11 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
85
  if isinstance(app_state.agent, MCTSAgent):
86
  if app_state.cpu_config.simulations is None:
87
  raise ValueError
 
 
88
  return float(
89
- app_state.agent.mcts._root.visits / app_state.cpu_config.simulations, # noqa: SLF001
90
- ) # TODO why needed?
91
  return 1.0
92
 
93
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
 
12
 
13
  from litrl.algo.mcts.agent import MCTSAgent
14
  from litrl.common.agent import RandomAgent
15
+ from litrl.env.connect_four import Board, ConnectFour
16
  from litrl.env.make import make
17
  from litrl.env.typing import GymId
18
  from src.app_state import AppState
 
28
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
29
 
30
 
31
+ ObservationType = list[Board]
32
 
33
 
34
  class GridResponseType(BaseModel):
 
36
  done: bool
37
 
38
 
39
+ class BotResponseType(GridResponseType):
40
+ action: int
41
+
42
+
43
+ def get_app_state() -> AppState:
44
+ return AppState()
45
+
46
+
47
  def step(env: ConnectFour, action: int) -> GridResponseType:
48
  env.step(action)
49
  return observe(env)
 
69
  action: int,
70
  app_state: Annotated[AppState, Depends(dependency=AppState)],
71
  ) -> GridResponseType:
72
+ response = step(app_state.env, action)
73
+ app_state.inform_action(action=action)
74
+ return response
75
 
76
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
77
  def endpoint_observe(
 
79
  ) -> GridResponseType:
80
  return observe(app_state.env)
81
 
82
+ @app.post(path="/connect_four/bot_play", response_model=BotResponseType)
83
  def endpoint_bot_play(
84
  cpu_config: CpuConfig,
85
+ app_state: Annotated[AppState, Depends(dependency=get_app_state)],
86
+ ) -> BotResponseType:
87
  app_state.set_config(cpu_config)
88
  action = app_state.get_action()
89
+ response = step(app_state.env, action)
90
+ app_state.inform_action(action=action)
91
+ return BotResponseType(
92
+ grid=response.grid,
93
+ done=response.done,
94
+ action=action,
95
+ )
96
 
97
  @app.get(path="/connect_four/bot_progress", response_model=float)
98
  def endpoint_bot_progress(
 
101
  if isinstance(app_state.agent, MCTSAgent):
102
  if app_state.cpu_config.simulations is None:
103
  raise ValueError
104
+ if app_state.agent.mcts is None:
105
+ raise ValueError
106
  return float(
107
+ app_state.agent.mcts.root.visits / app_state.cpu_config.simulations,
108
+ ) # TODO why not recognized as float?
109
  return 1.0
110
 
111
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
src/app_state.py CHANGED
@@ -25,18 +25,21 @@ class AppState:
25
 
26
  self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
27
  self.set_agent() # TODO in properties setter.
28
- self.agent: Agent[Any, int]
29
 
30
- def __new__(cls) -> "AppState":
31
  if cls._instance is None:
32
  cls._instance = super().__new__(cls)
33
  cls._instance.setup()
34
  return cls._instance
35
 
36
  def set_config(self, cpu_config: CpuConfig) -> None:
 
37
  if cpu_config != self.cpu_config:
38
  self.cpu_config = cpu_config
39
  self.set_agent()
 
 
40
 
41
  def create_rollout(self) -> Agent[Any, Any]:
42
  match self.cpu_config.rollout_policy:
@@ -51,8 +54,8 @@ class AppState:
51
  raise NotImplementedError(msg)
52
 
53
  def set_agent(self) -> None:
54
- match self.cpu_config.agent_type:
55
- case AgentType.MCTS:
56
  rollout_agent = self.create_rollout()
57
  # fmt: off
58
  mcts_config = (
@@ -62,9 +65,10 @@ class AppState:
62
  ).build()
63
  # fmt: on
64
  self.agent = MCTSAgent(cfg=mcts_config)
65
- case AgentType.RANDOM:
 
66
  self.agent = RandomMultiAgent()
67
- case AgentType.SAC:
68
  self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
69
  case _:
70
  msg = f"cpu_config.name: {self.cpu_config.agent_type}"
@@ -72,3 +76,8 @@ class AppState:
72
 
73
  def get_action(self) -> int:
74
  return self.agent.get_action(env=self.env)
 
 
 
 
 
 
25
 
26
  self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
27
  self.set_agent() # TODO in properties setter.
28
+ self.agent: Agent[Any, Any]
29
 
30
+ def __new__(cls: type["AppState"]) -> "AppState":
31
  if cls._instance is None:
32
  cls._instance = super().__new__(cls)
33
  cls._instance.setup()
34
  return cls._instance
35
 
36
  def set_config(self, cpu_config: CpuConfig) -> None:
37
+ logger.info(f"new cpu_config: {cpu_config}")
38
  if cpu_config != self.cpu_config:
39
  self.cpu_config = cpu_config
40
  self.set_agent()
41
+ else:
42
+ logger.info("cpu_config unchanged")
43
 
44
  def create_rollout(self) -> Agent[Any, Any]:
45
  match self.cpu_config.rollout_policy:
 
54
  raise NotImplementedError(msg)
55
 
56
  def set_agent(self) -> None:
57
+ match self.cpu_config.agent_type.value:
58
+ case AgentType.MCTS.value:
59
  rollout_agent = self.create_rollout()
60
  # fmt: off
61
  mcts_config = (
 
65
  ).build()
66
  # fmt: on
67
  self.agent = MCTSAgent(cfg=mcts_config)
68
+ logger.debug("set_agent: MCTSAgent")
69
+ case AgentType.RANDOM.value:
70
  self.agent = RandomMultiAgent()
71
+ case AgentType.SAC.value:
72
  self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
73
  case _:
74
  msg = f"cpu_config.name: {self.cpu_config.agent_type}"
 
76
 
77
  def get_action(self) -> int:
78
  return self.agent.get_action(env=self.env)
79
+
80
+ def inform_action(self, action: int) -> None:
81
+ """Update the agent's state as a result of external changes to the environment."""
82
+ if isinstance(self.agent, MCTSAgent) and self.agent.mcts is not None:
83
+ self.agent.mcts.update_root(action)
src/huggingface/huggingface_client.py CHANGED
@@ -1,6 +1,4 @@
1
- import os
2
-
3
- from huggingface_hub import HfApi, login
4
 
5
  from .get_environments import get_environments
6
  from .get_files import get_mp4_paths
@@ -8,11 +6,6 @@ from .get_files import get_mp4_paths
8
 
9
  class HuggingFaceClient:
10
  def __init__(self) -> None:
11
- login(
12
- token=os.environ.get("HUGGINGFACE_TOKEN"),
13
- add_to_git_credential=True,
14
- new_session=False,
15
- )
16
  self.hf_api = HfApi()
17
  self.environments = get_environments(self.hf_api)
18
  self.mp4_paths = get_mp4_paths(environments=self.environments)
 
1
+ from huggingface_hub import HfApi
 
 
2
 
3
  from .get_environments import get_environments
4
  from .get_files import get_mp4_paths
 
6
 
7
  class HuggingFaceClient:
8
  def __init__(self) -> None:
 
 
 
 
 
9
  self.hf_api = HfApi()
10
  self.environments = get_environments(self.hf_api)
11
  self.mp4_paths = get_mp4_paths(environments=self.environments)