File size: 2,653 Bytes
b139995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from gradio.components.gallery import (
    GalleryImageType,
    CaptionedGalleryImageType,
    GalleryImage,
    GalleryData,
)
from pathlib import Path
from urllib.parse import urlparse

import gradio
import numpy as np
import PIL.Image
from gradio_client.utils import is_http_url_like

from gradio import processing_utils, utils, wasm_utils
from gradio.data_classes import FileData


class Gallery(gradio.Gallery):
    def postprocess(
        self,
        value: list[GalleryImageType | CaptionedGalleryImageType] | None,
    ) -> GalleryData:
        """
        This is a patched version of the original function, wherein the format for PIL is computed based on the data type:
        format = "png" if img.mode == "I;16" else "webp"
        """
        if value is None:
            return GalleryData(root=[])
        output = []

        def _save(img):
            url = None
            caption = None
            orig_name = None
            if isinstance(img, (tuple, list)):
                img, caption = img
            if isinstance(img, np.ndarray):
                file = processing_utils.save_img_array_to_cache(
                    img, cache_dir=self.GRADIO_CACHE, format=self.format
                )
                file_path = str(utils.abspath(file))
            elif isinstance(img, PIL.Image.Image):
                format = "png" #if img.mode == "I;16" else "webp"
                file = processing_utils.save_pil_to_cache(
                    img, cache_dir=self.GRADIO_CACHE, format=format
                )
                file_path = str(utils.abspath(file))
            elif isinstance(img, str):
                file_path = img
                if is_http_url_like(img):
                    url = img
                    orig_name = Path(urlparse(img).path).name
                else:
                    url = None
                    orig_name = Path(img).name
            elif isinstance(img, Path):
                file_path = str(img)
                orig_name = img.name
            else:
                raise ValueError(f"Cannot process type as image: {type(img)}")
            return GalleryImage(
                image=FileData(path=file_path, url=url, orig_name=orig_name),
                caption=caption,
            )

        if wasm_utils.IS_WASM:
            for img in value:
                output.append(_save(img))
        else:
            with ThreadPoolExecutor() as executor:
                for o in executor.map(_save, value):
                    output.append(o)
        return GalleryData(root=output)