File size: 13,162 Bytes
37c29c8
35805d2
37c29c8
 
 
 
 
 
 
 
 
7be1a06
37c29c8
 
7be1a06
37c29c8
 
7be1a06
37c29c8
 
 
7be1a06
37c29c8
7be1a06
37c29c8
7be1a06
37c29c8
 
 
7be1a06
37c29c8
 
7be1a06
37c29c8
 
7be1a06
37c29c8
7be1a06
37c29c8
 
 
 
 
 
 
 
 
86da7c8
37c29c8
 
7be1a06
37c29c8
 
 
 
7be1a06
 
 
 
 
 
 
 
 
37c29c8
 
 
 
 
 
 
 
 
 
7be1a06
37c29c8
7be1a06
 
 
37c29c8
 
 
 
 
 
 
 
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be1a06
 
37c29c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be1a06
 
37c29c8
7be1a06
 
37c29c8
 
7be1a06
37c29c8
 
7be1a06
 
37c29c8
 
7be1a06
37c29c8
 
 
 
7be1a06
c0d26d5
37c29c8
7be1a06
 
 
 
 
37c29c8
 
 
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2a6c6d
 
 
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
 
 
 
 
 
7be1a06
37c29c8
7be1a06
 
37c29c8
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
 
 
 
7be1a06
37c29c8
 
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
7be1a06
37c29c8
7be1a06
 
 
37c29c8
 
7be1a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c29c8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
# Remi Serra 202407
from env_utils import load_credentials
import gradio as gr
from random import randrange
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string
from watsonx_utils import wxEngine
from prompts import (
    list_prompts,
    get_prompt_template,
    get_prompt_example,
    get_prompt_primer,
    get_prompt_uploadmode,
)
from data_images import svg_three_dots
from ibm_watsonx_ai.wml_client_error import WMLClientError


# Functions - input
def read_file(uploaded_file):
    if uploaded_file:
        svg_xml = open(uploaded_file, "r").read()
        encoded_data_string = xml_string_to_data_string(svg_xml)
        return (
            encoded_data_string,
            svg_xml,
            html_img_preview(encoded_data_string),
        )


def input_encoded_string_box_change(data_string: str):
    # print(f"encoded_string_box_change:image_data:{data_string}")
    svg_xml = decode_b64_string_to_pretty_xml(data_string)
    return svg_xml, html_img_preview(data_string)


def input_xml_string_box_change(svg_xml: str):
    data_string = xml_string_to_data_string(svg_xml)
    return data_string, html_img_preview(data_string)


def xml_string_to_data_string(svg_xml: str):
    b64 = encode_svg_xml_to_b64_string(svg_xml)
    data_string = "data:image/svg+xml;base64," + b64
    return data_string


def html_img_preview(data_string):
    return f'<img src="{data_string}" width="100px" style="display: block; margin-left: auto; margin-right: auto;"/>'


image_placeholder = html_img_preview(svg_three_dots)

# Functions - watsonx


def wx_prompt_drop_change(prompt_template_name):
    show_upload = get_prompt_uploadmode(prompt_template_name)
    return {
        wx_prompt_box: get_prompt_template(prompt_template_name),
        wx_instructions_box: get_prompt_example(prompt_template_name),
        wx_primer_box: get_prompt_primer(prompt_template_name),
        upload_row: gr.Row(visible=show_upload),
        upload_accordeon: gr.Accordion(visible=show_upload),
    }


def wx_models_dropdown(wx_engine_state):
    wx_engine = wx_engine_state
    model_list = []
    default_value = None
    recommended_model = "ibm/granite-20b-code-instruct"
    if wx_engine is not None:
        model_list = wx_engine.list_models()
        default_value = (
            recommended_model if recommended_model in model_list else model_list[0]
        )
        # print(f"wx_models_dropdown:model_list:{model_list}")
        # print(f"wx_models_dropdown:default_value:{default_value}")

    return gr.Dropdown(
        label="Model",
        info=recommended_model + " recommended",
        choices=model_list,
        value=default_value,
    )


