File size: 14,007 Bytes
37c29c8
35805d2
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73dc8c0
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35805d2
 
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35805d2
 
 
 
37c29c8
 
35805d2
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Remi Serra 202407
from env_utils import load_credentials
import gradio as gr
from random import randrange
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string
from watsonx_utils import wxEngine
from prompts import (
    list_prompts,
    get_prompt_template,
    get_prompt_example,
    get_prompt_primer,
)
from data_images import svg_three_dots


# Functions
def read_file(uploaded_file):
    if uploaded_file:
        svg_xml = open(uploaded_file, "r").read()
        data_string = xml_string_to_data_string(svg_xml)
        return (
            data_string,
            svg_xml,
            html_img_preview(data_string),
            xml_string_to_html_img(svg_xml),
        )


def encoded_string_box_change(data_string: str):
    # print(f"encoded_string_box_change:image_data:{data_string}")
    svg_xml = decode_b64_string_to_pretty_xml(data_string)
    return svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml)


def xml_string_box_change(svg_xml: str):
    data_string = xml_string_to_data_string(svg_xml)
    return data_string, html_img_preview(data_string), xml_string_to_html_img(svg_xml)


def xml_string_to_data_string(svg_xml: str):
    b64 = encode_svg_xml_to_b64_string(svg_xml)
    data_string = "data:image/svg+xml;base64," + b64
    return data_string


def html_img_preview(data_string):
    return r'<center><img width=100 src="%s"/><center>' % data_string


def xml_string_to_html_img(svg_xml: str):
    data_string = xml_string_to_data_string(svg_xml)
    return html_img_preview(data_string)


# def replace_color(svg_xml: str, color_from, color_to):
#     print(
#         f"replace_color:svg_xml:{svg_xml}, color_from:{color_from}, color_to:{color_to} "
#     )
#     if svg_xml and color_from and color_to:
#         new_svg_xml = svg_xml.replace(color_from, color_to)
#         data_string = xml_string_to_data_string(new_svg_xml)
#         return data_string, new_svg_xml, html_img_preview(data_string)


# def switch_colors(color_from, color_to):
#     print(f"switch_colors:color_from:{color_from}, color_to:{color_to} ")
#     return color_to, color_from


# Functions - watsonx


def wx_prompt_change(prompt_template_name):
    return (
        get_prompt_template(prompt_template_name),
        get_prompt_example(prompt_template_name),
        get_prompt_primer(prompt_template_name),
    )


def wx_models_dropdown(wx_engine_state):
    wx_engine = wx_engine_state
    model_list = []
    default_value = None
    recommended_model = "ibm/granite-20b-code-instruct"
    if wx_engine is not None:
        model_list = wx_engine.list_models()
        default_value = (
            (recommended_model if recommended_model in model_list else model_list[0]),
        )
    return gr.Dropdown(
        label="Model",
        info=recommended_model + " recommended",
        choices=model_list,
        value=default_value,
    )


def wx_connect(wx_engine_state, apiendpoint, apikey, projectid):
    wx_engine_state = wxEngine(apiendpoint, apikey, projectid)
    print("watsonx.ai activated")
    return wx_engine_state, wx_models_dropdown(wx_engine_state)


def prepare_prompt(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_status = "Done."
    wx_engine = wx_engine_state
    # get model specs
    model_max_tokens = wx_engine.get_model_max_tokens(wx_model)
    # Add "primer" at the end of the prompt
    prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer
    # Test and alert if prompt is too long
    prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model)
    if prompt_nb_tokens > model_max_tokens:
        wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions."
        print(wx_status)
    # calculate max new token based on xml_string - or 500 when original string is too small
    # note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate()
    max_new_tokens = max(500, len(xml_string))
    return wx_status, max_new_tokens, prompt


