Spaces:
Sleeping
Sleeping
"""Contains all of the components that can be used with Gradio Interface / Blocks. | |
Along with the docs for each component, you can find the names of example demos that use | |
each component. These demos are located in the `demo` directory.""" | |
from __future__ import annotations | |
import hashlib | |
import os | |
import secrets | |
import shutil | |
import tempfile | |
import urllib.request | |
from enum import Enum | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Callable | |
import aiofiles | |
import numpy as np | |
import requests | |
from fastapi import UploadFile | |
from gradio_client import utils as client_utils | |
from gradio_client.documentation import set_documentation_group | |
from gradio_client.serializing import ( | |
Serializable, | |
) | |
from PIL import Image as _Image # using _ to minimize namespace pollution | |
from gradio import processing_utils, utils | |
from gradio.blocks import Block, BlockContext | |
from gradio.deprecation import warn_deprecation, warn_style_method_deprecation | |
from gradio.events import ( | |
EventListener, | |
) | |
from gradio.layouts import Column, Form, Row | |
if TYPE_CHECKING: | |
from typing import TypedDict | |
class DataframeData(TypedDict): | |
headers: list[str] | |
data: list[list[str | int | bool]] | |
set_documentation_group("component") | |
_Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 | |
class _Keywords(Enum): | |
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` | |
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) | |
class Component(Block, Serializable): | |
""" | |
A base class for defining the methods that all gradio components should have. | |
""" | |
def __init__(self, *args, **kwargs): | |
Block.__init__(self, *args, **kwargs) | |
EventListener.__init__(self) | |
def __str__(self): | |
return self.__repr__() | |
def __repr__(self): | |
return f"{self.get_block_name()}" | |
def get_config(self): | |
""" | |
:return: a dictionary with context variables for the javascript file associated with the context | |
""" | |
return { | |
"name": self.get_block_name(), | |
**super().get_config(), | |
} | |
def preprocess(self, x: Any) -> Any: | |
""" | |
Any preprocessing needed to be performed on function input. | |
""" | |
return x | |
def postprocess(self, y): | |
""" | |
Any postprocessing needed to be performed on function output. | |
""" | |
return y | |
def style(self, *args, **kwargs): | |
""" | |
This method is deprecated. Please set these arguments in the Components constructor instead. | |
""" | |
warn_style_method_deprecation() | |
put_deprecated_params_in_box = False | |
if "rounded" in kwargs: | |
warn_deprecation( | |
"'rounded' styling is no longer supported. To round adjacent components together, place them in a Column(variant='box')." | |
) | |
if isinstance(kwargs["rounded"], (list, tuple)): | |
put_deprecated_params_in_box = True | |
kwargs.pop("rounded") | |
if "margin" in kwargs: | |
warn_deprecation( | |
"'margin' styling is no longer supported. To place adjacent components together without margin, place them in a Column(variant='box')." | |
) | |
if isinstance(kwargs["margin"], (list, tuple)): | |
put_deprecated_params_in_box = True | |
kwargs.pop("margin") | |
if "border" in kwargs: | |
warn_deprecation( | |
"'border' styling is no longer supported. To place adjacent components in a shared border, place them in a Column(variant='box')." | |
) | |
kwargs.pop("border") | |
for key in kwargs: | |
warn_deprecation(f"Unknown style parameter: {key}") | |
if ( | |
put_deprecated_params_in_box | |
and isinstance(self.parent, (Row, Column)) | |
and self.parent.variant == "default" | |
): | |
self.parent.variant = "compact" | |
return self | |
class IOComponent(Component): | |
""" | |
A base class for defining methods that all input/output components should have. | |
""" | |
def __init__( | |
self, | |
*, | |
value: Any = None, | |
label: str | None = None, | |
info: str | None = None, | |
show_label: bool | None = None, | |
container: bool = True, | |
scale: int | None = None, | |
min_width: int | None = None, | |
interactive: bool | None = None, | |
visible: bool = True, | |
elem_id: str | None = None, | |
elem_classes: list[str] | str | None = None, | |
load_fn: Callable | None = None, | |
every: float | None = None, | |
**kwargs, | |
): | |
self.temp_files: set[str] = set() | |
self.DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str( | |
Path(tempfile.gettempdir()) / "gradio" | |
) | |
Component.__init__( | |
self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs | |
) | |
self.label = label | |
self.info = info | |
if not container: | |
if show_label: | |
warn_deprecation("show_label has no effect when container is False.") | |
show_label = False | |
if show_label is None: | |
show_label = True | |
self.show_label = show_label | |
self.container = container | |
if scale is not None and scale != round(scale): | |
warn_deprecation( | |
f"'scale' value should be an integer. Using {scale} will cause issues." | |
) | |
self.scale = scale | |
self.min_width = min_width | |
self.interactive = interactive | |
# load_event is set in the Blocks.attach_load_events method | |
self.load_event: None | dict[str, Any] = None | |
self.load_event_to_attach = None | |
load_fn, initial_value = self.get_load_fn_and_initial_value(value) | |
self.value = ( | |
initial_value | |
if self._skip_init_processing | |
else self.postprocess(initial_value) | |
) | |
if callable(load_fn): | |
self.attach_load_event(load_fn, every) | |
def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str: | |
sha1 = hashlib.sha1() | |
with open(file_path, "rb") as f: | |
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): | |
sha1.update(chunk) | |
return sha1.hexdigest() | |
def hash_url(url: str, chunk_num_blocks: int = 128) -> str: | |
sha1 = hashlib.sha1() | |
remote = urllib.request.urlopen(url) | |
max_file_size = 100 * 1024 * 1024 # 100MB | |
total_read = 0 | |
while True: | |
data = remote.read(chunk_num_blocks * sha1.block_size) | |
total_read += chunk_num_blocks * sha1.block_size | |
if not data or total_read > max_file_size: | |
break | |
sha1.update(data) | |
return sha1.hexdigest() | |
def hash_bytes(bytes: bytes): | |
sha1 = hashlib.sha1() | |
sha1.update(bytes) | |
return sha1.hexdigest() | |
def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str: | |
sha1 = hashlib.sha1() | |
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): | |
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] | |
sha1.update(data.encode("utf-8")) | |
return sha1.hexdigest() | |
def make_temp_copy_if_needed(self, file_path: str | Path) -> str: | |
"""Returns a temporary file path for a copy of the given file path if it does | |
not already exist. Otherwise returns the path to the existing temp file.""" | |
temp_dir = self.hash_file(file_path) | |
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
name = client_utils.strip_invalid_filename_characters(Path(file_path).name) | |
full_temp_file_path = str(utils.abspath(temp_dir / name)) | |
if not Path(full_temp_file_path).exists(): | |
shutil.copy2(file_path, full_temp_file_path) | |
self.temp_files.add(full_temp_file_path) | |
return full_temp_file_path | |
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: | |
temp_dir = secrets.token_hex( | |
20 | |
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. | |
temp_dir = Path(upload_dir) / temp_dir | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
if file.filename: | |
file_name = Path(file.filename).name | |
name = client_utils.strip_invalid_filename_characters(file_name) | |
else: | |
name = f"tmp{secrets.token_hex(5)}" | |
full_temp_file_path = str(utils.abspath(temp_dir / name)) | |
async with aiofiles.open(full_temp_file_path, "wb") as output_file: | |
while True: | |
content = await file.read(100 * 1024 * 1024) | |
if not content: | |
break | |
await output_file.write(content) | |
return full_temp_file_path | |
def download_temp_copy_if_needed(self, url: str) -> str: | |
"""Downloads a file and makes a temporary file path for a copy if does not already | |
exist. Otherwise returns the path to the existing temp file.""" | |
temp_dir = self.hash_url(url) | |
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
name = client_utils.strip_invalid_filename_characters(Path(url).name) | |
full_temp_file_path = str(utils.abspath(temp_dir / name)) | |
if not Path(full_temp_file_path).exists(): | |
with requests.get(url, stream=True) as r, open( | |
full_temp_file_path, "wb" | |
) as f: | |
shutil.copyfileobj(r.raw, f) | |
self.temp_files.add(full_temp_file_path) | |
return full_temp_file_path | |
def base64_to_temp_file_if_needed( | |
self, base64_encoding: str, file_name: str | None = None | |
) -> str: | |
"""Converts a base64 encoding to a file and returns the path to the file if | |
the file doesn't already exist. Otherwise returns the path to the existing file. | |
""" | |
temp_dir = self.hash_base64(base64_encoding) | |
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
guess_extension = client_utils.get_extension(base64_encoding) | |
if file_name: | |
file_name = client_utils.strip_invalid_filename_characters(file_name) | |
elif guess_extension: | |
file_name = f"file.{guess_extension}" | |
else: | |
file_name = "file" | |
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore | |
if not Path(full_temp_file_path).exists(): | |
data, _ = client_utils.decode_base64_to_binary(base64_encoding) | |
with open(full_temp_file_path, "wb") as fb: | |
fb.write(data) | |
self.temp_files.add(full_temp_file_path) | |
return full_temp_file_path | |
def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str: | |
bytes_data = processing_utils.encode_pil_to_bytes(img, format) | |
temp_dir = Path(dir) / self.hash_bytes(bytes_data) | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
filename = str(temp_dir / f"image.{format}") | |
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) | |
return filename | |
def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str: | |
pil_image = _Image.fromarray( | |
processing_utils._convert(arr, np.uint8, force_copy=False) | |
) | |
return self.pil_to_temp_file(pil_image, dir, format="png") | |
def audio_to_temp_file(self, data: np.ndarray, sample_rate: int, format: str): | |
temp_dir = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data.tobytes()) | |
temp_dir.mkdir(exist_ok=True, parents=True) | |
filename = str(temp_dir / f"audio.{format}") | |
processing_utils.audio_to_file(sample_rate, data, filename, format=format) | |
return filename | |
def file_bytes_to_file(self, data: bytes, file_name: str): | |
path = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data) | |
path.mkdir(exist_ok=True, parents=True) | |
path = path / Path(file_name).name | |
path.write_bytes(data) | |
return path | |
def get_config(self): | |
config = { | |
"label": self.label, | |
"show_label": self.show_label, | |
"container": self.container, | |
"scale": self.scale, | |
"min_width": self.min_width, | |
"interactive": self.interactive, | |
**super().get_config(), | |
} | |
if self.info: | |
config["info"] = self.info | |
return config | |
def get_load_fn_and_initial_value(value): | |
if callable(value): | |
initial_value = value() | |
load_fn = value | |
else: | |
initial_value = value | |
load_fn = None | |
return load_fn, initial_value | |
def attach_load_event(self, callable: Callable, every: float | None): | |
"""Add a load event that runs `callable`, optionally every `every` seconds.""" | |
self.load_event_to_attach = (callable, every) | |
def as_example(self, input_data): | |
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end.""" | |
return input_data | |
class FormComponent: | |
def get_expected_parent(self) -> type[Form] | None: | |
if getattr(self, "container", None) is False: | |
return None | |
return Form | |
def component(cls_name: str) -> Component: | |
obj = utils.component_or_layout_class(cls_name)() | |
if isinstance(obj, BlockContext): | |
raise ValueError(f"Invalid component: {obj.__class__}") | |
return obj | |
def get_component_instance( | |
comp: str | dict | Component, render: bool | None = None | |
) -> Component: | |
""" | |
Returns a component instance from a string, dict, or Component object. | |
Parameters: | |
comp: the component to instantiate. If a string, must be the name of a component, e.g. "dropdown". If a dict, must have a "name" key, e.g. {"name": "dropdown", "choices": ["a", "b"]}. If a Component object, will be returned as is. | |
render: whether to render the component. If True, renders the component (if not already rendered). If False, *unrenders* the component (if already rendered) -- this is useful when constructing an Interface or ChatInterface inside of a Blocks. If None, does not render or unrender the component. | |
""" | |
if isinstance(comp, str): | |
component_obj = component(comp) | |
elif isinstance(comp, dict): | |
name = comp.pop("name") | |
component_cls = utils.component_or_layout_class(name) | |
component_obj = component_cls(**comp) | |
if isinstance(component_obj, BlockContext): | |
raise ValueError(f"Invalid component: {name}") | |
elif isinstance(comp, Component): | |
component_obj = comp | |
else: | |
raise ValueError( | |
f"Component must provided as a `str` or `dict` or `Component` but is {comp}" | |
) | |
if render and not component_obj.is_rendered: | |
component_obj.render() | |
elif render is False and component_obj.is_rendered: | |
component_obj.unrender() | |
return component_obj | |