def wx_connect_click(wx_engine_state, apiendpoint, apikey, projectid):
    # if apiendpoint is not None and apikey is not None and projectid is not None:
    try:
        wx_engine_state = wxEngine(apiendpoint, apikey, projectid)
        msg = "watsonx.ai sucessfully activated"
        print(msg)
        return (
            wx_engine_state,
            wx_models_dropdown(wx_engine_state),
            gr.Accordion(open=False),
            gr.Button(interactive=True),
            gr.Textbox(msg),
        )
    except WMLClientError as ex:
        template = "Exception {0} occurred: {1!r}"
        msg = template.format(type(ex).__name__, ex.args)
        print(msg)
    return (
        wx_engine_state,
        [],
        gr.Accordion(open=True),
        gr.Button(interactive=False),
        gr.Textbox(msg),
    )


def prepare_prompt(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_status = "Done."
    wx_engine = wx_engine_state
    # get model specs
    model_max_tokens = wx_engine.get_model_max_tokens(wx_model)
    # Add "primer" at the end of the prompt
    prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer
    # Test and alert if prompt is too long
    prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model)
    if prompt_nb_tokens > model_max_tokens:
        wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions."
        print(wx_status)
    # calculate max new token based on xml_string - or 500 when original string is too small
    # note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate()
    max_new_tokens = max(500, len(xml_string))
    return wx_status, max_new_tokens, prompt


