File size: 4,479 Bytes
0ad74ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import datetime
import os
import threading
from collections import OrderedDict
from collections.abc import Iterator
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from gradio.blocks import Blocks
    from gradio.components import State


class StateHolder:
    def __init__(self):
        self.capacity = 10000
        self.session_data: OrderedDict[str, SessionState] = OrderedDict()
        self.time_last_used: dict[str, datetime.datetime] = {}
        self.lock = threading.Lock()

    def set_blocks(self, blocks: Blocks):
        self.blocks = blocks
        blocks.state_holder = self
        self.capacity = blocks.state_session_capacity

    def reset(self, blocks: Blocks):
        """Reset the state holder with new blocks. Used during reload mode."""
        self.session_data = OrderedDict()
        # Call set blocks again to set new ids
        self.set_blocks(blocks)

    def __getitem__(self, session_id: str) -> SessionState:
        if session_id not in self.session_data:
            self.session_data[session_id] = SessionState(self.blocks)
        self.update(session_id)
        self.time_last_used[session_id] = datetime.datetime.now()
        return self.session_data[session_id]

    def __contains__(self, session_id: str):
        return session_id in self.session_data

    def update(self, session_id: str):
        with self.lock:
            if session_id in self.session_data:
                self.session_data.move_to_end(session_id)
            if len(self.session_data) > self.capacity:
                self.session_data.popitem(last=False)

    def delete_all_expired_state(
        self,
    ):
        for session_id in self.session_data:
            self.delete_state(session_id, expired_only=True)

    def delete_state(self, session_id: str, expired_only: bool = False):
        if session_id not in self.session_data:
            return
        to_delete = []
        session_state = self.session_data[session_id]
        for component, value, expired in session_state.state_components:
            if not expired_only or expired:
                component.delete_callback(value)
                to_delete.append(component._id)
        for component in to_delete:
            del session_state.state_data[component]


class SessionState:
    def __init__(self, blocks: Blocks):
        self.blocks_config = copy(blocks.default_config)
        self.state_data: dict[int, Any] = {}
        self._state_ttl = {}
        self.is_closed = False
        # When a session is closed, the state is stored for an hour to give the user time to reopen the session.
        # During testing we set to a lower value to be able to test
        self.STATE_TTL_WHEN_CLOSED = (
            1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
        )

    def __getitem__(self, key: int) -> Any:
        block = self.blocks_config.blocks[key]
        if block.stateful:
            if key not in self.state_data:
                self.state_data[key] = deepcopy(getattr(block, "value", None))
            return self.state_data[key]
        else:
            return block

    def __setitem__(self, key: int, value: Any):
        from gradio.components import State

        block = self.blocks_config.blocks[key]
        if isinstance(block, State):
            self._state_ttl[key] = (
                block.time_to_live,
                datetime.datetime.now(),
            )
            self.state_data[key] = value
        else:
            self.blocks_config.blocks[key] = value

    def __contains__(self, key: int):
        block = self.blocks_config.blocks[key]
        if block.stateful:
            return key in self.state_data
        else:
            return key in self.blocks_config.blocks

    @property
    def state_components(self) -> Iterator[tuple[State, Any, bool]]:
        from gradio.components import State

        for id in self.state_data:
            block = self.blocks_config.blocks[id]
            if isinstance(block, State) and id in self._state_ttl:
                time_to_live, created_at = self._state_ttl[id]
                if self.is_closed:
                    time_to_live = self.STATE_TTL_WHEN_CLOSED
                value = self.state_data[id]
                yield (
                    block,
                    value,
                    (datetime.datetime.now() - created_at).seconds > time_to_live,
                )