File size: 4,653 Bytes
0ad74ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING
from gradio.pipelines_utils import (
handle_diffusers_pipeline,
handle_transformers_js_pipeline,
handle_transformers_pipeline,
)
if TYPE_CHECKING:
import diffusers
import transformers
def load_from_pipeline(
pipeline: transformers.Pipeline | diffusers.DiffusionPipeline, # type: ignore
) -> dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline or diffusers.DiffusionPipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
if str(type(pipeline).__module__).startswith("transformers.pipelines."):
pipeline_info = handle_transformers_pipeline(pipeline)
elif str(type(pipeline).__module__).startswith("diffusers.pipelines."):
pipeline_info = handle_diffusers_pipeline(pipeline)
else:
raise ValueError(
"pipeline must be a transformers.pipeline or diffusers.pipeline"
)
def fn(*params):
if pipeline_info:
data = pipeline_info["preprocess"](*params)
if str(type(pipeline).__module__).startswith("transformers.pipelines"):
from transformers import pipelines
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
pipelines.token_classification.TokenClassificationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data) # type: ignore
# special case for object-detection and token-classification pipelines
# original input image / text sent to postprocess function
if isinstance(
pipeline,
(
pipelines.object_detection.ObjectDetectionPipeline,
pipelines.token_classification.TokenClassificationPipeline,
),
):
output = pipeline_info["postprocess"](data, params[0])
else:
output = pipeline_info["postprocess"](data)
return output
elif str(type(pipeline).__module__).startswith("diffusers.pipelines"):
data = pipeline(**data) # type: ignore
output = pipeline_info["postprocess"](data)
return output
else:
raise ValueError("pipeline_info can not be None.")
interface_info = pipeline_info.copy() if pipeline_info else {}
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = (
pipeline.model.config.name_or_path
if str(type(pipeline).__module__).startswith("transformers.pipelines")
else pipeline.__class__.__name__
)
return interface_info
def load_from_js_pipeline(pipeline) -> dict:
if str(type(pipeline).__module__).startswith("transformers_js_py."):
pipeline_info = handle_transformers_js_pipeline(pipeline)
else:
raise ValueError("pipeline must be a transformers_js_py's pipeline")
async def fn(*params):
preprocess = pipeline_info["preprocess"]
postprocess = pipeline_info["postprocess"]
postprocess_takes_inputs = pipeline_info.get("postprocess_takes_inputs", False)
preprocessed_params = preprocess(*params) if preprocess else params
pipeline_output = await pipeline(*preprocessed_params)
postprocessed_output = (
postprocess(pipeline_output, *(params if postprocess_takes_inputs else ()))
if postprocess
else pipeline_output
)
return postprocessed_output
interface_info = {
"fn": fn,
"inputs": pipeline_info["inputs"],
"outputs": pipeline_info["outputs"],
"title": f"{pipeline.task} ({pipeline.model.config._name_or_path})",
}
return interface_info
|