def wx_generate(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_engine = wx_engine_state
    wx_status, max_new_tokens, prompt = prepare_prompt(
        wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
    )
    wx_result = wx_primer + wx_engine.generate_text(
        modelid=wx_model,
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        stop_sequences=["</svg>"],
    )
    print(f"wx_generate:wx_result:{wx_result}")
    data_string = xml_string_to_data_string(wx_result)
    return wx_status, data_string, wx_result, html_img_preview(data_string)


def wx_stream(
    wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
    wx_engine = wx_engine_state
    wx_status, max_new_tokens, prompt = prepare_prompt(
        wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
    )
    wx_result = wx_primer
    # https://www.gradio.app/guides/streaming-outputs
    wx_result_generator = wx_engine.generate_text(
        modelid=wx_model,
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        stop_sequences=["</svg>"],
        stream=True,
    )
    for chunk in wx_result_generator:
        wx_result += chunk
        status = f"Processing.{'.'*int(randrange(3))}"
        yield status, status, wx_result, None
    print(f"wx_stream:wx_result:{wx_result}")
    data_string = xml_string_to_data_string(wx_result)
    yield wx_status, data_string, wx_result, html_img_preview(data_string)


# Functions - output


def output_xml_string_box_change(xml_string):
    data_string = xml_string_to_data_string(xml_string)
    return (
        data_string,
        xml_string,
        html_img_preview(data_string),
    )


# APP layout
with gr.Blocks(theme = "Zarkel/IBM_Carbon_Theme") as demo:
    gr.Markdown("# SVG editor")
    gr.Markdown(
        """### Get started:
        - Create a new SVG: Enter a description in the 'Instructions' box and click 'Submit'
        - Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'"""
    )

    # init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not
    wx_engine_state = gr.State(None)
    with gr.Row():  # main UI
        with gr.Column(scale=0):  # watsonx setup
            # prompt template selection
            prompt_template_names = list_prompts()
            default_prompt_template_name = prompt_template_names[0]
            wx_prompt_drop = gr.Dropdown(
                label="Action",
                choices=prompt_template_names,
                value=default_prompt_template_name,
            )
            # credentials
            # load env variables
            status_unused, env_apiendpoint, env_apikey, env_projectid = (
                load_credentials()
            )
            with gr.Accordion("Credentials", open=True) as credentials_accordeon:
                wx_creds_endpoint = gr.Textbox(
                    label="Endpoint",
                    value=env_apiendpoint or "https://us-south.ml.cloud.ibm.com",
                    max_lines=1,
                )
                wx_creds_apikey = gr.Textbox(
                    label="API key", value=env_apikey, max_lines=1
                )
                wx_creds_projectid = gr.Textbox(
                    label="Project id", value=env_projectid, max_lines=1
                )
                wx_connect_btn = gr.Button("Connect")
            # model
            wx_models_drop = wx_models_dropdown(None)
            # prompt text and primer
            wx_prompt_box = gr.Textbox(
                info="Text",
                show_label=False,
                max_lines=5,
                value=get_prompt_template(prompt_template_names[0]),
            )
            wx_primer_box = gr.Textbox(
                info="Primer",
                show_label=False,
                max_lines=2,
                value=get_prompt_primer(prompt_template_names[0]),
            )
        with gr.Column():  # main pane
            # Upload
            with gr.Row(
                visible=get_prompt_uploadmode(default_prompt_template_name)
            ) as upload_row:
                # Upload an .svg file
                input_file = gr.File(scale=0, label="Upload an SVG file")
                # original preview
                input_svg_preview = gr.HTML(image_placeholder)
                # decoded SVG XML
                input_xml_string_box = gr.Textbox(
                    label="Input SVG XML",
                    lines=7,
                    max_lines=7,
                    show_copy_button=True,
                    scale=3,
                )
            with gr.Accordion(
                label="Input encoded string",
                open=False,
                visible=get_prompt_uploadmode(default_prompt_template_name),
            ) as upload_accordeon:
                # Encoded image string
                input_encoded_string_box = gr.Textbox(
                    label="Image string",
                    info="data:image/svg+xml;base64,...",
                    lines=7,
                    max_lines=7,
                    show_copy_button=True,
                    scale=3,
                    container=False,
                )

            # modification
            with gr.Row():
                wx_instructions_box = gr.Textbox(
                    label="Instructions",
                    scale=3,
                    value=get_prompt_example(prompt_template_names[0]),
                    show_copy_button=True,
                )
                wx_generate_btn = gr.Button("↓Generate↓", scale=0, interactive=False)
            output_svg_preview = gr.HTML(image_placeholder)
            output_xml_string_box = gr.Textbox(
                label="Result SVG XML",
                lines=7,
                max_lines=7,
                scale=3,
                show_copy_button=True,
            )
            with gr.Accordion(label="Result encoded string", open=False):
                output_encoded_string_box = gr.Textbox(
                    label="Image string",
                    info="data:image/svg+xml;base64,...",
                    lines=7,
                    max_lines=7,
                    show_copy_button=True,
                    scale=3,
                    container=False,
                )

            wx_status_box = gr.Markdown("Status")

    # Map controls to functions
    wx_prompt_drop.input(
        fn=wx_prompt_drop_change,
        inputs=wx_prompt_drop,
        outputs=[
            wx_prompt_box,
            wx_instructions_box,
            wx_primer_box,
            upload_row,
            upload_accordeon,
        ],
    )
    wx_connect_btn.click(
        fn=wx_connect_click,
        inputs=[
            wx_engine_state,
            wx_creds_endpoint,
            wx_creds_apikey,
            wx_creds_projectid,
        ],
        outputs=[
            wx_engine_state,
            wx_models_drop,
            credentials_accordeon,
            wx_generate_btn,
            wx_status_box,
        ],
    )
    input_file.upload(
        fn=read_file,
        inputs=input_file,
        outputs=[
            input_encoded_string_box,
            input_xml_string_box,
            input_svg_preview,
        ],
    )
    input_encoded_string_box.input(
        fn=input_encoded_string_box_change,
        inputs=[input_encoded_string_box],
        outputs=[input_xml_string_box, input_svg_preview],
    )
    input_xml_string_box.input(
        fn=input_xml_string_box_change,
        inputs=[input_xml_string_box],
        outputs=[input_encoded_string_box, input_svg_preview],
    )
    wx_generate_btn.click(
        fn=wx_stream,
        inputs=[
            wx_engine_state,
            wx_models_drop,
            wx_prompt_box,
            wx_instructions_box,
            wx_primer_box,
            input_xml_string_box,
        ],
        outputs=[
            wx_status_box,
            output_encoded_string_box,
            output_xml_string_box,
            output_svg_preview,
        ],
        api_name="wx_generate",
    )
    output_xml_string_box.input(
        fn=output_xml_string_box_change,
        inputs=[output_xml_string_box],
        outputs=[
            output_encoded_string_box,
            output_xml_string_box,
            output_svg_preview,
        ],
    )

# Main
if __name__ == "__main__":
    demo.launch()