from __future__ import annotations import dataclasses import inspect import json import re import shutil import textwrap from pathlib import Path from typing import Literal import gradio def _in_test_dir(): """Check if the current working directory ends with gradio/js/preview/test.""" return Path.cwd().parts[-4:] == ("gradio", "js", "preview", "test") default_demo_code = """ example = {name}().example_value() demo = gr.Interface( lambda x:x, {name}(), # interactive version of your component {name}(), # static version of your component # examples=[[example]], # uncomment this line to view the "example version" of your component ) """ static_only_demo_code = """ example = {name}().example_value() with gr.Blocks() as demo: with gr.Row(): {name}(label="Blank"), # blank component {name}(value=example, label="Populated"), # populated component """ layout_demo_code = """ with gr.Blocks() as demo: with {name}(): gr.Textbox(value="foo", interactive=True) gr.Number(value=10, interactive=True) """ fallback_code = """ with gr.Blocks() as demo: gr.Markdown("# Change the value (keep it JSON) and the front-end will update automatically.") {name}(value={{"message": "Hello from Gradio!"}}, label="Static") """ PATTERN_RE = r"gradio-template-\w+" PATTERN = "gradio-template-{template}" @dataclasses.dataclass class ComponentFiles: template: str demo_code: str = default_demo_code python_file_name: str = "" js_dir: str = "" def __post_init__(self): self.js_dir = self.js_dir or self.template.lower() self.python_file_name = self.python_file_name or f"{self.template.lower()}.py" OVERRIDES = { "AnnotatedImage": ComponentFiles( template="AnnotatedImage", python_file_name="annotated_image.py", demo_code=static_only_demo_code, ), "HighlightedText": ComponentFiles( template="HighlightedText", python_file_name="highlighted_text.py", demo_code=static_only_demo_code, ), "Chatbot": ComponentFiles(template="Chatbot", demo_code=static_only_demo_code), "Gallery": ComponentFiles(template="Gallery", demo_code=static_only_demo_code), "HTML": ComponentFiles(template="HTML", demo_code=static_only_demo_code), "Label": ComponentFiles(template="Label", demo_code=static_only_demo_code), "Markdown": ComponentFiles(template="Markdown", demo_code=static_only_demo_code), "Fallback": ComponentFiles(template="Fallback", demo_code=fallback_code), "Plot": ComponentFiles(template="Plot", demo_code=static_only_demo_code), "BarPlot": ComponentFiles( template="BarPlot", python_file_name="native_plot.py", js_dir="plot", demo_code=static_only_demo_code, ), "ClearButton": ComponentFiles( template="ClearButton", python_file_name="clear_button.py", js_dir="button", demo_code=static_only_demo_code, ), "ColorPicker": ComponentFiles( template="ColorPicker", python_file_name="color_picker.py" ), "DuplicateButton": ComponentFiles( template="DuplicateButton", python_file_name="duplicate_button.py", js_dir="button", demo_code=static_only_demo_code, ), "FileExplorer": ComponentFiles( template="FileExplorer", python_file_name="file_explorer.py", js_dir="fileexplorer", demo_code=textwrap.dedent( """ import os with gr.Blocks() as demo: {name}(value=os.path.dirname(__file__).split(os.sep)) """ ), ), "LinePlot": ComponentFiles( template="LinePlot", python_file_name="native_plot.py", js_dir="plot", demo_code=static_only_demo_code, ), "LogoutButton": ComponentFiles( template="LogoutButton", python_file_name="logout_button.py", js_dir="button", demo_code=static_only_demo_code, ), "LoginButton": ComponentFiles( template="LoginButton", python_file_name="login_button.py", js_dir="button", demo_code=static_only_demo_code, ), "ScatterPlot": ComponentFiles( template="ScatterPlot", python_file_name="native_plot.py", js_dir="plot", demo_code=static_only_demo_code, ), "UploadButton": ComponentFiles( template="UploadButton", python_file_name="upload_button.py", demo_code=static_only_demo_code, ), "JSON": ComponentFiles( template="JSON", python_file_name="json_component.py", demo_code=static_only_demo_code, ), "Row": ComponentFiles( template="Row", demo_code=layout_demo_code, ), "Column": ComponentFiles( template="Column", demo_code=layout_demo_code, ), "Tabs": ComponentFiles( template="Tabs", demo_code=textwrap.dedent( """ with gr.Blocks() as demo: with {name}(): with gr.Tab("Tab 1"): gr.Textbox(value="foo", interactive=True) with gr.Tab("Tab 2"): gr.Number(value=10, interactive=True) """ ), ), "Group": ComponentFiles( template="Group", demo_code=layout_demo_code, ), "Accordion": ComponentFiles( template="Accordion", demo_code=textwrap.dedent( """ with gr.Blocks() as demo: with {name}(label="Accordion"): gr.Textbox(value="foo", interactive=True) gr.Number(value=10, interactive=True) """ ), ), "Model3D": ComponentFiles( template="Model3D", js_dir="model3D", demo_code=textwrap.dedent( """ with gr.Blocks() as demo: {name}() """ ), ), "ImageEditor": ComponentFiles( template="ImageEditor", python_file_name="image_editor.py", js_dir="imageeditor", ), "MultimodalTextbox": ComponentFiles( template="MultimodalTextbox", python_file_name="multimodal_textbox.py", js_dir="multimodaltextbox", ), "DownloadButton": ComponentFiles( template="DownloadButton", python_file_name="download_button.py", js_dir="downloadbutton", ), } def _get_component_code(template: str | None) -> ComponentFiles: template = template or "Fallback" if template in OVERRIDES: return OVERRIDES[template] else: return ComponentFiles( python_file_name=f"{template.lower()}.py", js_dir=template.lower(), template=template, ) def _get_js_dependency_version(name: str, local_js_dir: Path) -> str: package_json = json.loads( Path(local_js_dir / name.split("/")[1] / "package.json").read_text() ) return package_json["version"] def _modify_js_deps( package_json: dict, key: Literal["dependencies", "devDependencies"], gradio_dir: Path, ): for dep in package_json.get(key, []): # if curent working directory is the gradio repo, use the local version of the dependency' if not _in_test_dir() and dep.startswith("@gradio/"): package_json[key][dep] = _get_js_dependency_version( dep, gradio_dir / "_frontend_code" ) return package_json def delete_contents(directory: str | Path) -> None: """Delete all contents of a directory, but not the directory itself.""" path = Path(directory) for child in path.glob("*"): if child.is_file(): child.unlink() elif child.is_dir(): shutil.rmtree(child) def _create_frontend( name: str, # noqa: ARG001 component: ComponentFiles, directory: Path, package_name: str, ): frontend = directory / "frontend" frontend.mkdir(exist_ok=True) p = Path(inspect.getfile(gradio)).parent def ignore(_src, names): ignored = [] for n in names: if ( n.startswith("CHANGELOG") or n.startswith("README.md") or ".test." in n or ".stories." in n or ".spec." in n ): ignored.append(n) return ignored shutil.copytree( str(p / "_frontend_code" / component.js_dir), frontend, dirs_exist_ok=True, ignore=ignore, ) source_package_json = json.loads(Path(frontend / "package.json").read_text()) source_package_json["name"] = package_name source_package_json = _modify_js_deps(source_package_json, "dependencies", p) source_package_json = _modify_js_deps(source_package_json, "devDependencies", p) (frontend / "package.json").write_text(json.dumps(source_package_json, indent=2)) shutil.copy( str(Path(__file__).parent / "files" / "gradio.config.js"), frontend / "gradio.config.js", ) def _replace_old_class_name(old_class_name: str, new_class_name: str, content: str): pattern = rf"(?<=\b)(?<!\bimport\s)(?<!\.){re.escape(old_class_name)}(?=\b)" return re.sub(pattern, new_class_name, content) def _strip_document_lines(content: str): return "\n".join( [line for line in content.split("\n") if not line.startswith("@document(")] ) def _create_backend( name: str, component: ComponentFiles, directory: Path, package_name: str ): def find_template_in_list(template, list_to_search): for item in list_to_search: if template.lower() == item.lower(): return item return None lists_to_search = [ (gradio.components.__all__, "components"), (gradio.layouts.__all__, "layouts"), (gradio._simple_templates.__all__, "_simple_templates"), # type: ignore ] correct_cased_template = None module = None for list_, module_name in lists_to_search: correct_cased_template = find_template_in_list(component.template, list_) if correct_cased_template: module = module_name break if not correct_cased_template: raise ValueError( f"Cannot find {component.template} in gradio.components, gradio.layouts, or gradio._simple_templates. " "Please pass in a valid component name via the --template option. It must match the name of the python class." ) if not module: raise ValueError("Module not found") # These README contents are used to install the component but they are overwritten later readme_contents = textwrap.dedent( """ # {package_name} A Custom Gradio component. ## Example usage ```python import gradio as gr from {package_name} import {name} ``` """ ).format(package_name=package_name, name=name) (directory / "README.md").write_text(readme_contents) backend = directory / "backend" / package_name backend.mkdir(exist_ok=True, parents=True) gitignore = Path(__file__).parent / "files" / "gitignore" gitignore_contents = gitignore.read_text() gitignore_dest = directory / ".gitignore" gitignore_dest.write_text(gitignore_contents) pyproject = Path(__file__).parent / "files" / "pyproject_.toml" pyproject_contents = pyproject.read_text() pyproject_dest = directory / "pyproject.toml" pyproject_contents = pyproject_contents.replace("<<name>>", package_name).replace( "<<template>>", PATTERN.format(template=correct_cased_template) ) pyproject_dest.write_text(pyproject_contents) demo_dir = directory / "demo" demo_dir.mkdir(exist_ok=True, parents=True) (demo_dir / "app.py").write_text( f""" import gradio as gr from {package_name} import {name} {component.demo_code.format(name=name)} if __name__ == "__main__": demo.launch() """ ) (demo_dir / "__init__.py").touch() init = backend / "__init__.py" init.write_text( f""" from .{name.lower()} import {name} __all__ = ['{name}'] """ ) p = Path(inspect.getfile(gradio)).parent python_file = backend / f"{name.lower()}.py" shutil.copy( str(p / module / component.python_file_name), str(python_file), ) source_pyi_file = p / module / component.python_file_name.replace(".py", ".pyi") pyi_file = backend / f"{name.lower()}.pyi" if source_pyi_file.exists(): shutil.copy(str(source_pyi_file), str(pyi_file)) content = python_file.read_text() content = _replace_old_class_name(correct_cased_template, name, content) content = _strip_document_lines(content) python_file.write_text(content) if pyi_file.exists(): pyi_content = pyi_file.read_text() pyi_content = _replace_old_class_name(correct_cased_template, name, content) pyi_content = _strip_document_lines(pyi_content) pyi_file.write_text(pyi_content)