ljy266987 commited on
Commit
679081c
1 Parent(s): a826f18

add modified zero

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
spaces/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import sys
5
+
6
+ if sys.version_info.minor < 8: # pragma: no cover
7
+ raise RuntimeError("Importing PySpaces requires Python 3.8+")
8
+
9
+
10
+ from .zero.decorator import GPU
11
+ from .zero.torch import disable_cuda_intercept
12
+ from .gradio import gradio_auto_wrap
13
+ from .gradio import disable_gradio_auto_wrap
14
+ from .gradio import enable_gradio_auto_wrap
15
+
16
+ import os
17
+
18
+ # 获取全部环境变量
19
+ env_vars = os.environ
20
+
21
+ # 遍历并打印环境变量
22
+ for key, value in env_vars.items():
23
+ print(f"{key}: {value}")
24
+
25
+
26
+ __all__ = [
27
+ 'GPU',
28
+ 'disable_cuda_intercept',
29
+ 'gradio_auto_wrap',
30
+ 'disable_gradio_auto_wrap',
31
+ 'enable_gradio_auto_wrap',
32
+ ]
spaces/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ from .utils import boolean
8
+
9
+
10
+ class Settings:
11
+ def __init__(self):
12
+ self.zero_gpu = boolean(
13
+ os.getenv('SPACES_ZERO_GPU'))
14
+ self.zero_device_api_url = (
15
+ os.getenv('SPACES_ZERO_DEVICE_API_URL'))
16
+ self.gradio_auto_wrap = boolean(
17
+ os.getenv('SPACES_GRADIO_AUTO_WRAP'))
18
+ self.zero_patch_torch_device = boolean(
19
+ os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
20
+
21
+
22
+ Config = Settings()
23
+
24
+
25
+ if Config.zero_gpu:
26
+ assert Config.zero_device_api_url is not None, (
27
+ 'SPACES_ZERO_DEVICE_API_URL env must be set '
28
+ 'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
29
+ )
spaces/gradio.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable
6
+ from typing import Generator
7
+ from typing import TypeVar
8
+ from typing import overload
9
+ from typing_extensions import ParamSpec
10
+
11
+ from .config import Config
12
+ from .zero.decorator import GPU
13
+
14
+
15
+ Param = ParamSpec('Param')
16
+ Res = TypeVar('Res')
17
+
18
+
19
+ gradio_auto_wrap_enabled = Config.gradio_auto_wrap
20
+
21
+
22
+ def disable_gradio_auto_wrap():
23
+ global gradio_auto_wrap_enabled
24
+ gradio_auto_wrap_enabled = False
25
+
26
+ def enable_gradio_auto_wrap():
27
+ global gradio_auto_wrap_enabled
28
+ gradio_auto_wrap_enabled = True
29
+
30
+
31
+ @overload
32
+ def gradio_auto_wrap(
33
+ task:
34
+ Callable[Param, Res],
35
+ ) -> Callable[Param, Res]:
36
+ ...
37
+ @overload
38
+ def gradio_auto_wrap(
39
+ task:
40
+ None,
41
+ ) -> None:
42
+ ...
43
+ def gradio_auto_wrap(
44
+ task:
45
+ Callable[Param, Res]
46
+ | None,
47
+ ) -> (Callable[Param, Res]
48
+ | None):
49
+ """
50
+ """
51
+ if not gradio_auto_wrap_enabled:
52
+ return task
53
+ if not callable(task):
54
+ return task
55
+ return GPU(task) # type: ignore
spaces/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from functools import lru_cache as cache
7
+ from functools import partial
8
+
9
+ import multiprocessing
10
+ from multiprocessing.queues import SimpleQueue as _SimpleQueue
11
+ from pathlib import Path
12
+ from pickle import PicklingError
13
+ from typing import Callable
14
+ from typing import TypeVar
15
+
16
+
17
+ GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
18
+
19
+
20
+ T = TypeVar('T')
21
+
22
+
23
+ @cache
24
+ def self_cgroup_device_path() -> str:
25
+ cgroup_content = Path('/proc/self/cgroup').read_text()
26
+ for line in cgroup_content.strip().split('\n'):
27
+ contents = line.split(':devices:')
28
+ if len(contents) != 2:
29
+ continue # pragma: no cover
30
+ return contents[1]
31
+ raise Exception # pragma: no cover
32
+
33
+
34
+ if sys.version_info.minor < 9: # pragma: no cover
35
+ _SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
36
+
37
+ class SimpleQueue(_SimpleQueue[T]):
38
+ def __init__(self, *args):
39
+ super().__init__(*args, ctx=multiprocessing.get_context('fork'))
40
+ def put(self, obj: T):
41
+ try:
42
+ super().put(obj)
43
+ except PicklingError:
44
+ raise # pragma: no cover
45
+ # https://bugs.python.org/issue29187
46
+ except Exception as e:
47
+ message = str(e)
48
+ if not "pickle" in message:
49
+ raise # pragma: no cover
50
+ raise PicklingError(message)
51
+ def close(self): # Python 3.8 static typing trick
52
+ super().close() # type: ignore
53
+
54
+
55
+ def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
56
+ def drop(*args):
57
+ return fn()
58
+ return drop
59
+
60
+
61
+ def boolean(value: str | None) -> bool:
62
+ return value is not None and value.lower() in ("1", "t", "true")
63
+
64
+
65
+ def gradio_request_var():
66
+ try:
67
+ from gradio.context import LocalContext
68
+ except ImportError: # pragma: no cover
69
+ raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
70
+ return LocalContext.request
71
+
72
+
73
+ debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
spaces/zero/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from ..config import Config
5
+ from . import torch
6
+
7
+ if Config.zero_gpu:
8
+ if torch.is_in_bad_fork():
9
+ raise RuntimeError(
10
+ "CUDA has been initialized before importing the `spaces` package"
11
+ )
12
+ torch.patch() # pragma: no cover
spaces/zero/api.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synced with huggingface/pyspaces:spaces/zero/api.py
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from datetime import timedelta
7
+ from typing import Any
8
+ from typing import Generator
9
+ from typing import Literal
10
+ from typing import NamedTuple
11
+ from typing import Optional
12
+ from typing import overload
13
+
14
+ import httpx
15
+ from pydantic import BaseModel
16
+ from typing_extensions import assert_never
17
+
18
+
19
+ AllowToken = str
20
+ NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
21
+ NvidiaUUID = str
22
+ CGroupPath = str
23
+ VisitorId = str
24
+ Score = float
25
+
26
+
27
+ class ScheduleResponse(BaseModel):
28
+ idle: bool
29
+ nvidiaIndex: int
30
+ nvidiaUUID: str
31
+ allowToken: str | None
32
+
33
+
34
+ class QuotaInfos(BaseModel):
35
+ left: int
36
+ wait: timedelta
37
+
38
+
39
+ class ReportUsageMonitoringParams(NamedTuple):
40
+ nvidia_index: int
41
+ visitor_id: str
42
+ duration: timedelta
43
+
44
+
45
+ class QueueEvent(BaseModel):
46
+ event: Literal['ping', 'failed', 'succeeded']
47
+ data: Optional[ScheduleResponse] = None
48
+
49
+
50
+ def sse_parse(text: str):
51
+ event, *data = text.strip().splitlines()
52
+ assert event.startswith('event:')
53
+ event = event[6:].strip()
54
+ if event in ('ping', 'failed'):
55
+ return QueueEvent(event=event)
56
+ assert event == 'succeeded'
57
+ (data,) = data
58
+ assert data.startswith('data:')
59
+ data = data[5:].strip()
60
+ return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
61
+
62
+
63
+ def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
64
+ for text in res.iter_text():
65
+ if len(text) == 0:
66
+ break # pragma: no cover
67
+ try:
68
+ print(f"sse_stream: {text}")
69
+ yield sse_parse(text)
70
+ except GeneratorExit:
71
+ res.close()
72
+ break
73
+
74
+
75
+ class APIClient:
76
+
77
+ def __init__(self, client: httpx.Client):
78
+ self.client = client
79
+
80
+ def startup_report(self) -> httpx.codes:
81
+ res = self.client.post('/startup-report')
82
+ print(f"/startup-report: {res}")
83
+ return httpx.codes(res.status_code)
84
+
85
+ def schedule(
86
+ self,
87
+ cgroup_path: str,
88
+ task_id: int = 0,
89
+ token: str | None = None,
90
+ duration_seconds: int | None = None,
91
+ enable_queue: bool = True,
92
+ ):
93
+ params: dict[str, str | int | bool] = {
94
+ 'cgroupPath': cgroup_path,
95
+ 'taskId': task_id,
96
+ 'enableQueue': enable_queue,
97
+ }
98
+ if duration_seconds is not None:
99
+ params['durationSeconds'] = duration_seconds
100
+ if token is not None:
101
+ params['token'] = token
102
+ print(f"POST /schedule: {params}")
103
+
104
+ res = self.client.send(
105
+ request=self.client.build_request(
106
+ method='POST',
107
+ url='/schedule',
108
+ params=params,
109
+ ),
110
+ stream=True,
111
+ )
112
+ status = httpx.codes(res.status_code)
113
+ if (status is not httpx.codes.OK and
114
+ status is not httpx.codes.TOO_MANY_REQUESTS
115
+ ):
116
+ res.close()
117
+ return status
118
+ if "text/event-stream" in res.headers['content-type']:
119
+ return sse_stream(res)
120
+ res.read()
121
+ print(f"POST /schedule res: {res.json()}")
122
+ if status is httpx.codes.TOO_MANY_REQUESTS:
123
+ return QuotaInfos(**res.json()) # pragma: no cover
124
+ if status is httpx.codes.OK:
125
+ return ScheduleResponse(**res.json())
126
+ assert_never(status)
127
+
128
+ def allow(
129
+ self,
130
+ allow_token: str,
131
+ pid: int,
132
+ ):
133
+ params = {
134
+ 'allowToken': allow_token,
135
+ 'pid': pid,
136
+ }
137
+ res = self.client.post('/allow', params=params)
138
+ print(f"POST /allow param: {params} res: {res}")
139
+ return httpx.codes(res.status_code)
140
+
141
+ def release(
142
+ self,
143
+ nvidia_index: int,
144
+ cgroup_path: str,
145
+ task_id: int = 0,
146
+ fail: bool = False,
147
+ ) -> httpx.codes:
148
+ params = {
149
+ 'nvidiaIndex': nvidia_index,
150
+ 'cgroupPath': cgroup_path,
151
+ 'taskId': task_id,
152
+ 'fail': fail,
153
+ }
154
+ res = self.client.post('/release', params=params)
155
+ print(f"POST /release param: {params} res: {res}")
156
+ return httpx.codes(res.status_code)
157
+
158
+ def get_queue_size(self) -> int:
159
+ res = self.client.get('/queue-size')
160
+ assert res.status_code == 200, res.status_code
161
+ size = res.json()
162
+ assert isinstance(size, int)
163
+ return size
spaces/zero/bitsandbytes.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ from typing import TYPE_CHECKING
9
+ from typing import Tuple
10
+
11
+ from .utils import cuda_unavailable
12
+ from .utils import maybe_import_torch
13
+ from .utils import maybe_import_bitsandbytes
14
+
15
+ if TYPE_CHECKING:
16
+ import torch as Torch
17
+
18
+
19
+ if (torch := maybe_import_torch()) and (bnb := maybe_import_bitsandbytes()):
20
+
21
+ from torch.utils.weak import WeakTensorKeyDictionary
22
+
23
+ with cuda_unavailable(torch):
24
+ from bitsandbytes import cextension
25
+ from bitsandbytes import functional
26
+ try: # bitsandbytes < 0.44
27
+ from bitsandbytes.cuda_setup.main import CUDASetup
28
+ except ModuleNotFoundError: # pragma: no cover
29
+ CUDASetup = None
30
+ from bitsandbytes.nn import Int8Params
31
+ from bitsandbytes.nn import Params4bit
32
+
33
+ _param_to_8bit = Int8Params.to # type: ignore
34
+ _param_cuda_8bit = Int8Params.cuda
35
+ _param_to_4bit = Params4bit.to # type: ignore
36
+ _param_cuda_4bit = Params4bit.cuda
37
+
38
+ TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
39
+
40
+ to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
41
+ to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
42
+
43
+ def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
44
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
45
+ device, *_ = parsed
46
+ if not isinstance(device, torch.device): # pragma: no cover
47
+ return _param_to_8bit(self, *args, **kwargs)
48
+ if device.type != 'cuda':
49
+ return _param_to_8bit(self, *args, **kwargs)
50
+ to_ops_8bit[self] = parsed
51
+ return self
52
+
53
+ def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
54
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
55
+ device, *_ = parsed
56
+ if not isinstance(device, torch.device): # pragma: no cover
57
+ return _param_to_4bit(self, *args, **kwargs)
58
+ if device.type != 'cuda':
59
+ return _param_to_4bit(self, *args, **kwargs)
60
+ to_ops_4bit[self] = parsed
61
+ return self
62
+
63
+ def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
64
+ if device is None: # pragma: no cover
65
+ return True
66
+ if isinstance(device, int):
67
+ return True
68
+ if isinstance(device, str): # pragma: no cover
69
+ device = torch.device(device)
70
+ return device.type == 'cuda' # pragma: no cover
71
+
72
+ def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
73
+ if not _cuda_op_arg_check(device): # pragma: no cover
74
+ # Let PyTorch handle the fail
75
+ return _param_cuda_8bit(self, device, **kwargs)
76
+ to_ops_8bit[self] = None
77
+ return self
78
+
79
+ def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
80
+ if not _cuda_op_arg_check(device): # pragma: no cover
81
+ # Let PyTorch handle the fail
82
+ return _param_cuda_4bit(self, device, **kwargs)
83
+ to_ops_4bit[self] = None
84
+ return self
85
+
86
+ def _patch():
87
+ Int8Params.to = _to_op_register_8bit # type: ignore
88
+ Int8Params.cuda = _cuda_op_register_8bit # type: ignore
89
+ Params4bit.to = _to_op_register_4bit # type: ignore
90
+ Params4bit.cuda = _cuda_op_register_4bit # type: ignore
91
+
92
+ def _unpatch():
93
+ Int8Params.to = _param_to_8bit # type: ignore
94
+ Int8Params.cuda = _param_cuda_8bit
95
+ Params4bit.to = _param_to_4bit # type: ignore
96
+ Params4bit.cuda = _param_cuda_4bit
97
+
98
+ def _move():
99
+ if CUDASetup is not None:
100
+ CUDASetup._instance = None
101
+ importlib.reload(cextension)
102
+ functional.lib = cextension.lib
103
+ for op in to_ops_8bit.items():
104
+ tensor, parsed_args = op
105
+ if parsed_args:
106
+ _, dtype, _, memory_format = parsed_args
107
+ else:
108
+ dtype, memory_format = None, None
109
+ tensor.data = _param_to_8bit(tensor,
110
+ device='cuda',
111
+ dtype=dtype,
112
+ memory_format=memory_format,
113
+ ) # type: ignore
114
+ for op in to_ops_4bit.items():
115
+ tensor, parsed_args = op
116
+ if parsed_args:
117
+ _, dtype, _, memory_format = parsed_args
118
+ else:
119
+ dtype, memory_format = None, None
120
+ tensor.data = _param_to_4bit(tensor,
121
+ device='cuda',
122
+ dtype=dtype,
123
+ memory_format=memory_format,
124
+ ) # type: ignore
125
+
126
+ else:
127
+
128
+ _patch = lambda: None
129
+ _unpatch = lambda: None
130
+ _move = lambda: None
131
+
132
+
133
+ patch = _patch
134
+ unpatch = _unpatch
135
+ move = _move
spaces/zero/client.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import time
7
+ import warnings
8
+ from datetime import timedelta
9
+
10
+ import gradio as gr
11
+ import httpx
12
+
13
+ from .. import utils
14
+ from ..config import Config
15
+ from .api import APIClient
16
+ from .api import QuotaInfos
17
+ from .api import ScheduleResponse
18
+ from .gradio import get_event
19
+
20
+
21
+ TOKEN_HEADER = 'X-IP-Token'
22
+ DEFAULT_SCHEDULE_DURATION = 60
23
+
24
+ QUOTA_MESSAGE = "You have exceeded your GPU quota"
25
+ UNUSED_MESSAGE = "GPU device not used"
26
+ NO_GPU_MESSAGE_REGULAR = "No GPU is currently available"
27
+ NO_GPU_MESSAGE_INQUEUE = "No GPU is currently available for you after 60s"
28
+
29
+
30
+ def api_client():
31
+ assert Config.zero_device_api_url is not None
32
+ httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
33
+ print(f"api_client: {Config.zero_device_api_url}")
34
+ return APIClient(httpx_client)
35
+
36
+
37
+ def startup_report():
38
+ retries, max_retries = 0, 2
39
+ client = api_client()
40
+ while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
41
+ time.sleep(1)
42
+ if (retries := retries + 1) > max_retries:
43
+ raise RuntimeError("Error while initializing ZeroGPU: NotFound")
44
+ if status is not httpx.codes.OK: # pragma: no cover
45
+ raise RuntimeError("Error while initializing ZeroGPU: Unknown")
46
+
47
+
48
+ def schedule(
49
+ task_id: int,
50
+ request: gr.Request | None = None,
51
+ duration: timedelta | None = None,
52
+ _first_attempt: bool = True,
53
+ ) -> ScheduleResponse:
54
+
55
+ if not gr.__version__.startswith('4.'): # pragma: no cover
56
+ raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
57
+
58
+ res = api_client().schedule(
59
+ cgroup_path=utils.self_cgroup_device_path(),
60
+ task_id=task_id,
61
+ token=_get_token(request),
62
+ duration_seconds=duration.seconds if duration is not None else None,
63
+ )
64
+
65
+ if isinstance(res, ScheduleResponse):
66
+ return res
67
+
68
+ if isinstance(res, QuotaInfos): # pragma: no cover
69
+ requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
70
+ if res.wait < timedelta(0):
71
+ message = (
72
+ f"The requested GPU duration ({requested}s) "
73
+ f"is larger than the maximum allowed"
74
+ )
75
+ else:
76
+ message = (
77
+ f"You have exceeded your GPU quota "
78
+ f"({res.left}s left vs. {requested}s requested). "
79
+ f"Please retry in {res.wait}"
80
+ )
81
+ raise gr.Error(message)
82
+
83
+ if not isinstance(res, httpx.codes): # pragma: no cover
84
+ gr.Info("Waiting for a GPU to become available")
85
+ connection_event = get_event()
86
+ if connection_event is None and request is not None:
87
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
88
+ while True:
89
+ try:
90
+ event = next(res)
91
+ except StopIteration:
92
+ raise RuntimeError("Unexpected end of stream")
93
+ except httpx.RemoteProtocolError:
94
+ if not _first_attempt:
95
+ raise RuntimeError("Error while re-trying after queue disconnect")
96
+ return schedule(task_id, request, duration, _first_attempt=False)
97
+ if event.event == 'ping':
98
+ if connection_event is not None and not connection_event.alive:
99
+ res.close()
100
+ raise RuntimeError("Connection closed by visitor while queueing")
101
+ continue
102
+ if event.event == 'failed':
103
+ raise gr.Error(NO_GPU_MESSAGE_INQUEUE)
104
+ if event.event == 'succeeded':
105
+ assert event.data is not None
106
+ if connection_event is not None and not connection_event.alive:
107
+ release(task_id, event.data.nvidiaIndex)
108
+ raise RuntimeError("Connection closed by visitor on queue success")
109
+ gr.Info("Successfully acquired a GPU")
110
+ return event.data
111
+
112
+ if res is httpx.codes.SERVICE_UNAVAILABLE:
113
+ raise gr.Error(NO_GPU_MESSAGE_REGULAR)
114
+
115
+ # TODO: Find a way to log 'detail' response field
116
+ raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
117
+
118
+
119
+ def allow(allow_token: str) -> None:
120
+ pid = os.getpid()
121
+ assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
122
+ assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
123
+
124
+
125
+ def release(
126
+ task_id: int,
127
+ nvidia_index: int,
128
+ fail: bool = False,
129
+ allow_404: bool = False,
130
+ ) -> None:
131
+
132
+ res = api_client().release(
133
+ cgroup_path=utils.self_cgroup_device_path(),
134
+ task_id=task_id,
135
+ nvidia_index=nvidia_index,
136
+ fail=fail,
137
+ )
138
+
139
+ if res is httpx.codes.NO_CONTENT: # pragma: no cover
140
+ try:
141
+ gr.Warning(UNUSED_MESSAGE)
142
+ except AttributeError:
143
+ pass
144
+ warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
145
+ return None
146
+
147
+ if res is httpx.codes.NOT_FOUND:
148
+ if not allow_404:
149
+ warnings.warn("ZeroGPU API /release warning: 404 Not Found")
150
+ return None
151
+
152
+ if httpx.codes.is_success(res):
153
+ return None
154
+
155
+ # TODO: Find a way to log 'detail' response field
156
+ # TODO: Only raise in dev environment. Simply warn in production ?
157
+ raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
158
+
159
+
160
+ def _get_token(request: gr.Request | None) -> str | None:
161
+
162
+ if request is None:
163
+ return None
164
+
165
+ headers = getattr(request, 'headers', None)
166
+ if headers is None or not hasattr(headers, '__dict__'):
167
+ raise gr.Error("Internal Gradio error")
168
+
169
+ # Compatibility trick
170
+ if not hasattr(headers, 'get'):
171
+ headers = headers.__dict__ # pragma: no cover
172
+
173
+ if not (token := headers.get(TOKEN_HEADER.lower())):
174
+ raise gr.Error("Internal infra error")
175
+
176
+ return token
spaces/zero/decorator.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ import sys
7
+ import warnings
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from typing import Callable
11
+ from typing import TypeVar
12
+ from typing import overload
13
+ from typing_extensions import ParamSpec
14
+ from typing_extensions import Unpack
15
+
16
+ import gradio as gr
17
+
18
+ from ..config import Config
19
+ from . import client
20
+ from .types import EmptyKwargs
21
+ from .wrappers import regular_function_wrapper
22
+ from .wrappers import generator_function_wrapper
23
+
24
+
25
+ P = ParamSpec('P')
26
+ R = TypeVar('R')
27
+
28
+
29
+ decorated_cache: dict[Callable, Callable] = {}
30
+
31
+
32
+ @overload
33
+ def GPU(
34
+ task: None = None, *,
35
+ duration: int | timedelta | None = None,
36
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
37
+ ...
38
+ @overload
39
+ def GPU(
40
+ task: Callable[P, R], *,
41
+ duration: int | timedelta | None = None,
42
+ ) -> Callable[P, R]:
43
+ ...
44
+ def GPU(
45
+ task: Callable[P, R] | None = None, *,
46
+ duration: int | timedelta | None = None,
47
+ **kwargs: Unpack[EmptyKwargs],
48
+ ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
49
+ """
50
+ ZeroGPU decorator
51
+
52
+ Basic usage:
53
+ ```
54
+ @spaces.GPU
55
+ def fn(...):
56
+ # CUDA is available here
57
+ pass
58
+ ```
59
+
60
+ With custom duration:
61
+ ```
62
+ @spaces.GPU(duration=45) # Expressed in seconds
63
+ def fn(...):
64
+ # CUDA is available here
65
+ pass
66
+ ```
67
+
68
+ Args:
69
+ task (`Callable | None`): Python function that requires CUDA
70
+ duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
71
+
72
+ Returns:
73
+ `Callable`: GPU-ready function
74
+ """
75
+ if "enable_queue" in kwargs:
76
+ warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
77
+ if task is None:
78
+ return partial(_GPU, duration=duration)
79
+ return _GPU(task, duration)
80
+
81
+
82
+ def _GPU(
83
+ task: Callable[P, R],
84
+ duration: int | timedelta | None,
85
+ ) -> Callable[P, R]:
86
+
87
+ if not Config.zero_gpu:
88
+ # TODO: still prepend gr.Request for type consistency ?
89
+ return task # type: ignore
90
+
91
+ if sys.version_info.minor < 9: # pragma: no cover
92
+ raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
93
+
94
+ if task in decorated_cache:
95
+ # TODO: Assert same duration ?
96
+ return decorated_cache[task] # type: ignore
97
+
98
+ if inspect.iscoroutinefunction(task):
99
+ raise NotImplementedError
100
+
101
+ if duration is None or isinstance(duration, timedelta):
102
+ timedelta_duration = duration
103
+ else:
104
+ timedelta_duration = timedelta(seconds=duration)
105
+
106
+ if inspect.isgeneratorfunction(task):
107
+ decorated = generator_function_wrapper(task, timedelta_duration)
108
+ else:
109
+ decorated = regular_function_wrapper(task, timedelta_duration)
110
+
111
+ client.startup_report()
112
+ decorated_cache.update({
113
+ task: decorated,
114
+ decorated: decorated,
115
+ })
116
+
117
+ return decorated # type: ignore
spaces/zero/gradio.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from typing import NamedTuple
6
+ import warnings
7
+
8
+ from gradio.context import LocalContext
9
+ from gradio.helpers import Progress
10
+ from gradio.helpers import TrackedIterable
11
+ from gradio.queueing import Queue
12
+ from typing_extensions import assert_type
13
+
14
+ from ..utils import SimpleQueue
15
+ from .types import GeneratorResQueueResult
16
+ from .types import GradioQueueEvent
17
+ from .types import RegularResQueueResult
18
+
19
+
20
+ QUEUE_RPC_METHODS = [
21
+ "set_progress",
22
+ "log_message",
23
+ ]
24
+
25
+
26
+ class GradioPartialContext(NamedTuple):
27
+ event_id: str | None
28
+ in_event_listener: bool
29
+ progress: Progress | None
30
+
31
+ @staticmethod
32
+ def get():
33
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
34
+ return GradioPartialContext(
35
+ event_id=LocalContext.event_id.get(),
36
+ in_event_listener=LocalContext.in_event_listener.get(),
37
+ progress=LocalContext.progress.get(),
38
+ )
39
+
40
+ @staticmethod
41
+ def apply(context: 'GradioPartialContext'):
42
+ LocalContext.event_id.set(context.event_id)
43
+ LocalContext.in_event_listener.set(context.in_event_listener)
44
+ LocalContext.progress.set(context.progress)
45
+
46
+
47
+ def get_queue_instance():
48
+ blocks = LocalContext.blocks.get()
49
+ if blocks is None: # pragma: no cover
50
+ return None
51
+ return blocks._queue
52
+
53
+
54
+ def get_event():
55
+ queue = get_queue_instance()
56
+ event_id = LocalContext.event_id.get()
57
+ if queue is None:
58
+ return None
59
+ if event_id is None: # pragma: no cover
60
+ return None
61
+ for job in queue.active_jobs:
62
+ if job is None: # pragma: no cover
63
+ continue
64
+ for event in job:
65
+ if event._id == event_id:
66
+ return event
67
+
68
+
69
+ def try_process_queue_event(method_name: str, *args, **kwargs):
70
+ queue = get_queue_instance()
71
+ if queue is None: # pragma: no cover
72
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
73
+ return
74
+ method = getattr(queue, method_name, None)
75
+ assert callable(method)
76
+ method(*args, **kwargs)
77
+
78
+
79
+ def patch_gradio_queue(
80
+ res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
81
+ ):
82
+
83
+ def rpc_method(method_name: str):
84
+ def method(*args, **kwargs):
85
+ if args and isinstance(args[0], Queue):
86
+ args = args[1:] # drop `self`
87
+ res_queue.put(GradioQueueEvent(method_name, args, kwargs))
88
+ return method
89
+
90
+ for method_name in QUEUE_RPC_METHODS:
91
+ if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
92
+ warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
93
+ continue
94
+ if not callable(method): # pragma: no cover
95
+ warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
96
+ continue
97
+ setattr(Queue, method_name, rpc_method(method_name))
98
+
99
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
100
+
101
+
102
+ def tracked_iterable__reduce__(self):
103
+ res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
104
+ cls, base, state, *_ = res
105
+ return cls, base,{**state, **{
106
+ 'iterable': None,
107
+ '_tqdm': None,
108
+ }}
spaces/zero/torch.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ # pyright: reportPrivateImportUsage=false
4
+
5
+ from __future__ import annotations
6
+
7
+ import multiprocessing
8
+ import os
9
+ from concurrent.futures import ProcessPoolExecutor
10
+ from contextlib import suppress
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+ from typing import TYPE_CHECKING
14
+ from typing import Any
15
+ from typing import Optional
16
+ from typing import Tuple
17
+
18
+ from ..config import Config
19
+ from . import bitsandbytes
20
+ from .utils import maybe_import_torch
21
+
22
+ if TYPE_CHECKING:
23
+ import torch as Torch
24
+
25
+
26
+ # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
27
+ CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
28
+ CUDA_TOTAL_MEMORY = 42144366592
29
+ CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
30
+ CUDA_DEVICE_CAPABILITY = (8, 0)
31
+ CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
32
+
33
+ GENERIC_METHOD_NAMES = [
34
+ 'arange',
35
+ 'as_tensor',
36
+ 'asarray',
37
+ 'bartlett_window',
38
+ 'blackman_window',
39
+ 'empty',
40
+ 'empty_like',
41
+ 'empty_strided',
42
+ 'eye',
43
+ 'full',
44
+ 'full_like',
45
+ 'hamming_window',
46
+ 'hann_window',
47
+ 'kaiser_window',
48
+ 'linspace',
49
+ 'logspace',
50
+ 'obj',
51
+ 'ones',
52
+ 'ones_like',
53
+ 'rand',
54
+ 'rand_like',
55
+ 'randint',
56
+ 'randint_like',
57
+ 'randn',
58
+ 'randn_like',
59
+ 'randperm',
60
+ 'range',
61
+ 'sparse_bsc_tensor',
62
+ 'sparse_bsr_tensor',
63
+ 'sparse_compressed_tensor',
64
+ 'sparse_coo_tensor',
65
+ 'sparse_csc_tensor',
66
+ 'sparse_csr_tensor',
67
+ 'tensor',
68
+ 'tril_indices',
69
+ 'triu_indices',
70
+ 'zeros',
71
+ 'zeros_like',
72
+ ]
73
+
74
+
75
+ if (torch := maybe_import_torch()):
76
+
77
+ from torch.utils.weak import WeakTensorKeyDictionary
78
+
79
+ TO_CUDA = (torch.device('cuda'), None, False, None)
80
+
81
+ _tensor__deepcopy__ = torch.Tensor.__deepcopy__
82
+ _tensor_to = torch.Tensor.to
83
+ _tensor_cuda = torch.Tensor.cuda
84
+ _tensor_cpu = torch.Tensor.cpu
85
+ _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
86
+ _cuda_init = torch._C._cuda_init
87
+ _cuda_available = torch.cuda.is_available
88
+ _cuda_device_count = torch.cuda.device_count
89
+ _cuda_current_device = torch.cuda.current_device
90
+ _cuda_mem_get_info = torch.cuda.mem_get_info
91
+ _cuda_get_device_capability = torch.cuda.get_device_capability
92
+ _cuda_get_device_properties = torch.cuda.get_device_properties
93
+ _cuda_get_device_name = torch.cuda.get_device_name
94
+
95
+ TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
96
+
97
+ to_ops: dict[Torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
98
+
99
+ def _tensor_new_register(*args, **kwargs):
100
+ new_tensor: Torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
101
+ if (base_tensor := new_tensor._base) is not None:
102
+ if base_tensor in to_ops:
103
+ to_ops[new_tensor] = to_ops[base_tensor]
104
+ return new_tensor
105
+
106
+ def _tensor_deepcopy_register(self: Torch.Tensor, memo):
107
+ new_tensor = _tensor__deepcopy__(self, memo)
108
+ if isinstance(new_tensor, torch.Tensor):
109
+ if self in to_ops:
110
+ to_ops[new_tensor] = to_ops[self]
111
+ return new_tensor
112
+
113
+ @property
114
+ def _tensor_device_property(self: Torch.Tensor):
115
+ if self in to_ops:
116
+ return torch.device(type='cuda', index=0)
117
+ del torch.Tensor.device
118
+ try:
119
+ return self.device
120
+ finally:
121
+ torch.Tensor.device = _tensor_device_property # type: ignore
122
+
123
+ @property
124
+ def _tensor_dtype_property(self: Torch.Tensor):
125
+ if self in to_ops:
126
+ if (to_dtype := to_ops[self][1]) is not None:
127
+ return to_dtype
128
+ del torch.Tensor.dtype
129
+ try:
130
+ return self.dtype
131
+ finally:
132
+ torch.Tensor.dtype = _tensor_dtype_property # type: ignore
133
+
134
+ def _to_op_register(self: Torch.Tensor, *args, **kwargs):
135
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
136
+ device, dtype, *_ = parsed
137
+ try:
138
+ to_args = to_ops.pop(self)
139
+ except KeyError:
140
+ to_args = None
141
+ if device is None:
142
+ if to_args is not None:
143
+ to_ops[self] = (to_args[0], dtype, *to_args[2:])
144
+ return self
145
+ return _tensor_to(self, *args, **kwargs)
146
+ if device.type != 'cuda':
147
+ if to_args is not None:
148
+ if (to_dtype := to_args[1]) is not None:
149
+ kwargs = {'dtype': to_dtype, **kwargs}
150
+ return _tensor_to(self, *args, **kwargs)
151
+ to_ops[self] = parsed
152
+ return self
153
+
154
+ def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
155
+ if device is None:
156
+ return True
157
+ if isinstance(device, int):
158
+ return True
159
+ if isinstance(device, str):
160
+ device = torch.device(device)
161
+ return device.type == 'cuda'
162
+
163
+ def _cuda_op_register(self: Torch.Tensor, device: Torch.device | int | str | None = None, **kwargs):
164
+ if not _cuda_op_arg_check(device):
165
+ # Let PyTorch handle the fail
166
+ return _tensor_cuda(self, device, **kwargs)
167
+ to_ops[self] = TO_CUDA
168
+ return self
169
+
170
+ def _cpu_op_remove(self: Torch.Tensor, **kwargs):
171
+ try:
172
+ to_args = to_ops.pop(self)
173
+ except KeyError:
174
+ to_args = None
175
+ if to_args is not None:
176
+ if (to_dtype := to_args[1]) is not None:
177
+ return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
178
+ return _tensor_cpu(self, **kwargs)
179
+
180
+ def _cuda_init_raise():
181
+ raise RuntimeError(
182
+ "CUDA must not be initialized in the main process "
183
+ "on Spaces with Stateless GPU environment.\n"
184
+ "You can look at this Stacktrace to find out "
185
+ "which part of your code triggered a CUDA init"
186
+ )
187
+
188
+ def _generic_method_register(name: str, *args: Any, **kwargs: Any):
189
+ try:
190
+ device = torch.device(kwargs.get('device', "cpu"))
191
+ except Exception:
192
+ return _torch_generics[name](*args, **kwargs)
193
+ if device.type != 'cuda':
194
+ return _torch_generics[name](*args, **kwargs)
195
+ tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
196
+ to_ops[tensor] = TO_CUDA
197
+ return tensor
198
+
199
+ def _patch():
200
+ torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
201
+ torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
202
+ torch.Tensor.to = _to_op_register # type: ignore
203
+ torch.Tensor.cuda = _cuda_op_register # type: ignore
204
+ torch.Tensor.cpu = _cpu_op_remove # type: ignore
205
+ if Config.zero_patch_torch_device:
206
+ torch.Tensor.device = _tensor_device_property # type: ignore
207
+ torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
208
+ for name in GENERIC_METHOD_NAMES:
209
+ setattr(torch, name, partial(_generic_method_register, name))
210
+ torch._C._cuda_init = _cuda_init_raise
211
+ torch.cuda.is_available = lambda: True
212
+ torch.cuda.device_count = lambda: 1
213
+ torch.cuda.current_device = lambda: 0
214
+ torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
215
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
216
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
217
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
218
+ bitsandbytes.patch()
219
+
220
+ def _unpatch():
221
+ torch.Tensor.__deepcopy__ = _tensor__deepcopy__
222
+ with suppress(AttributeError):
223
+ del torch.Tensor.__new__
224
+ torch.Tensor.to = _tensor_to
225
+ torch.Tensor.cuda = _tensor_cuda
226
+ torch.Tensor.cpu = _tensor_cpu
227
+ with suppress(AttributeError):
228
+ del torch.Tensor.device
229
+ with suppress(AttributeError):
230
+ del torch.Tensor.dtype
231
+ for name in GENERIC_METHOD_NAMES:
232
+ setattr(torch, name, _torch_generics[name])
233
+ torch._C._cuda_init = _cuda_init
234
+ torch.cuda.is_available = _cuda_available
235
+ torch.cuda.device_count = _cuda_device_count
236
+ torch.cuda.current_device = _cuda_current_device
237
+ torch.cuda.mem_get_info = _cuda_mem_get_info
238
+ torch.cuda.get_device_capability = _cuda_get_device_capability
239
+ torch.cuda.get_device_properties = _cuda_get_device_properties
240
+ torch.cuda.get_device_name = _cuda_get_device_name
241
+ bitsandbytes.unpatch()
242
+
243
+ def _move(nvidia_uuid: str):
244
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
245
+ torch.Tensor([0]).cuda() # CUDA init
246
+ for op in to_ops.items():
247
+ tensor, parsed_args = op
248
+ _, dtype, _, memory_format = parsed_args
249
+ tensor.data = _tensor_to(tensor,
250
+ device='cuda',
251
+ dtype=dtype,
252
+ memory_format=memory_format,
253
+ ) # type: ignore
254
+ bitsandbytes.move()
255
+ torch.cuda.synchronize()
256
+
257
+ def _is_in_bad_fork():
258
+ with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
259
+ f = e.submit(torch.cuda._is_in_bad_fork)
260
+ return f.result()
261
+
262
+ def _disable_cuda_intercept():
263
+ torch.Tensor.to = _tensor_to
264
+ torch.Tensor.cuda = _tensor_cuda
265
+
266
+ else:
267
+
268
+ _patch = lambda: None
269
+ _unpatch = lambda: None
270
+ _move = lambda nvidia_uuid: None
271
+ _is_in_bad_fork = lambda: False
272
+ _disable_cuda_intercept = lambda: None
273
+
274
+
275
+ patch = _patch
276
+ unpatch = _unpatch
277
+ move = _move
278
+ is_in_bad_fork = _is_in_bad_fork
279
+ disable_cuda_intercept = _disable_cuda_intercept
spaces/zero/tqdm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from multiprocessing.synchronize import RLock as MultiprocessingRLock
5
+
6
+
7
+ def remove_tqdm_multiprocessing_lock():
8
+ from tqdm import tqdm
9
+ tqdm_lock = tqdm.get_lock()
10
+ assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
11
+ tqdm_lock.locks = [
12
+ lock for lock in tqdm_lock.locks
13
+ if not isinstance(lock, MultiprocessingRLock)
14
+ ]
spaces/zero/types.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+ from typing import Dict
9
+ from typing import Tuple
10
+ from typing import TypedDict
11
+ from typing_extensions import Generic
12
+ from typing_extensions import ParamSpec
13
+ from typing_extensions import TypeAlias
14
+ from typing_extensions import TypeVar
15
+
16
+
17
+ Params = Tuple[Tuple[object, ...], Dict[str, Any]]
18
+ Res = TypeVar('Res')
19
+ Param = ParamSpec('Param')
20
+
21
+ class EmptyKwargs(TypedDict):
22
+ pass
23
+
24
+ @dataclass
25
+ class OkResult(Generic[Res]):
26
+ value: Res
27
+ @dataclass
28
+ class ExceptionResult:
29
+ value: Exception
30
+ @dataclass
31
+ class AbortedResult:
32
+ pass
33
+ @dataclass
34
+ class EndResult:
35
+ pass
36
+ @dataclass
37
+ class GradioQueueEvent:
38
+ method_name: str
39
+ args: tuple[Any, ...]
40
+ kwargs: dict[str, Any]
41
+
42
+ RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
43
+ GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
44
+ YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
spaces/zero/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ from contextlib import contextmanager
6
+ from importlib import metadata
7
+ from types import ModuleType
8
+
9
+ from packaging import version
10
+
11
+ from ..config import Config
12
+
13
+
14
+ def maybe_import_torch():
15
+ if not Config.zero_gpu:
16
+ return None
17
+ try:
18
+ import torch
19
+ except ImportError:
20
+ return None
21
+ return torch
22
+
23
+
24
+ @contextmanager
25
+ def cuda_unavailable(torch: ModuleType):
26
+ _is_available = torch.cuda.is_available
27
+ torch.cuda.is_available = lambda: False
28
+ yield
29
+ torch.cuda.is_available = _is_available
30
+
31
+
32
+ def maybe_import_bitsandbytes():
33
+ if (torch := maybe_import_torch()) is None:
34
+ return None # pragma: no cover
35
+ with cuda_unavailable(torch):
36
+ try:
37
+ import bitsandbytes
38
+ except ImportError:
39
+ bitsandbytes = None
40
+ else:
41
+ if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
42
+ raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
43
+ print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
44
+ return bitsandbytes
spaces/zero/wrappers.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ from __future__ import annotations
4
+
5
+ import multiprocessing
6
+ import os
7
+ import signal
8
+ import traceback
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from contextvars import copy_context
11
+ from datetime import timedelta
12
+ from functools import partial
13
+ from functools import wraps
14
+ from multiprocessing.context import ForkProcess
15
+ from pickle import PicklingError
16
+ from queue import Empty
17
+ from queue import Queue as ThreadQueue
18
+ from threading import Thread
19
+ from typing import TYPE_CHECKING
20
+ from typing import Callable
21
+ from typing import Generator
22
+ from typing import Generic
23
+ from typing_extensions import assert_never
24
+
25
+ import gradio as gr
26
+ import psutil
27
+
28
+ from ..utils import debug
29
+ from ..utils import drop_params
30
+ from ..utils import gradio_request_var
31
+ from ..utils import SimpleQueue as Queue
32
+ from . import client
33
+ from . import torch
34
+ from .api import AllowToken
35
+ from .api import NvidiaIndex
36
+ from .api import NvidiaUUID
37
+ from .gradio import GradioPartialContext
38
+ from .gradio import patch_gradio_queue
39
+ from .gradio import try_process_queue_event
40
+ from .tqdm import remove_tqdm_multiprocessing_lock
41
+ from .types import * # TODO: Please don't do that
42
+
43
+
44
+ GENERATOR_GLOBAL_TIMEOUT = 20 * 60
45
+
46
+
47
+ Process = multiprocessing.get_context('fork').Process
48
+ forked = False
49
+
50
+
51
+ class Worker(Generic[Res]):
52
+ process: ForkProcess
53
+ arg_queue: Queue[tuple[Params, GradioPartialContext]]
54
+ res_queue: Queue[Res | None]
55
+ _sentinel: Thread
56
+
57
+ def __init__(
58
+ self,
59
+ target: Callable[[
60
+ Queue[tuple[Params, GradioPartialContext]],
61
+ Queue[Res | None],
62
+ AllowToken | None,
63
+ NvidiaUUID,
64
+ list[int],
65
+ ], None],
66
+ allow_token: str | None,
67
+ nvidia_uuid: str,
68
+ ):
69
+ self._sentinel = Thread(target=self._close_on_exit)
70
+ self.arg_queue = Queue()
71
+ self.res_queue = Queue()
72
+ fds = [c.fd for c in psutil.Process().connections()]
73
+ args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
74
+ if TYPE_CHECKING:
75
+ target(*args)
76
+ self.process = Process(
77
+ target=target,
78
+ args=args,
79
+ daemon=True,
80
+ )
81
+ self.process.start()
82
+ self._sentinel.start()
83
+
84
+ def _close_on_exit(self):
85
+ self.process.join()
86
+ self.res_queue.put(None)
87
+
88
+
89
+ def worker_init(
90
+ res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
91
+ allow_token: str | None,
92
+ nvidia_uuid: str,
93
+ fds: list[int],
94
+ ) -> None | ExceptionResult:
95
+ try: # Unrecoverable init part
96
+ if allow_token is not None:
97
+ client.allow(allow_token)
98
+ torch.unpatch()
99
+ torch.move(nvidia_uuid)
100
+ patch_gradio_queue(res_queue)
101
+ except Exception as e: # pragma: no cover
102
+ traceback.print_exc()
103
+ return ExceptionResult(e)
104
+ try:
105
+ remove_tqdm_multiprocessing_lock()
106
+ except Exception: # pragma: no cover
107
+ print("Error while trying to remove tqdm mp_lock:")
108
+ traceback.print_exc()
109
+ for fd in fds:
110
+ try:
111
+ os.close(fd)
112
+ except Exception as e: # pragma: no cover
113
+ if isinstance(e, OSError) and e.errno == 9:
114
+ continue
115
+ traceback.print_exc()
116
+ return ExceptionResult(e)
117
+
118
+
119
+ def regular_function_wrapper(
120
+ task: Callable[Param, Res],
121
+ duration: timedelta | None,
122
+ ) -> Callable[Param, Res]:
123
+
124
+ request_var = gradio_request_var()
125
+ workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
126
+ task_id = id(task)
127
+
128
+ @wraps(task)
129
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
130
+
131
+ if forked:
132
+ return task(*args, **kwargs)
133
+
134
+ request = request_var.get()
135
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
136
+ allow_token = schedule_response.allowToken
137
+ nvidia_index = schedule_response.nvidiaIndex
138
+ nvidia_uuid = schedule_response.nvidiaUUID
139
+ release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
140
+
141
+ worker = workers.get(nvidia_index)
142
+ if worker is None or not worker.process.is_alive():
143
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
144
+ workers[nvidia_index] = worker
145
+
146
+ try:
147
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
148
+ except PicklingError:
149
+ release(fail=True)
150
+ # TODO: Better error message (check what arg / kwarg is problematic ?)
151
+ raise
152
+
153
+ while True:
154
+ res = worker.res_queue.get()
155
+ if res is None:
156
+ release(fail=True, allow_404=True)
157
+ raise gr.Error("GPU task aborted")
158
+ if isinstance(res, ExceptionResult):
159
+ release(fail=True)
160
+ raise res.value
161
+ if isinstance(res, OkResult):
162
+ release()
163
+ return res.value
164
+ if isinstance(res, GradioQueueEvent):
165
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
166
+ continue
167
+ assert_never(res)
168
+
169
+
170
+ def thread_wrapper(
171
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
172
+ res_queue: Queue[RegularResQueueResult[Res] | None],
173
+ allow_token: str | None,
174
+ nvidia_uuid: str,
175
+ fds: list[int],
176
+ ):
177
+ global forked
178
+ forked = True
179
+ if (res := worker_init(
180
+ res_queue=res_queue,
181
+ allow_token=allow_token,
182
+ nvidia_uuid=nvidia_uuid,
183
+ fds=fds,
184
+ )) is not None: # pragma: no cover
185
+ res_queue.put(res)
186
+ return
187
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
188
+ while True:
189
+ try:
190
+ (args, kwargs), gradio_context = arg_queue.get()
191
+ except OSError:
192
+ break
193
+ GradioPartialContext.apply(gradio_context)
194
+ context = copy_context()
195
+ with ThreadPoolExecutor() as executor:
196
+ future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
197
+ try:
198
+ res = future.result()
199
+ except Exception as e:
200
+ traceback.print_exc()
201
+ res = ExceptionResult(e)
202
+ else:
203
+ res = OkResult(res)
204
+ try:
205
+ res_queue.put(res)
206
+ except PicklingError as e:
207
+ res_queue.put(ExceptionResult(e))
208
+
209
+ # https://github.com/python/cpython/issues/91002
210
+ if not hasattr(task, '__annotations__'):
211
+ gradio_handler.__annotations__ = {}
212
+
213
+ return gradio_handler
214
+
215
+
216
+ def generator_function_wrapper(
217
+ task: Callable[Param, Generator[Res, None, None]],
218
+ duration: timedelta | None,
219
+ ) -> Callable[Param, Generator[Res, None, None]]:
220
+
221
+ request_var = gradio_request_var()
222
+ workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
223
+ task_id = id(task)
224
+
225
+ @wraps(task)
226
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
227
+
228
+ if forked:
229
+ yield from task(*args, **kwargs)
230
+ return
231
+
232
+ request = request_var.get()
233
+ schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
234
+ allow_token = schedule_response.allowToken
235
+ nvidia_index = schedule_response.nvidiaIndex
236
+ nvidia_uuid = schedule_response.nvidiaUUID
237
+ release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
238
+
239
+ worker = workers.get(nvidia_index)
240
+ if worker is None or not worker.process.is_alive():
241
+ worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
242
+ workers[nvidia_index] = worker
243
+
244
+ try:
245
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
246
+ except PicklingError:
247
+ release(fail=True)
248
+ raise
249
+
250
+ yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
251
+ def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
252
+ while True:
253
+ res = worker.res_queue.get()
254
+ if res is None:
255
+ release(fail=True, allow_404=True)
256
+ yield_queue.put(AbortedResult())
257
+ return
258
+ if isinstance(res, ExceptionResult):
259
+ release(fail=True)
260
+ yield_queue.put(ExceptionResult(res.value))
261
+ return
262
+ if isinstance(res, EndResult):
263
+ release()
264
+ yield_queue.put(EndResult())
265
+ return
266
+ if isinstance(res, OkResult):
267
+ yield_queue.put(OkResult(res.value))
268
+ continue
269
+ if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
270
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
271
+ continue
272
+ debug(f"fill_yield_queue: assert_never({res=})")
273
+ assert_never(res)
274
+ from typing_extensions import assert_never
275
+ with ThreadPoolExecutor() as e:
276
+ f = e.submit(fill_yield_queue, worker)
277
+ f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
278
+ while True:
279
+ try:
280
+ res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
281
+ except Empty: # pragma: no cover
282
+ debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
283
+ raise
284
+ if isinstance(res, AbortedResult):
285
+ raise gr.Error("GPU task aborted")
286
+ if isinstance(res, ExceptionResult):
287
+ raise res.value
288
+ if isinstance(res, EndResult):
289
+ break
290
+ if isinstance(res, OkResult):
291
+ yield res.value
292
+ continue
293
+ debug(f"gradio_handler: assert_never({res=})")
294
+ assert_never(res)
295
+
296
+
297
+ def thread_wrapper(
298
+ arg_queue: Queue[tuple[Params, GradioPartialContext]],
299
+ res_queue: Queue[GeneratorResQueueResult[Res] | None],
300
+ allow_token: str | None,
301
+ nvidia_uuid: str,
302
+ fds: list[int],
303
+ ):
304
+ global forked
305
+ forked = True
306
+ if (res := worker_init(
307
+ res_queue=res_queue,
308
+ allow_token=allow_token,
309
+ nvidia_uuid=nvidia_uuid,
310
+ fds=fds,
311
+ )) is not None: # pragma: no cover
312
+ res_queue.put(res)
313
+ return
314
+ signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
315
+ while True:
316
+ try:
317
+ (args, kwargs), gradio_context = arg_queue.get()
318
+ except OSError:
319
+ break
320
+ def iterate():
321
+ gen = task(*args, **kwargs) # type: ignore
322
+ while True:
323
+ try:
324
+ res = next(gen)
325
+ except StopIteration:
326
+ break
327
+ except Exception as e:
328
+ res_queue.put(ExceptionResult(e))
329
+ break
330
+ try:
331
+ res_queue.put(OkResult(res))
332
+ except PicklingError as e:
333
+ res_queue.put(ExceptionResult(e))
334
+ break
335
+ else:
336
+ continue
337
+ GradioPartialContext.apply(gradio_context)
338
+ context = copy_context()
339
+ with ThreadPoolExecutor() as executor:
340
+ executor.submit(context.run, iterate)
341
+ res_queue.put(EndResult())
342
+
343
+ # https://github.com/python/cpython/issues/91002
344
+ if not hasattr(task, '__annotations__'):
345
+ gradio_handler.__annotations__ = {}
346
+
347
+ return gradio_handler