Spaces:
Running
Running
# Remi Serra 202407 | |
import gradio as gr | |
from random import randrange | |
from svg_utils import decode_b64_string_to_pretty_xml, encode_svg_xml_to_b64_string | |
from watsonx_utils import wxEngine | |
from prompts import ( | |
list_prompts, | |
get_prompt_template, | |
get_prompt_example, | |
get_prompt_primer, | |
) | |
from data_images import svg_three_dots | |
# Functions | |
def read_file(uploaded_file): | |
if uploaded_file: | |
svg_xml = open(uploaded_file, "r").read() | |
data_string = xml_string_to_data_string(svg_xml) | |
return ( | |
data_string, | |
svg_xml, | |
html_img_preview(data_string), | |
xml_string_to_html_img(svg_xml), | |
) | |
def encoded_string_box_change(data_string: str): | |
# print(f"encoded_string_box_change:image_data:{data_string}") | |
svg_xml = decode_b64_string_to_pretty_xml(data_string) | |
return svg_xml, html_img_preview(data_string), xml_string_to_html_img(svg_xml) | |
def xml_string_box_change(svg_xml: str): | |
data_string = xml_string_to_data_string(svg_xml) | |
return data_string, html_img_preview(data_string), xml_string_to_html_img(svg_xml) | |
def xml_string_to_data_string(svg_xml: str): | |
b64 = encode_svg_xml_to_b64_string(svg_xml) | |
data_string = "data:image/svg+xml;base64," + b64 | |
return data_string | |
def html_img_preview(data_string): | |
return r'<center><img width=100 src="%s"/><center>' % data_string | |
def xml_string_to_html_img(svg_xml: str): | |
data_string = xml_string_to_data_string(svg_xml) | |
return html_img_preview(data_string) | |
# def replace_color(svg_xml: str, color_from, color_to): | |
# print( | |
# f"replace_color:svg_xml:{svg_xml}, color_from:{color_from}, color_to:{color_to} " | |
# ) | |
# if svg_xml and color_from and color_to: | |
# new_svg_xml = svg_xml.replace(color_from, color_to) | |
# data_string = xml_string_to_data_string(new_svg_xml) | |
# return data_string, new_svg_xml, html_img_preview(data_string) | |
# def switch_colors(color_from, color_to): | |
# print(f"switch_colors:color_from:{color_from}, color_to:{color_to} ") | |
# return color_to, color_from | |
# Functions - watsonx | |
def wx_prompt_change(prompt_template_name): | |
return ( | |
get_prompt_template(prompt_template_name), | |
get_prompt_example(prompt_template_name), | |
get_prompt_primer(prompt_template_name), | |
) | |
def wx_models_dropdown(wx_engine_state): | |
wx_engine = wx_engine_state | |
model_list = [] | |
default_value = None | |
recommended_model = "ibm/granite-20b-code-instruct" | |
if wx_engine is not None: | |
model_list = wx_engine.list_models() | |
default_value = ( | |
(recommended_model if recommended_model in model_list else model_list[0]), | |
) | |
return gr.Dropdown( | |
label="Model", | |
info=recommended_model + " recommended", | |
choices=model_list, | |
value=default_value, | |
) | |
def wx_connect(wx_engine_state, apiendpoint, apikey, projectid): | |
wx_engine_state = wxEngine(apiendpoint, apikey, projectid) | |
print("watsonx.ai activated") | |
return wx_engine_state, wx_models_dropdown(wx_engine_state) | |
def prepare_prompt( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_status = "Done." | |
wx_engine = wx_engine_state | |
# get model specs | |
model_max_tokens = wx_engine.get_model_max_tokens(wx_model) | |
# Add "primer" at the end of the prompt | |
prompt = wx_prompt.format(svg=xml_string, instructions=wx_instructions) + wx_primer | |
# Test and alert if prompt is too long | |
prompt_nb_tokens = wx_engine.get_prompt_nb_tokens(prompt, wx_model) | |
if prompt_nb_tokens > model_max_tokens: | |
wx_status = f"Warning: prompt length ({prompt_nb_tokens}) is more than the model max tokens ({model_max_tokens}), and will be truncated. Please review your instructions." | |
print(wx_status) | |
# calculate max new token based on xml_string - or 500 when original string is too small | |
# note: prompt will be truncated if too long with GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS in generate() | |
max_new_tokens = max(500, len(xml_string)) | |
return wx_status, max_new_tokens, prompt | |
def wx_generate( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_engine = wx_engine_state | |
wx_status, max_new_tokens, prompt = prepare_prompt( | |
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string | |
) | |
wx_result = wx_primer + wx_engine.generate_text( | |
modelid=wx_model, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
stop_sequences=["</svg>"], | |
) | |
print(f"wx_generate:wx_result:{wx_result}") | |
return wx_status, wx_result, xml_string_to_html_img(wx_result) | |
def wx_stream( | |
wx_engine_state, wx_model, wx_prompt, wx_instructions, wx_primer="", xml_string="" | |
): | |
wx_engine = wx_engine_state | |
wx_status, max_new_tokens, prompt = prepare_prompt( | |
wx_engine, wx_model, wx_prompt, wx_instructions, wx_primer, xml_string | |
) | |
wx_result = wx_primer | |
# https://www.gradio.app/guides/streaming-outputs | |
wx_result_generator = wx_engine.generate_text( | |
modelid=wx_model, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
stop_sequences=["</svg>"], | |
stream=True, | |
) | |
for chunk in wx_result_generator: | |
wx_result += chunk | |
yield f"Processing.{'.'*int(randrange(3))}", wx_result, None | |
print(f"wx_stream:wx_result:{wx_result}") | |
yield wx_status, wx_result, xml_string_to_html_img(wx_result) | |
def wx_result_box_change(wx_result): | |
return xml_string_to_html_img(wx_result) | |
def wx_accept(svg_xml): | |
data_string = xml_string_to_data_string(svg_xml) | |
return ( | |
data_string, | |
svg_xml, | |
html_img_preview(data_string), | |
xml_string_to_html_img(svg_xml), | |
) | |
# APP | |
with gr.Blocks() as demo: | |
gr.Markdown("# SVG editor") | |
with gr.Accordion("Get started:", open=True): | |
gr.Markdown( | |
""" | |
- Create a new SVG: Select the prompt template 'Create SVG', enter a description in the 'Instructions' box and click 'Submit' | |
- Modify an existing SVG: Upload an SVG file, or paste an image string or SVG XML, then Select the prompt template 'Modify SVG', enter a change instruction in the 'Instructions' box and click 'Submit' | |
- Describe an SVG: Upload, paste or generate an SVG file, Select the prompt template 'Describe SVG' and click 'Submit' """ | |
) | |
# init state - note gr.State() initial value must be deep-copyable - my wx_engine class is not | |
wx_engine_state = gr.State(None) | |
with gr.Column(): | |
# Encoded string and preview | |
with gr.Accordion("Load SVG:", open=True): | |
with gr.Row(): | |
# Upload an .svg file | |
uploaded_file = gr.File(scale=0, label="Upload an SVG file") | |
# Paste an image string | |
encoded_string_box = gr.Textbox( | |
label="Image string", | |
info="data:image/svg+xml;base64,...", | |
lines=7, | |
max_lines=7, | |
show_copy_button=True, | |
scale=3, | |
) | |
# original preview | |
encoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>") | |
# Decoded string and preview | |
with gr.Row(): | |
xml_string_box = gr.Textbox( | |
label="SVG XML", | |
lines=7, | |
max_lines=7, | |
show_copy_button=True, | |
scale=3, | |
) | |
# decoded preview | |
decoded_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>") | |
# with gr.Row(): # Color changer | |
# color_from_area = gr.ColorPicker(label="Search color:", value="#000000", scale=0) | |
# color_switch_btn = gr.Button("<->", scale=0) | |
# color_to_area = gr.ColorPicker(label="Replace color:", value="#FFFFFF", scale=0) | |
# color_replace_btn = gr.Button("Replace", scale=0) | |
with gr.Accordion("watsonx.ai:", open=True): | |
with gr.Row(): # watsonx | |
with gr.Column(scale=0): | |
with gr.Group(): | |
# credentials | |
with gr.Accordion("Credentials:", open=True): | |
wx_creds_endpoint = gr.Textbox( | |
label="Endpoint", | |
value="https://us-south.ml.cloud.ibm.com", | |
max_lines=1, | |
) | |
wx_creds_apikey = gr.Textbox(label="API key", max_lines=1) | |
wx_creds_projectid = gr.Textbox( | |
label="Project id", max_lines=1 | |
) | |
wx_connect_btn = gr.Button("Connect") | |
# model | |
wx_models_drop = wx_models_dropdown(None) | |
wx_connect_btn.click( | |
fn=wx_connect, | |
inputs=[ | |
wx_engine_state, | |
wx_creds_endpoint, | |
wx_creds_apikey, | |
wx_creds_projectid, | |
], | |
outputs=[wx_engine_state, wx_models_drop], | |
) | |
# prompt template | |
prompt_template_names = list_prompts() | |
wx_prompt_drop = gr.Dropdown( | |
label="Prompt template", | |
choices=prompt_template_names, | |
value=prompt_template_names[0], | |
) | |
wx_prompt_box = gr.Textbox( | |
info="Text", | |
show_label=False, | |
max_lines=5, | |
value=get_prompt_template(prompt_template_names[0]), | |
) | |
wx_primer_box = gr.Textbox( | |
info="Primer", | |
show_label=False, | |
max_lines=2, | |
value=get_prompt_primer(prompt_template_names[0]), | |
) | |
with gr.Column(): | |
with gr.Row(): | |
wx_instructions_box = gr.Textbox( | |
label="Instructions", | |
scale=3, | |
value=get_prompt_example(prompt_template_names[0]), | |
show_copy_button=True, | |
) | |
wx_submit_btn = gr.Button("↓Submit↓", scale=0) | |
wx_accept_btn = gr.Button("↑Accept↑", scale=0) | |
with gr.Row(): | |
wx_result_box = gr.Textbox( | |
label="Result", | |
lines=7, | |
max_lines=7, | |
scale=3, | |
show_copy_button=True, | |
) | |
wx_svg_preview = gr.HTML(f"<img src='{svg_three_dots}'/>") | |
wx_status_box = gr.Markdown("Status") | |
wx_prompt_drop.input( | |
fn=wx_prompt_change, | |
inputs=wx_prompt_drop, | |
outputs=[wx_prompt_box, wx_instructions_box, wx_primer_box], | |
) | |
wx_submit_btn.click( | |
# fn=wx_generate, | |
fn=wx_stream, | |
inputs=[ | |
wx_engine_state, | |
wx_models_drop, | |
wx_prompt_box, | |
wx_instructions_box, | |
wx_primer_box, | |
xml_string_box, | |
], | |
outputs=[wx_status_box, wx_result_box, wx_svg_preview], | |
api_name="wx_generate", | |
) | |
wx_result_box.input( | |
fn=wx_result_box_change, | |
inputs=[wx_result_box], | |
outputs=[wx_svg_preview], | |
) | |
wx_accept_btn.click( | |
fn=wx_accept, | |
inputs=[wx_result_box], | |
outputs=[ | |
encoded_string_box, | |
xml_string_box, | |
encoded_svg_preview, | |
decoded_svg_preview, | |
], | |
) | |
# Actions | |
encoded_string_box.input( | |
fn=encoded_string_box_change, | |
inputs=[encoded_string_box], | |
outputs=[xml_string_box, encoded_svg_preview, decoded_svg_preview], | |
) | |
uploaded_file.upload( | |
fn=read_file, | |
inputs=uploaded_file, | |
outputs=[ | |
encoded_string_box, | |
xml_string_box, | |
encoded_svg_preview, | |
decoded_svg_preview, | |
], | |
) | |
xml_string_box.input( | |
fn=xml_string_box_change, | |
inputs=[xml_string_box], | |
outputs=[encoded_string_box, encoded_svg_preview, decoded_svg_preview], | |
) | |
# color_switch_btn.click( | |
# fn=switch_colors, | |
# inputs=[color_from_area, color_to_area], | |
# outputs=[color_from_area, color_to_area], | |
# api_name="color_switch", | |
# ) | |
# color_replace_btn.click( | |
# fn=replace_color, | |
# inputs=[xml_string_box, color_from_area, color_to_area], | |
# outputs=[encoded_string_box, xml_string_box, svg_preview], | |
# api_name="color_replace", | |
# ) | |
# Main | |
if __name__ == "__main__": | |
demo.launch() | |