File size: 3,674 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import subprocess
import sys
import tempfile

import gradio as gr
from PIL import Image

sys.path.append('/home/user/app/code')

# set up diffvg
# os.system('git clone https://github.com/BachiLi/diffvg.git')
os.chdir('diffvg')
os.system('git submodule update --init --recursive')
os.system('python setup.py install --user')
sys.path.append("/home/user/.local/lib/python3.10/site-packages/diffvg-0.0.1-py3.10-linux-x86_64.egg")
print("diffvg installed.")
os.chdir('/home/user/app')


def process_images(prompt, num_paths, token_index, seed, optimize_width=False, optimize_color=False):
    with tempfile.TemporaryDirectory() as tmpdirname:
        command = [
            "python", "svg_render.py",
            "x=diffsketcher",
            f"prompt={prompt}",
            f"x.num_paths={num_paths}",
            f"x.token_ind={token_index}",
            f"seed={seed}",
            f"x.optim_width={optimize_width}",
            f"x.optim_rgba={optimize_color}",
            "x.optim_opacity=False",
        ]
        result = subprocess.run(command, check=True)
        if result.returncode == 0:
            output_image = Image.open(os.path.join(tmpdirname, "final_render.png"))
    return output_image


with gr.Blocks() as demo:
    gr.Markdown("# DiffSketcher")
    gr.Markdown("DiffSketcher synthesizes **vector sketches** based on **text prompts**.")
    li = [
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/cat.svg",
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/rose.svg",
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant.svg",
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/elephant_silhouette.svg",
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_width.svg",
        "https://raw.githubusercontent.com/ximinng/DiffSketcher/main/img/horse_rgba.svg",
        "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera.svg",
        "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width.svg",
        "https://ximinng.github.io/PyTorch-SVGRender-project/assets/diffsketcher/Sydney_opera_width_color.svg",
    ]
    gr.Gallery(li, columns=6)
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(label="prompt")
            num_paths = gr.Slider(label="path number", value=96, minimum=1, maximum=500, step=1)
            token_index = gr.Textbox(label="token_index", info="CLIP embedding token index. Starting from 1.")
            seed = gr.Slider(0, 10000, label="random seed", value=8019)
            with gr.Accordion("Selectable Inputs"):
                optimize_width = gr.Checkbox(label="optimize stroke width")
                optimize_color = gr.Checkbox(label="optimize stroke color")
            btn = gr.Button("Synthesize")
        with gr.Column():
            output = gr.Image(label="output image", height=512)
    btn.click(process_images,
              inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
              outputs=[output])
    gr.Markdown("## Examples")
    gr.Markdown("Here are some config examples. Feel free to try your own prompts!")
    gr.Examples(
        inputs=[text, num_paths, token_index, seed, optimize_width, optimize_color],
        outputs=[output],
        fn=process_images,
        examples=[
            ["A photo of Sydney opera house.", 96, 5, 8019, False, False],
            ["A photo of Sydney opera house.", 96, 5, 8019, True, False],
            ["A photo of Sydney opera house.", 128, 5, 8019, True, True],
        ],
    )

demo.launch()