# Remi Serra 202407 import gradio as gr from random import randrange from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string from watsonx_utils import wxEngine from prompts import ( list_prompts, get_prompt_template, get_prompt_example, get_prompt_primer, ) from data_images import svg_three_dots # Functions def read_file(uploaded_file): if uploaded_file: svg_xml = open(uploaded_file, "r").read() data_string = xml_string_to_data_string(svg_xml) return ( data_string, svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml), ) def encoded_string_box_change(data_string: str): # print(f"encoded_string_box_change:image_data:{data_string}") svg_xml = decode_b64_string_to_pretty_xml(data_string) return svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml) def xml_string_box_change(svg_xml: str): data_string = xml_string_to_data_string(svg_xml) return data_string, html_img_preview(data_string), xml_string_to_html_img(svg_xml) def xml_string_to_data_string(svg_xml: str): b64 = encode_svg_xml_to_b64_string(svg_xml) data_string = "data:image/svg+xml;base64," + b64 return data_string def html_img_preview(data_string): return r'
' % data_string def xml_string_to_html_img(svg_xml: str): data_string = xml_string_to_data_string(svg_xml) return html_img_preview(data_string) # def replace_color(svg_xml: str, color_from, color_to): # print( # f"replace_color:svg_xml:{svg_xml}, color_from:{color_from}, color_to:{color_to} " # ) # if svg_xml and color_from and color_to: # new_svg_xml = svg_xml.replace(color_from, color_to) # data_string = xml_string_to_data_string(new_svg_xml) # return data_string, new_svg_xml, html_img_preview(data_string) # def switch_colors(color_from, color_to): # print(f"switch_colors:color_from:{color_from}, color_to:{color_to} ") # return color_to, color_from # Functions - watsonx def wx_prompt_change(prompt_template_name): return ( get_prompt_template(prompt_template_name), get_prompt_example(prompt_template_name), get_prompt_primer(prompt_template_name), ) def wx_models_dropdown(wx_engine_state): wx_engine = wx_engine_state model_list = [] default_value = None recommended_model = "ibm/granite-20b-code-instruct" if wx_engine is not None: model_list = wx_engine.list_models() default_value = ( (recommended_model if recommended_model in model_list else model_list[0]), ) return gr.Dropdown( label="Model", info=recommended_model + " recommended", choices=model_list, value=default_value, ) def wx_connect(wx_engine_state, apiendpoint, apikey, projectid): wx_engine_state = wxEngine(apiendpoint, apikey, projectid) print("watsonx.ai activated") return wx_engine_state, wx_models_dropdown(wx_engine_state) def prepare_prompt( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_status = "Done." wx_engine = wx_engine_state # get model specs model_max_tokens = wx_engine.get_model_max_tokens(wx_model) # Add "primer" at the end of the prompt prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer # Test and alert if prompt is too long prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model) if prompt_nb_tokens > model_max_tokens: wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions." print(wx_status) # calculate max new token based on xml_string - or 500 when original string is too small # note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate() max_new_tokens = max(500, len(xml_string)) return wx_status, max_new_tokens, prompt def wx_generate( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_engine = wx_engine_state wx_status, max_new_tokens, prompt = prepare_prompt( wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string ) wx_result = wx_primer + wx_engine.generate_text( modelid=wx_model, prompt=prompt, max_new_tokens=max_new_tokens, stop_sequences=[""], ) print(f"wx_generate:wx_result:{wx_result}") return wx_status, wx_result, xml_string_to_html_img(wx_result) def wx_stream( wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" ): wx_engine = wx_engine_state wx_status, max_new_tokens, prompt = prepare_prompt( wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string ) wx_result = wx_primer # https://www.gradio.app/guides/streaming-outputs wx_result_generator = wx_engine.generate_text( modelid=wx_model, prompt=prompt, max_new_tokens=max_new_tokens, stop_sequences=[""], stream=True, ) for chunk in wx_result_generator: wx_result += chunk yield f"Processing.{"."*int(randrange(3))}", wx_result, None print(f"wx_stream:wx_result:{wx_result}") yield wx_status, wx_result, xml_string_to_html_img(wx_result) def wx_result_box_change(wx_result): return xml_string_to_html_img(wx_result) def wx_accept(svg_xml): data_string = xml_string_to_data_string(svg_xml) return ( data_string, svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml), ) # APP with gr.Blocks() as demo: gr.Markdown("# SVG editor") with gr.Accordion("Get started:", open=True): gr.Markdown( """ - Create a new SVG: Select the prompt template 'Create SVG', enter a description in the 'Instructions' box and click 'Submit' - Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit' - Describe an SVG: Upload, paste or generate an SVG file, Select the prompt template 'Describe SVG' and click 'Submit' """ ) # init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not wx_engine_state = gr.State(None) with gr.Column(): # Encoded string and preview with gr.Accordion("Load SVG:", open=True): with gr.Row(): # Upload an .svg file uploaded_file = gr.File(scale=0, label="Upload an SVG file") # Paste an image string encoded_string_box = gr.Textbox( label="Image string", info="data:image/svg+xml;base64,...", lines=7, max_lines=7, show_copy_button=True, scale=3, ) # original preview encoded_svg_preview = gr.HTML(f"") # Decoded string and preview with gr.Row(): xml_string_box = gr.Textbox( label="SVG XML", lines=7, max_lines=7, show_copy_button=True, scale=3, ) # decoded preview decoded_svg_preview = gr.HTML(f"") # with gr.Row(): # Color changer # color_from_area = gr.ColorPicker(label="Search color:", value="#000000", scale=0) # color_switch_btn = gr.Button("<->", scale=0) # color_to_area = gr.ColorPicker(label="Replace color:", value="#FFFFFF", scale=0) # color_replace_btn = gr.Button("Replace", scale=0) with gr.Accordion("watsonx.ai:", open=True): with gr.Row(): # watsonx with gr.Column(scale=0): with gr.Group(): # credentials with gr.Accordion("Credentials:", open=True): wx_creds_endpoint = gr.Textbox( label="Endpoint", value="https://us-south.ml.cloud.ibm.com", max_lines=1 ) wx_creds_apikey = gr.Textbox( label="API key", max_lines=1 ) wx_creds_projectid = gr.Textbox( label="Project id", max_lines=1 ) wx_connect_btn = gr.Button("Connect") # model wx_models_drop = wx_models_dropdown(None) wx_connect_btn.click( fn=wx_connect, inputs=[ wx_engine_state, wx_creds_endpoint, wx_creds_apikey, wx_creds_projectid, ], outputs=[wx_engine_state, wx_models_drop], ) # prompt template prompt_template_names = list_prompts() wx_prompt_drop = gr.Dropdown( label="Prompt template", choices=prompt_template_names, value=prompt_template_names[0], ) wx_prompt_box = gr.Textbox( info="Text", show_label=False, max_lines=5, value=get_prompt_template(prompt_template_names[0]), ) wx_primer_box = gr.Textbox( info="Primer", show_label=False, max_lines=2, value=get_prompt_primer(prompt_template_names[0]), ) with gr.Column(): with gr.Row(): wx_instructions_box = gr.Textbox( label="Instructions", scale=3, value=get_prompt_example(prompt_template_names[0]), show_copy_button=True, ) wx_submit_btn = gr.Button("↓Submit↓", scale=0) wx_accept_btn = gr.Button("↑Accept↑", scale=0) with gr.Row(): wx_result_box = gr.Textbox( label="Result", lines=7, max_lines=7, scale=3, show_copy_button=True, ) wx_svg_preview = gr.HTML(f"") wx_status_box = gr.Markdown("Status") wx_prompt_drop.input( fn=wx_prompt_change, inputs=wx_prompt_drop, outputs=[wx_prompt_box, wx_instructions_box, wx_primer_box], ) wx_submit_btn.click( # fn=wx_generate, fn=wx_stream, inputs=[ wx_engine_state, wx_models_drop, wx_prompt_box, wx_instructions_box, wx_primer_box, xml_string_box, ], outputs=[wx_status_box, wx_result_box, wx_svg_preview], api_name="wx_generate", ) wx_result_box.input( fn=wx_result_box_change, inputs=[wx_result_box], outputs=[wx_svg_preview], ) wx_accept_btn.click( fn=wx_accept, inputs=[wx_result_box], outputs=[ encoded_string_box, xml_string_box, encoded_svg_preview, decoded_svg_preview, ], ) # Actions encoded_string_box.input( fn=encoded_string_box_change, inputs=[encoded_string_box], outputs=[xml_string_box, encoded_svg_preview, decoded_svg_preview], ) uploaded_file.upload( fn=read_file, inputs=uploaded_file, outputs=[ encoded_string_box, xml_string_box, encoded_svg_preview, decoded_svg_preview, ], ) xml_string_box.input( fn=xml_string_box_change, inputs=[xml_string_box], outputs=[encoded_string_box, encoded_svg_preview, decoded_svg_preview], ) # color_switch_btn.click( # fn=switch_colors, # inputs=[color_from_area, color_to_area], # outputs=[color_from_area, color_to_area], # api_name="color_switch", # ) # color_replace_btn.click( # fn=replace_color, # inputs=[xml_string_box, color_from_area, color_to_area], # outputs=[encoded_string_box, xml_string_box, svg_preview], # api_name="color_replace", # ) # Main if __name__ == "__main__": demo.launch()