File size: 4,479 Bytes
0ad74ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from __future__ import annotations
import datetime
import os
import threading
from collections import OrderedDict
from collections.abc import Iterator
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.components import State
class StateHolder:
def __init__(self):
self.capacity = 10000
self.session_data: OrderedDict[str, SessionState] = OrderedDict()
self.time_last_used: dict[str, datetime.datetime] = {}
self.lock = threading.Lock()
def set_blocks(self, blocks: Blocks):
self.blocks = blocks
blocks.state_holder = self
self.capacity = blocks.state_session_capacity
def reset(self, blocks: Blocks):
"""Reset the state holder with new blocks. Used during reload mode."""
self.session_data = OrderedDict()
# Call set blocks again to set new ids
self.set_blocks(blocks)
def __getitem__(self, session_id: str) -> SessionState:
if session_id not in self.session_data:
self.session_data[session_id] = SessionState(self.blocks)
self.update(session_id)
self.time_last_used[session_id] = datetime.datetime.now()
return self.session_data[session_id]
def __contains__(self, session_id: str):
return session_id in self.session_data
def update(self, session_id: str):
with self.lock:
if session_id in self.session_data:
self.session_data.move_to_end(session_id)
if len(self.session_data) > self.capacity:
self.session_data.popitem(last=False)
def delete_all_expired_state(
self,
):
for session_id in self.session_data:
self.delete_state(session_id, expired_only=True)
def delete_state(self, session_id: str, expired_only: bool = False):
if session_id not in self.session_data:
return
to_delete = []
session_state = self.session_data[session_id]
for component, value, expired in session_state.state_components:
if not expired_only or expired:
component.delete_callback(value)
to_delete.append(component._id)
for component in to_delete:
del session_state.state_data[component]
class SessionState:
def __init__(self, blocks: Blocks):
self.blocks_config = copy(blocks.default_config)
self.state_data: dict[int, Any] = {}
self._state_ttl = {}
self.is_closed = False
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
# During testing we set to a lower value to be able to test
self.STATE_TTL_WHEN_CLOSED = (
1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
)
def __getitem__(self, key: int) -> Any:
block = self.blocks_config.blocks[key]
if block.stateful:
if key not in self.state_data:
self.state_data[key] = deepcopy(getattr(block, "value", None))
return self.state_data[key]
else:
return block
def __setitem__(self, key: int, value: Any):
from gradio.components import State
block = self.blocks_config.blocks[key]
if isinstance(block, State):
self._state_ttl[key] = (
block.time_to_live,
datetime.datetime.now(),
)
self.state_data[key] = value
else:
self.blocks_config.blocks[key] = value
def __contains__(self, key: int):
block = self.blocks_config.blocks[key]
if block.stateful:
return key in self.state_data
else:
return key in self.blocks_config.blocks
@property
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
from gradio.components import State
for id in self.state_data:
block = self.blocks_config.blocks[id]
if isinstance(block, State) and id in self._state_ttl:
time_to_live, created_at = self._state_ttl[id]
if self.is_closed:
time_to_live = self.STATE_TTL_WHEN_CLOSED
value = self.state_data[id]
yield (
block,
value,
(datetime.datetime.now() - created_at).seconds > time_to_live,
)
|