Samuel Kristianto
commited on
Commit
·
f2e2be2
1
Parent(s):
ab81f36
updated blocks.py
Browse files- fm23da_files/blocks.py +151 -121
fm23da_files/blocks.py
CHANGED
@@ -12,7 +12,7 @@ import warnings
|
|
12 |
import webbrowser
|
13 |
from abc import abstractmethod
|
14 |
from types import ModuleType
|
15 |
-
from typing import TYPE_CHECKING, Any,
|
16 |
|
17 |
import anyio
|
18 |
import requests
|
@@ -20,6 +20,7 @@ from anyio import CapacityLimiter
|
|
20 |
from gradio_client import serializing
|
21 |
from gradio_client import utils as client_utils
|
22 |
from gradio_client.documentation import document, set_documentation_group
|
|
|
23 |
from typing_extensions import Literal
|
24 |
|
25 |
from gradio import (
|
@@ -34,7 +35,7 @@ from gradio import (
|
|
34 |
)
|
35 |
from gradio.context import Context
|
36 |
from gradio.deprecation import check_deprecated_parameters
|
37 |
-
from gradio.exceptions import DuplicateBlockError,
|
38 |
from gradio.helpers import EventData, create_tracker, skip, special_args
|
39 |
from gradio.themes import Default as DefaultTheme
|
40 |
from gradio.themes import ThemeClass as Theme
|
@@ -56,7 +57,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
|
56 |
|
57 |
from gradio.components import Component
|
58 |
|
59 |
-
BUILT_IN_THEMES:
|
60 |
t.name: t
|
61 |
for t in [
|
62 |
themes.Base(),
|
@@ -74,7 +75,7 @@ class Block:
|
|
74 |
*,
|
75 |
render: bool = True,
|
76 |
elem_id: str | None = None,
|
77 |
-
elem_classes:
|
78 |
visible: bool = True,
|
79 |
root_url: str | None = None, # URL that is prepended to all file paths
|
80 |
_skip_init_processing: bool = False, # Used for loading from Spaces
|
@@ -145,15 +146,15 @@ class Block:
|
|
145 |
else self.__class__.__name__.lower()
|
146 |
)
|
147 |
|
148 |
-
def get_expected_parent(self) ->
|
149 |
return None
|
150 |
|
151 |
def set_event_trigger(
|
152 |
self,
|
153 |
event_name: str,
|
154 |
fn: Callable | None,
|
155 |
-
inputs: Component |
|
156 |
-
outputs: Component |
|
157 |
preprocess: bool = True,
|
158 |
postprocess: bool = True,
|
159 |
scroll_to_output: bool = False,
|
@@ -164,12 +165,12 @@ class Block:
|
|
164 |
queue: bool | None = None,
|
165 |
batch: bool = False,
|
166 |
max_batch_size: int = 4,
|
167 |
-
cancels:
|
168 |
every: float | None = None,
|
169 |
collects_event_data: bool | None = None,
|
170 |
trigger_after: int | None = None,
|
171 |
trigger_only_on_success: bool = False,
|
172 |
-
) ->
|
173 |
"""
|
174 |
Adds an event to the component's dependencies.
|
175 |
Parameters:
|
@@ -251,7 +252,7 @@ class Block:
|
|
251 |
api_name_ = utils.append_unique_suffix(
|
252 |
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
|
253 |
)
|
254 |
-
if
|
255 |
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
|
256 |
api_name = api_name_
|
257 |
|
@@ -295,11 +296,11 @@ class Block:
|
|
295 |
|
296 |
@staticmethod
|
297 |
@abstractmethod
|
298 |
-
def update(**kwargs) ->
|
299 |
return {}
|
300 |
|
301 |
@classmethod
|
302 |
-
def get_specific_update(cls, generic_update:
|
303 |
generic_update = generic_update.copy()
|
304 |
del generic_update["__type__"]
|
305 |
specific_update = cls.update(**generic_update)
|
@@ -318,7 +319,7 @@ class BlockContext(Block):
|
|
318 |
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
|
319 |
render: If False, this will not be included in the Blocks config file at all.
|
320 |
"""
|
321 |
-
self.children:
|
322 |
Block.__init__(self, visible=visible, render=render, **kwargs)
|
323 |
|
324 |
def __enter__(self):
|
@@ -368,8 +369,8 @@ class BlockFunction:
|
|
368 |
def __init__(
|
369 |
self,
|
370 |
fn: Callable | None,
|
371 |
-
inputs:
|
372 |
-
outputs:
|
373 |
preprocess: bool,
|
374 |
postprocess: bool,
|
375 |
inputs_as_dict: bool,
|
@@ -399,13 +400,13 @@ class BlockFunction:
|
|
399 |
return str(self)
|
400 |
|
401 |
|
402 |
-
class class_or_instancemethod(classmethod):
|
403 |
def __get__(self, instance, type_):
|
404 |
descr_get = super().__get__ if instance is None else self.__func__.__get__
|
405 |
return descr_get(instance, type_)
|
406 |
|
407 |
|
408 |
-
def postprocess_update_dict(block: Block, update_dict:
|
409 |
"""
|
410 |
Converts a dictionary of updates into a format that can be sent to the frontend.
|
411 |
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
|
@@ -433,15 +434,15 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool =
|
|
433 |
|
434 |
|
435 |
def convert_component_dict_to_list(
|
436 |
-
outputs_ids:
|
437 |
-
) ->
|
438 |
"""
|
439 |
Converts a dictionary of component updates into a list of updates in the order of
|
440 |
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
|
441 |
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
|
442 |
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
|
443 |
"""
|
444 |
-
keys_are_blocks = [isinstance(key, Block) for key in predictions
|
445 |
if all(keys_are_blocks):
|
446 |
reordered_predictions = [skip() for _ in outputs_ids]
|
447 |
for component, value in predictions.items():
|
@@ -459,7 +460,7 @@ def convert_component_dict_to_list(
|
|
459 |
return predictions
|
460 |
|
461 |
|
462 |
-
def get_api_info(config:
|
463 |
"""
|
464 |
Gets the information needed to generate the API docs from a Blocks config.
|
465 |
Parameters:
|
@@ -468,10 +469,14 @@ def get_api_info(config: Dict, serialize: bool = True):
|
|
468 |
"""
|
469 |
api_info = {"named_endpoints": {}, "unnamed_endpoints": {}}
|
470 |
mode = config.get("mode", None)
|
|
|
|
|
|
|
471 |
|
472 |
for d, dependency in enumerate(config["dependencies"]):
|
473 |
dependency_info = {"parameters": [], "returns": []}
|
474 |
skip_endpoint = False
|
|
|
475 |
|
476 |
inputs = dependency["inputs"]
|
477 |
for i in inputs:
|
@@ -479,44 +484,51 @@ def get_api_info(config: Dict, serialize: bool = True):
|
|
479 |
if component["id"] == i:
|
480 |
break
|
481 |
else:
|
482 |
-
skip_endpoint = True # if component not found, skip
|
483 |
break
|
484 |
type = component["type"]
|
485 |
if (
|
486 |
not component.get("serializer")
|
487 |
and type not in serializing.COMPONENT_MAPPING
|
488 |
):
|
489 |
-
skip_endpoint =
|
490 |
-
True # if component is not serializable, skip this endpoint
|
491 |
-
)
|
492 |
break
|
|
|
|
|
493 |
label = component["props"].get("label", f"parameter_{i}")
|
494 |
# The config has the most specific API info (taking into account the parameters
|
495 |
# of the component), so we use that if it exists. Otherwise, we fallback to the
|
496 |
# Serializer's API info.
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
else:
|
502 |
-
info = component["api_info"]["raw_input"]
|
503 |
-
example = component["example_inputs"]["raw"]
|
504 |
else:
|
505 |
-
serializer = serializing.COMPONENT_MAPPING[type]()
|
506 |
assert isinstance(serializer, serializing.Serializable)
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
dependency_info["parameters"].append(
|
514 |
{
|
515 |
"label": label,
|
516 |
-
"
|
517 |
-
"
|
|
|
|
|
|
|
518 |
"component": type.capitalize(),
|
519 |
"example_input": example,
|
|
|
520 |
}
|
521 |
)
|
522 |
|
@@ -526,30 +538,41 @@ def get_api_info(config: Dict, serialize: bool = True):
|
|
526 |
if component["id"] == o:
|
527 |
break
|
528 |
else:
|
529 |
-
skip_endpoint = True # if component not found, skip
|
530 |
break
|
531 |
type = component["type"]
|
532 |
if (
|
533 |
not component.get("serializer")
|
534 |
and type not in serializing.COMPONENT_MAPPING
|
535 |
):
|
536 |
-
skip_endpoint =
|
537 |
-
True # if component is not serializable, skip this endpoint
|
538 |
-
)
|
539 |
break
|
|
|
|
|
540 |
label = component["props"].get("label", f"value_{o}")
|
541 |
serializer = serializing.COMPONENT_MAPPING[type]()
|
542 |
assert isinstance(serializer, serializing.Serializable)
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
dependency_info["returns"].append(
|
548 |
{
|
549 |
"label": label,
|
550 |
-
"
|
551 |
-
"
|
|
|
|
|
|
|
552 |
"component": type.capitalize(),
|
|
|
553 |
}
|
554 |
)
|
555 |
|
@@ -662,8 +685,8 @@ class Blocks(BlockContext):
|
|
662 |
if not self.analytics_enabled:
|
663 |
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
664 |
super().__init__(render=False, **kwargs)
|
665 |
-
self.blocks:
|
666 |
-
self.fns:
|
667 |
self.dependencies = []
|
668 |
self.mode = mode
|
669 |
|
@@ -674,7 +697,7 @@ class Blocks(BlockContext):
|
|
674 |
self.height = None
|
675 |
self.api_open = True
|
676 |
|
677 |
-
self.is_space =
|
678 |
self.favicon_path = None
|
679 |
self.auth = None
|
680 |
self.dev_mode = True
|
@@ -692,7 +715,8 @@ class Blocks(BlockContext):
|
|
692 |
self.progress_tracking = None
|
693 |
self.ssl_verify = True
|
694 |
|
695 |
-
self.
|
|
|
696 |
|
697 |
if self.analytics_enabled:
|
698 |
is_custom_theme = not any(
|
@@ -712,7 +736,7 @@ class Blocks(BlockContext):
|
|
712 |
def from_config(
|
713 |
cls,
|
714 |
config: dict,
|
715 |
-
fns:
|
716 |
root_url: str | None = None,
|
717 |
) -> Blocks:
|
718 |
"""
|
@@ -726,7 +750,7 @@ class Blocks(BlockContext):
|
|
726 |
config = copy.deepcopy(config)
|
727 |
components_config = config["components"]
|
728 |
theme = config.get("theme", "default")
|
729 |
-
original_mapping:
|
730 |
|
731 |
def get_block_instance(id: int) -> Block:
|
732 |
for block_config in components_config:
|
@@ -843,10 +867,13 @@ class Blocks(BlockContext):
|
|
843 |
raise DuplicateBlockError(
|
844 |
f"A block with id: {self._id} has already been rendered in the current Blocks."
|
845 |
)
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
)
|
|
|
|
|
|
|
850 |
|
851 |
Context.root_block.blocks.update(self.blocks)
|
852 |
Context.root_block.fns.extend(self.fns)
|
@@ -858,7 +885,7 @@ class Blocks(BlockContext):
|
|
858 |
api_name,
|
859 |
[dep["api_name"] for dep in Context.root_block.dependencies],
|
860 |
)
|
861 |
-
if
|
862 |
warnings.warn(
|
863 |
f"api_name {api_name} already exists, using {api_name_}"
|
864 |
)
|
@@ -933,7 +960,9 @@ class Blocks(BlockContext):
|
|
933 |
None,
|
934 |
)
|
935 |
if inferred_fn_index is None:
|
936 |
-
raise
|
|
|
|
|
937 |
fn_index = inferred_fn_index
|
938 |
if not (self.is_callable(fn_index)):
|
939 |
raise ValueError(
|
@@ -966,9 +995,9 @@ class Blocks(BlockContext):
|
|
966 |
async def call_function(
|
967 |
self,
|
968 |
fn_index: int,
|
969 |
-
processed_input:
|
970 |
-
iterator:
|
971 |
-
requests: routes.Request |
|
972 |
event_id: str | None = None,
|
973 |
event_data: EventData | None = None,
|
974 |
):
|
@@ -987,17 +1016,9 @@ class Blocks(BlockContext):
|
|
987 |
is_generating = False
|
988 |
|
989 |
if block_fn.inputs_as_dict:
|
990 |
-
processed_input = [
|
991 |
-
{
|
992 |
-
input_component: data
|
993 |
-
for input_component, data in zip(block_fn.inputs, processed_input)
|
994 |
-
}
|
995 |
-
]
|
996 |
|
997 |
-
if isinstance(requests, list)
|
998 |
-
request = requests[0]
|
999 |
-
else:
|
1000 |
-
request = requests
|
1001 |
processed_input, progress_index, _ = special_args(
|
1002 |
block_fn.fn, processed_input, request, event_data
|
1003 |
)
|
@@ -1025,17 +1046,17 @@ class Blocks(BlockContext):
|
|
1025 |
else:
|
1026 |
prediction = None
|
1027 |
|
1028 |
-
if inspect.
|
1029 |
-
|
1030 |
-
|
1031 |
if not self.enable_queue:
|
1032 |
raise ValueError("Need to enable queue to use generators.")
|
1033 |
try:
|
1034 |
if iterator is None:
|
1035 |
iterator = prediction
|
1036 |
-
|
1037 |
-
utils.
|
1038 |
-
)
|
1039 |
is_generating = True
|
1040 |
except StopAsyncIteration:
|
1041 |
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
@@ -1055,7 +1076,7 @@ class Blocks(BlockContext):
|
|
1055 |
"iterator": iterator,
|
1056 |
}
|
1057 |
|
1058 |
-
def serialize_data(self, fn_index: int, inputs:
|
1059 |
dependency = self.dependencies[fn_index]
|
1060 |
processed_input = []
|
1061 |
|
@@ -1069,7 +1090,7 @@ class Blocks(BlockContext):
|
|
1069 |
|
1070 |
return processed_input
|
1071 |
|
1072 |
-
def deserialize_data(self, fn_index: int, outputs:
|
1073 |
dependency = self.dependencies[fn_index]
|
1074 |
predictions = []
|
1075 |
|
@@ -1085,7 +1106,7 @@ class Blocks(BlockContext):
|
|
1085 |
|
1086 |
return predictions
|
1087 |
|
1088 |
-
def validate_inputs(self, fn_index: int, inputs:
|
1089 |
block_fn = self.fns[fn_index]
|
1090 |
dependency = self.dependencies[fn_index]
|
1091 |
|
@@ -1107,10 +1128,7 @@ class Blocks(BlockContext):
|
|
1107 |
block = self.blocks[input_id]
|
1108 |
wanted_args.append(str(block))
|
1109 |
for inp in inputs:
|
1110 |
-
if isinstance(inp, str)
|
1111 |
-
v = f'"{inp}"'
|
1112 |
-
else:
|
1113 |
-
v = str(inp)
|
1114 |
received_args.append(v)
|
1115 |
|
1116 |
wanted = ", ".join(wanted_args)
|
@@ -1126,7 +1144,7 @@ Received inputs:
|
|
1126 |
[{received}]"""
|
1127 |
)
|
1128 |
|
1129 |
-
def preprocess_data(self, fn_index: int, inputs:
|
1130 |
block_fn = self.fns[fn_index]
|
1131 |
dependency = self.dependencies[fn_index]
|
1132 |
|
@@ -1147,7 +1165,7 @@ Received inputs:
|
|
1147 |
processed_input = inputs
|
1148 |
return processed_input
|
1149 |
|
1150 |
-
def validate_outputs(self, fn_index: int, predictions: Any |
|
1151 |
block_fn = self.fns[fn_index]
|
1152 |
dependency = self.dependencies[fn_index]
|
1153 |
|
@@ -1169,10 +1187,7 @@ Received inputs:
|
|
1169 |
block = self.blocks[output_id]
|
1170 |
wanted_args.append(str(block))
|
1171 |
for pred in predictions:
|
1172 |
-
if isinstance(pred, str)
|
1173 |
-
v = f'"{pred}"'
|
1174 |
-
else:
|
1175 |
-
v = str(pred)
|
1176 |
received_args.append(v)
|
1177 |
|
1178 |
wanted = ", ".join(wanted_args)
|
@@ -1187,7 +1202,7 @@ Received outputs:
|
|
1187 |
)
|
1188 |
|
1189 |
def postprocess_data(
|
1190 |
-
self, fn_index: int, predictions:
|
1191 |
):
|
1192 |
block_fn = self.fns[fn_index]
|
1193 |
dependency = self.dependencies[fn_index]
|
@@ -1211,10 +1226,11 @@ Received outputs:
|
|
1211 |
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
1212 |
output.append(None)
|
1213 |
continue
|
1214 |
-
except (IndexError, KeyError):
|
1215 |
raise ValueError(
|
1216 |
-
|
1217 |
-
|
|
|
1218 |
block = self.blocks[output_id]
|
1219 |
if getattr(block, "stateful", False):
|
1220 |
if not utils.is_update(predictions[i]):
|
@@ -1241,13 +1257,13 @@ Received outputs:
|
|
1241 |
async def process_api(
|
1242 |
self,
|
1243 |
fn_index: int,
|
1244 |
-
inputs:
|
1245 |
-
state:
|
1246 |
-
request: routes.Request |
|
1247 |
-
iterators:
|
1248 |
event_id: str | None = None,
|
1249 |
event_data: EventData | None = None,
|
1250 |
-
) ->
|
1251 |
"""
|
1252 |
Processes API calls from the frontend. First preprocesses the data,
|
1253 |
then runs the relevant function, then postprocesses the output.
|
@@ -1304,7 +1320,6 @@ Received outputs:
|
|
1304 |
|
1305 |
block_fn.total_runtime += result["duration"]
|
1306 |
block_fn.total_runs += 1
|
1307 |
-
|
1308 |
return {
|
1309 |
"data": data,
|
1310 |
"is_generating": is_generating,
|
@@ -1342,15 +1357,15 @@ Received outputs:
|
|
1342 |
"theme": self.theme.name,
|
1343 |
}
|
1344 |
|
1345 |
-
def
|
1346 |
if not isinstance(block, BlockContext):
|
1347 |
return {"id": block._id}
|
1348 |
children_layout = []
|
1349 |
for child in block.children:
|
1350 |
-
children_layout.append(
|
1351 |
return {"id": block._id, "children": children_layout}
|
1352 |
|
1353 |
-
config["layout"] =
|
1354 |
|
1355 |
for _id, block in self.blocks.items():
|
1356 |
props = block.get_config() if hasattr(block, "get_config") else {}
|
@@ -1393,10 +1408,10 @@ Received outputs:
|
|
1393 |
|
1394 |
@class_or_instancemethod
|
1395 |
def load(
|
1396 |
-
self_or_cls,
|
1397 |
fn: Callable | None = None,
|
1398 |
-
inputs:
|
1399 |
-
outputs:
|
1400 |
api_name: str | None = None,
|
1401 |
scroll_to_output: bool = False,
|
1402 |
show_progress: bool = True,
|
@@ -1413,7 +1428,7 @@ Received outputs:
|
|
1413 |
api_key: str | None = None,
|
1414 |
alias: str | None = None,
|
1415 |
**kwargs,
|
1416 |
-
) -> Blocks |
|
1417 |
"""
|
1418 |
For reverse compatibility reasons, this is both a class method and an instance
|
1419 |
method, the two of which, confusingly, do two completely different things.
|
@@ -1571,7 +1586,7 @@ Received outputs:
|
|
1571 |
debug: bool = False,
|
1572 |
enable_queue: bool | None = None,
|
1573 |
max_threads: int = 40,
|
1574 |
-
auth: Callable |
|
1575 |
auth_message: str | None = None,
|
1576 |
prevent_thread_lock: bool = False,
|
1577 |
show_error: bool = False,
|
@@ -1588,9 +1603,11 @@ Received outputs:
|
|
1588 |
ssl_verify: bool = True,
|
1589 |
quiet: bool = False,
|
1590 |
show_api: bool = True,
|
1591 |
-
file_directories:
|
|
|
|
|
1592 |
_frontend: bool = True,
|
1593 |
-
) ->
|
1594 |
"""
|
1595 |
Launches a simple web server that serves the demo. Can also be used to create a
|
1596 |
public link used by anyone to access the demo from their browser by setting share=True.
|
@@ -1619,7 +1636,9 @@ Received outputs:
|
|
1619 |
ssl_verify: If False, skips certificate validation which allows self-signed certificates to be used.
|
1620 |
quiet: If True, suppresses most print statements.
|
1621 |
show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
|
1622 |
-
file_directories:
|
|
|
|
|
1623 |
Returns:
|
1624 |
app: FastAPI app object that is running the demo
|
1625 |
local_url: Locally accessible link to the demo
|
@@ -1680,9 +1699,20 @@ Received outputs:
|
|
1680 |
self.queue()
|
1681 |
self.show_api = self.api_open if self.enable_queue else show_api
|
1682 |
|
1683 |
-
|
1684 |
-
|
1685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1686 |
|
1687 |
self.validate_queue_settings()
|
1688 |
|
@@ -1759,7 +1789,7 @@ Received outputs:
|
|
1759 |
# a shareable link must be created.
|
1760 |
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
|
1761 |
raise ValueError(
|
1762 |
-
"When localhost is not accessible, a shareable link must be created. Please set share=True."
|
1763 |
)
|
1764 |
|
1765 |
if self.is_colab:
|
@@ -1776,7 +1806,7 @@ Received outputs:
|
|
1776 |
)
|
1777 |
else:
|
1778 |
if not self.share:
|
1779 |
-
|
1780 |
|
1781 |
if self.share:
|
1782 |
if self.is_space:
|
|
|
12 |
import webbrowser
|
13 |
from abc import abstractmethod
|
14 |
from types import ModuleType
|
15 |
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
|
16 |
|
17 |
import anyio
|
18 |
import requests
|
|
|
20 |
from gradio_client import serializing
|
21 |
from gradio_client import utils as client_utils
|
22 |
from gradio_client.documentation import document, set_documentation_group
|
23 |
+
from packaging import version
|
24 |
from typing_extensions import Literal
|
25 |
|
26 |
from gradio import (
|
|
|
35 |
)
|
36 |
from gradio.context import Context
|
37 |
from gradio.deprecation import check_deprecated_parameters
|
38 |
+
from gradio.exceptions import DuplicateBlockError, InvalidApiNameError
|
39 |
from gradio.helpers import EventData, create_tracker, skip, special_args
|
40 |
from gradio.themes import Default as DefaultTheme
|
41 |
from gradio.themes import ThemeClass as Theme
|
|
|
57 |
|
58 |
from gradio.components import Component
|
59 |
|
60 |
+
BUILT_IN_THEMES: dict[str, Theme] = {
|
61 |
t.name: t
|
62 |
for t in [
|
63 |
themes.Base(),
|
|
|
75 |
*,
|
76 |
render: bool = True,
|
77 |
elem_id: str | None = None,
|
78 |
+
elem_classes: list[str] | str | None = None,
|
79 |
visible: bool = True,
|
80 |
root_url: str | None = None, # URL that is prepended to all file paths
|
81 |
_skip_init_processing: bool = False, # Used for loading from Spaces
|
|
|
146 |
else self.__class__.__name__.lower()
|
147 |
)
|
148 |
|
149 |
+
def get_expected_parent(self) -> type[BlockContext] | None:
|
150 |
return None
|
151 |
|
152 |
def set_event_trigger(
|
153 |
self,
|
154 |
event_name: str,
|
155 |
fn: Callable | None,
|
156 |
+
inputs: Component | list[Component] | set[Component] | None,
|
157 |
+
outputs: Component | list[Component] | None,
|
158 |
preprocess: bool = True,
|
159 |
postprocess: bool = True,
|
160 |
scroll_to_output: bool = False,
|
|
|
165 |
queue: bool | None = None,
|
166 |
batch: bool = False,
|
167 |
max_batch_size: int = 4,
|
168 |
+
cancels: list[int] | None = None,
|
169 |
every: float | None = None,
|
170 |
collects_event_data: bool | None = None,
|
171 |
trigger_after: int | None = None,
|
172 |
trigger_only_on_success: bool = False,
|
173 |
+
) -> tuple[dict[str, Any], int]:
|
174 |
"""
|
175 |
Adds an event to the component's dependencies.
|
176 |
Parameters:
|
|
|
252 |
api_name_ = utils.append_unique_suffix(
|
253 |
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
|
254 |
)
|
255 |
+
if api_name != api_name_:
|
256 |
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
|
257 |
api_name = api_name_
|
258 |
|
|
|
296 |
|
297 |
@staticmethod
|
298 |
@abstractmethod
|
299 |
+
def update(**kwargs) -> dict:
|
300 |
return {}
|
301 |
|
302 |
@classmethod
|
303 |
+
def get_specific_update(cls, generic_update: dict[str, Any]) -> dict:
|
304 |
generic_update = generic_update.copy()
|
305 |
del generic_update["__type__"]
|
306 |
specific_update = cls.update(**generic_update)
|
|
|
319 |
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
|
320 |
render: If False, this will not be included in the Blocks config file at all.
|
321 |
"""
|
322 |
+
self.children: list[Block] = []
|
323 |
Block.__init__(self, visible=visible, render=render, **kwargs)
|
324 |
|
325 |
def __enter__(self):
|
|
|
369 |
def __init__(
|
370 |
self,
|
371 |
fn: Callable | None,
|
372 |
+
inputs: list[Component],
|
373 |
+
outputs: list[Component],
|
374 |
preprocess: bool,
|
375 |
postprocess: bool,
|
376 |
inputs_as_dict: bool,
|
|
|
400 |
return str(self)
|
401 |
|
402 |
|
403 |
+
class class_or_instancemethod(classmethod): # noqa: N801
|
404 |
def __get__(self, instance, type_):
|
405 |
descr_get = super().__get__ if instance is None else self.__func__.__get__
|
406 |
return descr_get(instance, type_)
|
407 |
|
408 |
|
409 |
+
def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
|
410 |
"""
|
411 |
Converts a dictionary of updates into a format that can be sent to the frontend.
|
412 |
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
|
|
|
434 |
|
435 |
|
436 |
def convert_component_dict_to_list(
|
437 |
+
outputs_ids: list[int], predictions: dict
|
438 |
+
) -> list | dict:
|
439 |
"""
|
440 |
Converts a dictionary of component updates into a list of updates in the order of
|
441 |
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
|
442 |
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
|
443 |
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
|
444 |
"""
|
445 |
+
keys_are_blocks = [isinstance(key, Block) for key in predictions]
|
446 |
if all(keys_are_blocks):
|
447 |
reordered_predictions = [skip() for _ in outputs_ids]
|
448 |
for component, value in predictions.items():
|
|
|
460 |
return predictions
|
461 |
|
462 |
|
463 |
+
def get_api_info(config: dict, serialize: bool = True):
|
464 |
"""
|
465 |
Gets the information needed to generate the API docs from a Blocks config.
|
466 |
Parameters:
|
|
|
469 |
"""
|
470 |
api_info = {"named_endpoints": {}, "unnamed_endpoints": {}}
|
471 |
mode = config.get("mode", None)
|
472 |
+
after_new_format = version.parse(config.get("version", "2.0")) > version.Version(
|
473 |
+
"3.28.3"
|
474 |
+
)
|
475 |
|
476 |
for d, dependency in enumerate(config["dependencies"]):
|
477 |
dependency_info = {"parameters": [], "returns": []}
|
478 |
skip_endpoint = False
|
479 |
+
skip_components = ["state"]
|
480 |
|
481 |
inputs = dependency["inputs"]
|
482 |
for i in inputs:
|
|
|
484 |
if component["id"] == i:
|
485 |
break
|
486 |
else:
|
487 |
+
skip_endpoint = True # if component not found, skip endpoint
|
488 |
break
|
489 |
type = component["type"]
|
490 |
if (
|
491 |
not component.get("serializer")
|
492 |
and type not in serializing.COMPONENT_MAPPING
|
493 |
):
|
494 |
+
skip_endpoint = True # if component not serializable, skip endpoint
|
|
|
|
|
495 |
break
|
496 |
+
if type in skip_components:
|
497 |
+
continue
|
498 |
label = component["props"].get("label", f"parameter_{i}")
|
499 |
# The config has the most specific API info (taking into account the parameters
|
500 |
# of the component), so we use that if it exists. Otherwise, we fallback to the
|
501 |
# Serializer's API info.
|
502 |
+
serializer = serializing.COMPONENT_MAPPING[type]()
|
503 |
+
if component.get("api_info") and after_new_format:
|
504 |
+
info = component["api_info"]
|
505 |
+
example = component["example_inputs"]["serialized"]
|
|
|
|
|
|
|
506 |
else:
|
|
|
507 |
assert isinstance(serializer, serializing.Serializable)
|
508 |
+
info = serializer.api_info()
|
509 |
+
example = serializer.example_inputs()["raw"]
|
510 |
+
python_info = info["info"]
|
511 |
+
if serialize and info["serialized_info"]:
|
512 |
+
python_info = serializer.serialized_info()
|
513 |
+
if (
|
514 |
+
isinstance(serializer, serializing.FileSerializable)
|
515 |
+
and component["props"].get("file_count", "single") != "single"
|
516 |
+
):
|
517 |
+
python_info = serializer._multiple_file_serialized_info()
|
518 |
+
|
519 |
+
python_type = client_utils.json_schema_to_python_type(python_info)
|
520 |
+
serializer_name = serializing.COMPONENT_MAPPING[type].__name__
|
521 |
dependency_info["parameters"].append(
|
522 |
{
|
523 |
"label": label,
|
524 |
+
"type": info["info"],
|
525 |
+
"python_type": {
|
526 |
+
"type": python_type,
|
527 |
+
"description": python_info.get("description", ""),
|
528 |
+
},
|
529 |
"component": type.capitalize(),
|
530 |
"example_input": example,
|
531 |
+
"serializer": serializer_name,
|
532 |
}
|
533 |
)
|
534 |
|
|
|
538 |
if component["id"] == o:
|
539 |
break
|
540 |
else:
|
541 |
+
skip_endpoint = True # if component not found, skip endpoint
|
542 |
break
|
543 |
type = component["type"]
|
544 |
if (
|
545 |
not component.get("serializer")
|
546 |
and type not in serializing.COMPONENT_MAPPING
|
547 |
):
|
548 |
+
skip_endpoint = True # if component not serializable, skip endpoint
|
|
|
|
|
549 |
break
|
550 |
+
if type in skip_components:
|
551 |
+
continue
|
552 |
label = component["props"].get("label", f"value_{o}")
|
553 |
serializer = serializing.COMPONENT_MAPPING[type]()
|
554 |
assert isinstance(serializer, serializing.Serializable)
|
555 |
+
info = serializer.api_info()
|
556 |
+
python_info = info["info"]
|
557 |
+
if serialize and info["serialized_info"]:
|
558 |
+
python_info = serializer.serialized_info()
|
559 |
+
if (
|
560 |
+
isinstance(serializer, serializing.FileSerializable)
|
561 |
+
and component["props"].get("file_count", "single") != "single"
|
562 |
+
):
|
563 |
+
python_info = serializer._multiple_file_serialized_info()
|
564 |
+
python_type = client_utils.json_schema_to_python_type(python_info)
|
565 |
+
serializer_name = serializing.COMPONENT_MAPPING[type].__name__
|
566 |
dependency_info["returns"].append(
|
567 |
{
|
568 |
"label": label,
|
569 |
+
"type": info["info"],
|
570 |
+
"python_type": {
|
571 |
+
"type": python_type,
|
572 |
+
"description": python_info.get("description", ""),
|
573 |
+
},
|
574 |
"component": type.capitalize(),
|
575 |
+
"serializer": serializer_name,
|
576 |
}
|
577 |
)
|
578 |
|
|
|
685 |
if not self.analytics_enabled:
|
686 |
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
687 |
super().__init__(render=False, **kwargs)
|
688 |
+
self.blocks: dict[int, Block] = {}
|
689 |
+
self.fns: list[BlockFunction] = []
|
690 |
self.dependencies = []
|
691 |
self.mode = mode
|
692 |
|
|
|
697 |
self.height = None
|
698 |
self.api_open = True
|
699 |
|
700 |
+
self.is_space = os.getenv("SYSTEM") == "spaces"
|
701 |
self.favicon_path = None
|
702 |
self.auth = None
|
703 |
self.dev_mode = True
|
|
|
715 |
self.progress_tracking = None
|
716 |
self.ssl_verify = True
|
717 |
|
718 |
+
self.allowed_paths = []
|
719 |
+
self.blocked_paths = []
|
720 |
|
721 |
if self.analytics_enabled:
|
722 |
is_custom_theme = not any(
|
|
|
736 |
def from_config(
|
737 |
cls,
|
738 |
config: dict,
|
739 |
+
fns: list[Callable],
|
740 |
root_url: str | None = None,
|
741 |
) -> Blocks:
|
742 |
"""
|
|
|
750 |
config = copy.deepcopy(config)
|
751 |
components_config = config["components"]
|
752 |
theme = config.get("theme", "default")
|
753 |
+
original_mapping: dict[int, Block] = {}
|
754 |
|
755 |
def get_block_instance(id: int) -> Block:
|
756 |
for block_config in components_config:
|
|
|
867 |
raise DuplicateBlockError(
|
868 |
f"A block with id: {self._id} has already been rendered in the current Blocks."
|
869 |
)
|
870 |
+
overlapping_ids = set(Context.root_block.blocks).intersection(self.blocks)
|
871 |
+
for id in overlapping_ids:
|
872 |
+
# State componenents are allowed to be reused between Blocks
|
873 |
+
if not isinstance(self.blocks[id], components.State):
|
874 |
+
raise DuplicateBlockError(
|
875 |
+
"At least one block in this Blocks has already been rendered."
|
876 |
+
)
|
877 |
|
878 |
Context.root_block.blocks.update(self.blocks)
|
879 |
Context.root_block.fns.extend(self.fns)
|
|
|
885 |
api_name,
|
886 |
[dep["api_name"] for dep in Context.root_block.dependencies],
|
887 |
)
|
888 |
+
if api_name != api_name_:
|
889 |
warnings.warn(
|
890 |
f"api_name {api_name} already exists, using {api_name_}"
|
891 |
)
|
|
|
960 |
None,
|
961 |
)
|
962 |
if inferred_fn_index is None:
|
963 |
+
raise InvalidApiNameError(
|
964 |
+
f"Cannot find a function with api_name {api_name}"
|
965 |
+
)
|
966 |
fn_index = inferred_fn_index
|
967 |
if not (self.is_callable(fn_index)):
|
968 |
raise ValueError(
|
|
|
995 |
async def call_function(
|
996 |
self,
|
997 |
fn_index: int,
|
998 |
+
processed_input: list[Any],
|
999 |
+
iterator: AsyncIterator[Any] | None = None,
|
1000 |
+
requests: routes.Request | list[routes.Request] | None = None,
|
1001 |
event_id: str | None = None,
|
1002 |
event_data: EventData | None = None,
|
1003 |
):
|
|
|
1016 |
is_generating = False
|
1017 |
|
1018 |
if block_fn.inputs_as_dict:
|
1019 |
+
processed_input = [dict(zip(block_fn.inputs, processed_input))]
|
|
|
|
|
|
|
|
|
|
|
1020 |
|
1021 |
+
request = requests[0] if isinstance(requests, list) else requests
|
|
|
|
|
|
|
1022 |
processed_input, progress_index, _ = special_args(
|
1023 |
block_fn.fn, processed_input, request, event_data
|
1024 |
)
|
|
|
1046 |
else:
|
1047 |
prediction = None
|
1048 |
|
1049 |
+
if inspect.isgeneratorfunction(block_fn.fn) or inspect.isasyncgenfunction(
|
1050 |
+
block_fn.fn
|
1051 |
+
):
|
1052 |
if not self.enable_queue:
|
1053 |
raise ValueError("Need to enable queue to use generators.")
|
1054 |
try:
|
1055 |
if iterator is None:
|
1056 |
iterator = prediction
|
1057 |
+
if inspect.isgenerator(iterator):
|
1058 |
+
iterator = utils.SyncToAsyncIterator(iterator, self.limiter)
|
1059 |
+
prediction = await utils.async_iteration(iterator)
|
1060 |
is_generating = True
|
1061 |
except StopAsyncIteration:
|
1062 |
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
|
|
1076 |
"iterator": iterator,
|
1077 |
}
|
1078 |
|
1079 |
+
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
|
1080 |
dependency = self.dependencies[fn_index]
|
1081 |
processed_input = []
|
1082 |
|
|
|
1090 |
|
1091 |
return processed_input
|
1092 |
|
1093 |
+
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
|
1094 |
dependency = self.dependencies[fn_index]
|
1095 |
predictions = []
|
1096 |
|
|
|
1106 |
|
1107 |
return predictions
|
1108 |
|
1109 |
+
def validate_inputs(self, fn_index: int, inputs: list[Any]):
|
1110 |
block_fn = self.fns[fn_index]
|
1111 |
dependency = self.dependencies[fn_index]
|
1112 |
|
|
|
1128 |
block = self.blocks[input_id]
|
1129 |
wanted_args.append(str(block))
|
1130 |
for inp in inputs:
|
1131 |
+
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
|
|
|
|
|
|
|
1132 |
received_args.append(v)
|
1133 |
|
1134 |
wanted = ", ".join(wanted_args)
|
|
|
1144 |
[{received}]"""
|
1145 |
)
|
1146 |
|
1147 |
+
def preprocess_data(self, fn_index: int, inputs: list[Any], state: dict[int, Any]):
|
1148 |
block_fn = self.fns[fn_index]
|
1149 |
dependency = self.dependencies[fn_index]
|
1150 |
|
|
|
1165 |
processed_input = inputs
|
1166 |
return processed_input
|
1167 |
|
1168 |
+
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
|
1169 |
block_fn = self.fns[fn_index]
|
1170 |
dependency = self.dependencies[fn_index]
|
1171 |
|
|
|
1187 |
block = self.blocks[output_id]
|
1188 |
wanted_args.append(str(block))
|
1189 |
for pred in predictions:
|
1190 |
+
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
|
|
|
|
|
|
|
1191 |
received_args.append(v)
|
1192 |
|
1193 |
wanted = ", ".join(wanted_args)
|
|
|
1202 |
)
|
1203 |
|
1204 |
def postprocess_data(
|
1205 |
+
self, fn_index: int, predictions: list | dict, state: dict[int, Any]
|
1206 |
):
|
1207 |
block_fn = self.fns[fn_index]
|
1208 |
dependency = self.dependencies[fn_index]
|
|
|
1226 |
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
1227 |
output.append(None)
|
1228 |
continue
|
1229 |
+
except (IndexError, KeyError) as err:
|
1230 |
raise ValueError(
|
1231 |
+
"Number of output components does not match number "
|
1232 |
+
f"of values returned from from function {block_fn.name}"
|
1233 |
+
) from err
|
1234 |
block = self.blocks[output_id]
|
1235 |
if getattr(block, "stateful", False):
|
1236 |
if not utils.is_update(predictions[i]):
|
|
|
1257 |
async def process_api(
|
1258 |
self,
|
1259 |
fn_index: int,
|
1260 |
+
inputs: list[Any],
|
1261 |
+
state: dict[int, Any],
|
1262 |
+
request: routes.Request | list[routes.Request] | None = None,
|
1263 |
+
iterators: dict[int, Any] | None = None,
|
1264 |
event_id: str | None = None,
|
1265 |
event_data: EventData | None = None,
|
1266 |
+
) -> dict[str, Any]:
|
1267 |
"""
|
1268 |
Processes API calls from the frontend. First preprocesses the data,
|
1269 |
then runs the relevant function, then postprocesses the output.
|
|
|
1320 |
|
1321 |
block_fn.total_runtime += result["duration"]
|
1322 |
block_fn.total_runs += 1
|
|
|
1323 |
return {
|
1324 |
"data": data,
|
1325 |
"is_generating": is_generating,
|
|
|
1357 |
"theme": self.theme.name,
|
1358 |
}
|
1359 |
|
1360 |
+
def get_layout(block):
|
1361 |
if not isinstance(block, BlockContext):
|
1362 |
return {"id": block._id}
|
1363 |
children_layout = []
|
1364 |
for child in block.children:
|
1365 |
+
children_layout.append(get_layout(child))
|
1366 |
return {"id": block._id, "children": children_layout}
|
1367 |
|
1368 |
+
config["layout"] = get_layout(self)
|
1369 |
|
1370 |
for _id, block in self.blocks.items():
|
1371 |
props = block.get_config() if hasattr(block, "get_config") else {}
|
|
|
1408 |
|
1409 |
@class_or_instancemethod
|
1410 |
def load(
|
1411 |
+
self_or_cls, # noqa: N805
|
1412 |
fn: Callable | None = None,
|
1413 |
+
inputs: list[Component] | None = None,
|
1414 |
+
outputs: list[Component] | None = None,
|
1415 |
api_name: str | None = None,
|
1416 |
scroll_to_output: bool = False,
|
1417 |
show_progress: bool = True,
|
|
|
1428 |
api_key: str | None = None,
|
1429 |
alias: str | None = None,
|
1430 |
**kwargs,
|
1431 |
+
) -> Blocks | dict[str, Any] | None:
|
1432 |
"""
|
1433 |
For reverse compatibility reasons, this is both a class method and an instance
|
1434 |
method, the two of which, confusingly, do two completely different things.
|
|
|
1586 |
debug: bool = False,
|
1587 |
enable_queue: bool | None = None,
|
1588 |
max_threads: int = 40,
|
1589 |
+
auth: Callable | tuple[str, str] | list[tuple[str, str]] | None = None,
|
1590 |
auth_message: str | None = None,
|
1591 |
prevent_thread_lock: bool = False,
|
1592 |
show_error: bool = False,
|
|
|
1603 |
ssl_verify: bool = True,
|
1604 |
quiet: bool = False,
|
1605 |
show_api: bool = True,
|
1606 |
+
file_directories: list[str] | None = None,
|
1607 |
+
allowed_paths: list[str] | None = None,
|
1608 |
+
blocked_paths: list[str] | None = None,
|
1609 |
_frontend: bool = True,
|
1610 |
+
) -> tuple[FastAPI, str, str]:
|
1611 |
"""
|
1612 |
Launches a simple web server that serves the demo. Can also be used to create a
|
1613 |
public link used by anyone to access the demo from their browser by setting share=True.
|
|
|
1636 |
ssl_verify: If False, skips certificate validation which allows self-signed certificates to be used.
|
1637 |
quiet: If True, suppresses most print statements.
|
1638 |
show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
|
1639 |
+
file_directories: This parameter has been renamed to `allowed_paths`. It will be removed in a future version.
|
1640 |
+
allowed_paths: List of complete filepaths or parent directories that gradio is allowed to serve (in addition to the directory containing the gradio python file). Must be absolute paths. Warning: if you provide directories, any files in these directories or their subdirectories are accessible to all users of your app.
|
1641 |
+
blocked_paths: List of complete filepaths or parent directories that gradio is not allowed to serve (i.e. users of your app are not allowed to access). Must be absolute paths. Warning: takes precedence over `allowed_paths` and all other directories exposed by Gradio by default.
|
1642 |
Returns:
|
1643 |
app: FastAPI app object that is running the demo
|
1644 |
local_url: Locally accessible link to the demo
|
|
|
1699 |
self.queue()
|
1700 |
self.show_api = self.api_open if self.enable_queue else show_api
|
1701 |
|
1702 |
+
if file_directories is not None:
|
1703 |
+
warnings.warn(
|
1704 |
+
"The `file_directories` parameter has been renamed to `allowed_paths`. Please use that instead.",
|
1705 |
+
DeprecationWarning,
|
1706 |
+
)
|
1707 |
+
if allowed_paths is None:
|
1708 |
+
allowed_paths = file_directories
|
1709 |
+
self.allowed_paths = allowed_paths or []
|
1710 |
+
self.blocked_paths = blocked_paths or []
|
1711 |
+
|
1712 |
+
if not isinstance(self.allowed_paths, list):
|
1713 |
+
raise ValueError("`allowed_paths` must be a list of directories.")
|
1714 |
+
if not isinstance(self.blocked_paths, list):
|
1715 |
+
raise ValueError("`blocked_paths` must be a list of directories.")
|
1716 |
|
1717 |
self.validate_queue_settings()
|
1718 |
|
|
|
1789 |
# a shareable link must be created.
|
1790 |
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
|
1791 |
raise ValueError(
|
1792 |
+
"When localhost is not accessible, a shareable link must be created. Please set share=True or check your proxy settings to allow access to localhost."
|
1793 |
)
|
1794 |
|
1795 |
if self.is_colab:
|
|
|
1806 |
)
|
1807 |
else:
|
1808 |
if not self.share:
|
1809 |
+
print(f'Running on local URL: https://{self.server_name}')
|
1810 |
|
1811 |
if self.share:
|
1812 |
if self.is_space:
|