def wx_generate(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_engine = wx_engine_state
    wx_status, max_new_tokens, prompt = prepare_prompt(
        wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
    )
    wx_result = wx_primer + wx_engine.generate_text(
        modelid=wx_model,
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        stop_sequences=["</svg>"],
    )
    print(f"wx_generate:wx_result:{wx_result}")
    return wx_status, wx_result, xml_string_to_html_img(wx_result)


def wx_stream(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_engine = wx_engine_state
    wx_status, max_new_tokens, prompt = prepare_prompt(
        wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
    )
    wx_result = wx_primer
    # https://www.gradio.app/guides/streaming-outputs
    wx_result_generator = wx_engine.generate_text(
        modelid=wx_model,
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        stop_sequences=["</svg>"],
        stream=True,
    )
    for chunk in wx_result_generator:
        wx_result += chunk
        yield f"Processing.{'.'*int(randrange(3))}", wx_result, None
    print(f"wx_stream:wx_result:{wx_result}")
    yield wx_status, wx_result, xml_string_to_html_img(wx_result)


def wx_result_box_change(wx_result):
    return xml_string_to_html_img(wx_result)


def wx_accept(svg_xml):
    data_string = xml_string_to_data_string(svg_xml)
    return (
        data_string,
        svg_xml,
        html_img_preview(data_string),
        xml_string_to_html_img(svg_xml),
    )


# APP
with gr.Blocks() as demo:
    gr.Markdown("# SVG editor")
    with gr.Accordion("Get started:", open=True):
        gr.Markdown(
            """
            - Create a new SVG: Select the prompt template 'Create SVG', enter a description in the 'Instructions' box and click 'Submit'
            - Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'
            - Describe an SVG: Upload, paste or generate an SVG file, Select the prompt template 'Describe SVG' and click 'Submit' """
        )

    # load env variables
    status, env_apiendpoint, env_apikey, env_projectid = load_credentials()
    # init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not
    wx_engine_state = gr.State(None)
    with gr.Column():
        # Encoded string and preview
        with gr.Accordion("Load SVG:", open=True):
            with gr.Row():
                # Upload an .svg file
                uploaded_file = gr.File(scale=0, label="Upload an SVG file")
                # Paste an image string
                encoded_string_box = gr.Textbox(
                    label="Image string",
                    info="data:image/svg+xml;base64,...",
                    lines=7,
                    max_lines=7,
                    show_copy_button=True,
                    scale=3,
                )
                # original preview
                encoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
            # Decoded string and preview
            with gr.Row():
                xml_string_box = gr.Textbox(
                    label="SVG XML",
                    lines=7,
                    max_lines=7,
                    show_copy_button=True,
                    scale=3,
                )
                # decoded preview
                decoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
        # with gr.Row():  # Color changer
        #     color_from_area = gr.ColorPicker(label="Search color:", value="#000000", scale=0)
        #     color_switch_btn = gr.Button("<->", scale=0)
        #     color_to_area = gr.ColorPicker(label="Replace color:", value="#FFFFFF", scale=0)
        #     color_replace_btn = gr.Button("Replace", scale=0)

        with gr.Accordion("watsonx.ai:", open=True):
            with gr.Row():  # watsonx
                with gr.Column(scale=0):
                    with gr.Group():
                        # credentials
                        with gr.Accordion("Credentials:", open=True):
                            wx_creds_endpoint = gr.Textbox(
                                label="Endpoint", value=env_apiendpoint, max_lines=1
                            )
                            wx_creds_apikey = gr.Textbox(
                                label="API key", value=env_apikey, max_lines=1
                            )
                            wx_creds_projectid = gr.Textbox(
                                label="Project id", value=env_projectid, max_lines=1
                            )
                            wx_connect_btn = gr.Button("Connect")
                        # model
                        wx_models_drop = wx_models_dropdown(None)
                        wx_connect_btn.click(
                            fn=wx_connect,
                            inputs=[
                                wx_engine_state,
                                wx_creds_endpoint,
                                wx_creds_apikey,
                                wx_creds_projectid,
                            ],
                            outputs=[wx_engine_state, wx_models_drop],
                        )
                        # prompt template
                        prompt_template_names = list_prompts()
                        wx_prompt_drop = gr.Dropdown(
                            label="Prompt template",
                            choices=prompt_template_names,
                            value=prompt_template_names[0],
                        )
                        wx_prompt_box = gr.Textbox(
                            info="Text",
                            show_label=False,
                            max_lines=5,
                            value=get_prompt_template(prompt_template_names[0]),
                        )
                        wx_primer_box = gr.Textbox(
                            info="Primer",
                            show_label=False,
                            max_lines=2,
                            value=get_prompt_primer(prompt_template_names[0]),
                        )
                with gr.Column():
                    with gr.Row():
                        wx_instructions_box = gr.Textbox(
                            label="Instructions",
                            scale=3,
                            value=get_prompt_example(prompt_template_names[0]),
                            show_copy_button=True,
                        )
                        wx_submit_btn = gr.Button("↓Submit↓", scale=0)
                        wx_accept_btn = gr.Button("↑Accept↑", scale=0)
                    with gr.Row():
                        wx_result_box = gr.Textbox(
                            label="Result",
                            lines=7,
                            max_lines=7,
                            scale=3,
                            show_copy_button=True,
                        )
                        wx_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
                    wx_status_box = gr.Markdown("Status")

                wx_prompt_drop.input(
                    fn=wx_prompt_change,
                    inputs=wx_prompt_drop,
                    outputs=[wx_prompt_box, wx_instructions_box, wx_primer_box],
                )
                wx_submit_btn.click(
                    # fn=wx_generate,
                    fn=wx_stream,
                    inputs=[
                        wx_engine_state,
                        wx_models_drop,
                        wx_prompt_box,
                        wx_instructions_box,
                        wx_primer_box,
                        xml_string_box,
                    ],
                    outputs=[wx_status_box, wx_result_box, wx_svg_preview],
                    api_name="wx_generate",
                )
                wx_result_box.input(
                    fn=wx_result_box_change,
                    inputs=[wx_result_box],
                    outputs=[wx_svg_preview],
                )
                wx_accept_btn.click(
                    fn=wx_accept,
                    inputs=[wx_result_box],
                    outputs=[
                        encoded_string_box,
                        xml_string_box,
                        encoded_svg_preview,
                        decoded_svg_preview,
                    ],
                )

    # Actions
    encoded_string_box.input(
        fn=encoded_string_box_change,
        inputs=[encoded_string_box],
        outputs=[xml_string_box, encoded_svg_preview, decoded_svg_preview],
    )
    uploaded_file.upload(
        fn=read_file,
        inputs=uploaded_file,
        outputs=[
            encoded_string_box,
            xml_string_box,
            encoded_svg_preview,
            decoded_svg_preview,
        ],
    )
    xml_string_box.input(
        fn=xml_string_box_change,
        inputs=[xml_string_box],
        outputs=[encoded_string_box, encoded_svg_preview, decoded_svg_preview],
    )
    # color_switch_btn.click(
    #     fn=switch_colors,
    #     inputs=[color_from_area, color_to_area],
    #     outputs=[color_from_area, color_to_area],
    #     api_name="color_switch",
    # )
    # color_replace_btn.click(
    #     fn=replace_color,
    #     inputs=[xml_string_box, color_from_area, color_to_area],
    #     outputs=[encoded_string_box, xml_string_box, svg_preview],
    #     api_name="color_replace",
    # )


# Main
if __name__ == "__main__":
    demo.launch()