File size: 4,857 Bytes
e50338f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Forge Canvas
# AGPL V3
# by lllyasviel
# Commercial Use is not allowed. (Contact us for commercial use.)

import gradio.component_meta


create_or_modify_pyi_org = gradio.component_meta.create_or_modify_pyi


def create_or_modify_pyi_org_patched(component_class, class_name, events):
    try:
        if component_class.__name__ == 'LogicalImage':
            return
        return create_or_modify_pyi_org(component_class, class_name, events)
    except:
        return


gradio.component_meta.create_or_modify_pyi = create_or_modify_pyi_org_patched


import os
import uuid
import base64
import gradio as gr
import numpy as np

from PIL import Image
from io import BytesIO
from gradio.context import Context
from functools import wraps


canvas_js_root_path = os.path.dirname(__file__)


def web_js(file_name):
    full_path = os.path.join(canvas_js_root_path, file_name)
    return f'<script src="file={full_path}?{os.path.getmtime(full_path)}"></script>\n'


def web_css(file_name):
    full_path = os.path.join(canvas_js_root_path, file_name)
    return f'<link rel="stylesheet" href="file={full_path}?{os.path.getmtime(full_path)}">\n'


DEBUG_MODE = False

canvas_html = open(os.path.join(canvas_js_root_path, 'canvas.html'), encoding='utf-8').read()
canvas_head = ''
canvas_head += web_css('canvas.css')
canvas_head += web_js('canvas.min.js')


def image_to_base64(image_array, numpy=True):
    image = Image.fromarray(image_array) if numpy else image_array
    image = image.convert("RGBA")
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return f"data:image/png;base64,{image_base64}"


def base64_to_image(base64_str, numpy=True):
    if base64_str.startswith("data:image/png;base64,"):
        base64_str = base64_str.replace("data:image/png;base64,", "")
    image_data = base64.b64decode(base64_str)
    image = Image.open(BytesIO(image_data))
    image = image.convert("RGBA")
    image_array = np.array(image) if numpy else image
    return image_array


class LogicalImage(gr.Textbox):
    @wraps(gr.Textbox.__init__)
    def __init__(self, *args, numpy=True, **kwargs):
        self.numpy = numpy
        self.infotext = dict()

        if 'value' in kwargs:
            initial_value = kwargs['value']
            if initial_value is not None:
                kwargs['value'] = self.image_to_base64(initial_value)
            else:
                del kwargs['value']

        super().__init__(*args, **kwargs)

    def preprocess(self, payload):
        if not isinstance(payload, str):
            return None

        if not payload.startswith("data:image/png;base64,"):
            return None

        image = base64_to_image(payload, numpy=self.numpy)
        if hasattr(image, 'info'):
            image.info = self.infotext
        
        return image

    def postprocess(self, value):
        if value is None:
            return None
            
        if hasattr(value, 'info'):
            self.infotext = value.info

        return image_to_base64(value, numpy=self.numpy)

    def get_block_name(self):
        return "textbox"


class ForgeCanvas:
    def __init__(

            self,

            no_upload=False,

            no_scribbles=False,

            contrast_scribbles=False,

            height=512,

            scribble_color='#000000',

            scribble_color_fixed=False,

            scribble_width=4,

            scribble_width_fixed=False,

            scribble_alpha=100,

            scribble_alpha_fixed=False,

            scribble_softness=0,

            scribble_softness_fixed=False,

            visible=True,

            numpy=False,

            initial_image=None,

            elem_id=None,

            elem_classes=None

    ):
        self.uuid = 'uuid_' + uuid.uuid4().hex
        self.block = gr.HTML(canvas_html.replace('forge_mixin', self.uuid), visible=visible, elem_id=elem_id, elem_classes=elem_classes)
        self.foreground = LogicalImage(visible=DEBUG_MODE, label='foreground', numpy=numpy, elem_id=self.uuid, elem_classes=['logical_image_foreground'])
        self.background = LogicalImage(visible=DEBUG_MODE, label='background', numpy=numpy, value=initial_image, elem_id=self.uuid, elem_classes=['logical_image_background'])
        Context.root_block.load(None, js=f'async ()=>{{new ForgeCanvas("{self.uuid}", {no_upload}, {no_scribbles}, {contrast_scribbles}, {height}, '
                                         f"'{scribble_color}', {scribble_color_fixed}, {scribble_width}, {scribble_width_fixed}, "
                                         f'{scribble_alpha}, {scribble_alpha_fixed}, {scribble_softness}, {scribble_softness_fixed});}}')