svg-editor / app.py
remiserra's picture
typo
73dc8c0
raw
history blame
13.9 kB
# Remi Serra 202407
import gradio as gr
from random import randrange
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string
from watsonx_utils import wxEngine
from prompts import (
list_prompts,
get_prompt_template,
get_prompt_example,
get_prompt_primer,
)
from data_images import svg_three_dots
# Functions
def read_file(uploaded_file):
if uploaded_file:
svg_xml = open(uploaded_file, "r").read()
data_string = xml_string_to_data_string(svg_xml)
return (
data_string,
svg_xml,
html_img_preview(data_string),
xml_string_to_html_img(svg_xml),
)
def encoded_string_box_change(data_string: str):
# print(f"encoded_string_box_change:image_data:{data_string}")
svg_xml = decode_b64_string_to_pretty_xml(data_string)
return svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml)
def xml_string_box_change(svg_xml: str):
data_string = xml_string_to_data_string(svg_xml)
return data_string, html_img_preview(data_string), xml_string_to_html_img(svg_xml)
def xml_string_to_data_string(svg_xml: str):
b64 = encode_svg_xml_to_b64_string(svg_xml)
data_string = "data:image/svg+xml;base64," + b64
return data_string
def html_img_preview(data_string):
return r'<center><img width=100 src="%s"/><center>' % data_string
def xml_string_to_html_img(svg_xml: str):
data_string = xml_string_to_data_string(svg_xml)
return html_img_preview(data_string)
# def replace_color(svg_xml: str, color_from, color_to):
# print(
# f"replace_color:svg_xml:{svg_xml}, color_from:{color_from}, color_to:{color_to} "
# )
# if svg_xml and color_from and color_to:
# new_svg_xml = svg_xml.replace(color_from, color_to)
# data_string = xml_string_to_data_string(new_svg_xml)
# return data_string, new_svg_xml, html_img_preview(data_string)
# def switch_colors(color_from, color_to):
# print(f"switch_colors:color_from:{color_from}, color_to:{color_to} ")
# return color_to, color_from
# Functions - watsonx
def wx_prompt_change(prompt_template_name):
return (
get_prompt_template(prompt_template_name),
get_prompt_example(prompt_template_name),
get_prompt_primer(prompt_template_name),
)
def wx_models_dropdown(wx_engine_state):
wx_engine = wx_engine_state
model_list = []
default_value = None
recommended_model = "ibm/granite-20b-code-instruct"
if wx_engine is not None:
model_list = wx_engine.list_models()
default_value = (
(recommended_model if recommended_model in model_list else model_list[0]),
)
return gr.Dropdown(
label="Model",
info=recommended_model + " recommended",
choices=model_list,
value=default_value,
)
def wx_connect(wx_engine_state, apiendpoint, apikey, projectid):
wx_engine_state = wxEngine(apiendpoint, apikey, projectid)
print("watsonx.ai activated")
return wx_engine_state, wx_models_dropdown(wx_engine_state)
def prepare_prompt(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_status = "Done."
wx_engine = wx_engine_state
# get model specs
model_max_tokens = wx_engine.get_model_max_tokens(wx_model)
# Add "primer" at the end of the prompt
prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer
# Test and alert if prompt is too long
prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model)
if prompt_nb_tokens > model_max_tokens:
wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions."
print(wx_status)
# calculate max new token based on xml_string - or 500 when original string is too small
# note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate()
max_new_tokens = max(500, len(xml_string))
return wx_status, max_new_tokens, prompt
def wx_generate(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_engine = wx_engine_state
wx_status, max_new_tokens, prompt = prepare_prompt(
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
)
wx_result = wx_primer + wx_engine.generate_text(
modelid=wx_model,
prompt=prompt,
max_new_tokens=max_new_tokens,
stop_sequences=["</svg>"],
)
print(f"wx_generate:wx_result:{wx_result}")
return wx_status, wx_result, xml_string_to_html_img(wx_result)
def wx_stream(
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string=""
):
wx_engine = wx_engine_state
wx_status, max_new_tokens, prompt = prepare_prompt(
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string
)
wx_result = wx_primer
# https://www.gradio.app/guides/streaming-outputs
wx_result_generator = wx_engine.generate_text(
modelid=wx_model,
prompt=prompt,
max_new_tokens=max_new_tokens,
stop_sequences=["</svg>"],
stream=True,
)
for chunk in wx_result_generator:
wx_result += chunk
yield f"Processing.{'.'*int(randrange(3))}", wx_result, None
print(f"wx_stream:wx_result:{wx_result}")
yield wx_status, wx_result, xml_string_to_html_img(wx_result)
def wx_result_box_change(wx_result):
return xml_string_to_html_img(wx_result)
def wx_accept(svg_xml):
data_string = xml_string_to_data_string(svg_xml)
return (
data_string,
svg_xml,
html_img_preview(data_string),
xml_string_to_html_img(svg_xml),
)
# APP
with gr.Blocks() as demo:
gr.Markdown("# SVG editor")
with gr.Accordion("Get started:", open=True):
gr.Markdown(
"""
- Create a new SVG: Select the prompt template 'Create SVG', enter a description in the 'Instructions' box and click 'Submit'
- Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit'
- Describe an SVG: Upload, paste or generate an SVG file, Select the prompt template 'Describe SVG' and click 'Submit' """
)
# init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not
wx_engine_state = gr.State(None)
with gr.Column():
# Encoded string and preview
with gr.Accordion("Load SVG:", open=True):
with gr.Row():
# Upload an .svg file
uploaded_file = gr.File(scale=0, label="Upload an SVG file")
# Paste an image string
encoded_string_box = gr.Textbox(
label="Image string",
info="data:image/svg+xml;base64,...",
lines=7,
max_lines=7,
show_copy_button=True,
scale=3,
)
# original preview
encoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
# Decoded string and preview
with gr.Row():
xml_string_box = gr.Textbox(
label="SVG XML",
lines=7,
max_lines=7,
show_copy_button=True,
scale=3,
)
# decoded preview
decoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
# with gr.Row(): # Color changer
# color_from_area = gr.ColorPicker(label="Search color:", value="#000000", scale=0)
# color_switch_btn = gr.Button("<->", scale=0)
# color_to_area = gr.ColorPicker(label="Replace color:", value="#FFFFFF", scale=0)
# color_replace_btn = gr.Button("Replace", scale=0)
with gr.Accordion("watsonx.ai:", open=True):
with gr.Row(): # watsonx
with gr.Column(scale=0):
with gr.Group():
# credentials
with gr.Accordion("Credentials:", open=True):
wx_creds_endpoint = gr.Textbox(
label="Endpoint",
value="https://us-south.ml.cloud.ibm.com",
max_lines=1,
)
wx_creds_apikey = gr.Textbox(label="API key", max_lines=1)
wx_creds_projectid = gr.Textbox(
label="Project id", max_lines=1
)
wx_connect_btn = gr.Button("Connect")
# model
wx_models_drop = wx_models_dropdown(None)
wx_connect_btn.click(
fn=wx_connect,
inputs=[
wx_engine_state,
wx_creds_endpoint,
wx_creds_apikey,
wx_creds_projectid,
],
outputs=[wx_engine_state, wx_models_drop],
)
# prompt template
prompt_template_names = list_prompts()
wx_prompt_drop = gr.Dropdown(
label="Prompt template",
choices=prompt_template_names,
value=prompt_template_names[0],
)
wx_prompt_box = gr.Textbox(
info="Text",
show_label=False,
max_lines=5,
value=get_prompt_template(prompt_template_names[0]),
)
wx_primer_box = gr.Textbox(
info="Primer",
show_label=False,
max_lines=2,
value=get_prompt_primer(prompt_template_names[0]),
)
with gr.Column():
with gr.Row():
wx_instructions_box = gr.Textbox(
label="Instructions",
scale=3,
value=get_prompt_example(prompt_template_names[0]),
show_copy_button=True,
)
wx_submit_btn = gr.Button("↓Submit↓", scale=0)
wx_accept_btn = gr.Button("↑Accept↑", scale=0)
with gr.Row():
wx_result_box = gr.Textbox(
label="Result",
lines=7,
max_lines=7,
scale=3,
show_copy_button=True,
)
wx_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>")
wx_status_box = gr.Markdown("Status")
wx_prompt_drop.input(
fn=wx_prompt_change,
inputs=wx_prompt_drop,
outputs=[wx_prompt_box, wx_instructions_box, wx_primer_box],
)
wx_submit_btn.click(
# fn=wx_generate,
fn=wx_stream,
inputs=[
wx_engine_state,
wx_models_drop,
wx_prompt_box,
wx_instructions_box,
wx_primer_box,
xml_string_box,
],
outputs=[wx_status_box, wx_result_box, wx_svg_preview],
api_name="wx_generate",
)
wx_result_box.input(
fn=wx_result_box_change,
inputs=[wx_result_box],
outputs=[wx_svg_preview],
)
wx_accept_btn.click(
fn=wx_accept,
inputs=[wx_result_box],
outputs=[
encoded_string_box,
xml_string_box,
encoded_svg_preview,
decoded_svg_preview,
],
)
# Actions
encoded_string_box.input(
fn=encoded_string_box_change,
inputs=[encoded_string_box],
outputs=[xml_string_box, encoded_svg_preview, decoded_svg_preview],
)
uploaded_file.upload(
fn=read_file,
inputs=uploaded_file,
outputs=[
encoded_string_box,
xml_string_box,
encoded_svg_preview,
decoded_svg_preview,
],
)
xml_string_box.input(
fn=xml_string_box_change,
inputs=[xml_string_box],
outputs=[encoded_string_box, encoded_svg_preview, decoded_svg_preview],
)
# color_switch_btn.click(
# fn=switch_colors,
# inputs=[color_from_area, color_to_area],
# outputs=[color_from_area, color_to_area],
# api_name="color_switch",
# )
# color_replace_btn.click(
# fn=replace_color,
# inputs=[xml_string_box, color_from_area, color_to_area],
# outputs=[encoded_string_box, xml_string_box, svg_preview],
# api_name="color_replace",
# )
# Main
if __name__ == "__main__":
demo.launch()