File size: 2,096 Bytes
eca813c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
import os

from model import VirtualStagingToolV2


def predict(image, style, color_preference):
    init_image = image.convert("RGB").resize((512, 512))
    # mask = dict["mask"].convert("RGB").resize((512, 512))

    vs_tool = VirtualStagingToolV2(diffusion_version="stabilityai/stable-diffusion-2-inpainting")
    output_images, transparent_mask_image = vs_tool.virtual_stage(
        image=init_image, style=style, color_preference=color_preference, number_images=1)
    return output_images[0], transparent_mask_image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)


image_blocks = gr.Blocks()
with image_blocks as demo:
    with gr.Group():
        with gr.Box():
            with gr.Row():
                with gr.Column():
                    image = gr.Image(source='upload', elem_id="image_upload",
                                     type="pil", label="Upload",
                                     ).style(height=400)
                    with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
                        style = gr.Dropdown(
                            ["Mordern", "Coastal", "French country"],
                            label="Design theme", elem_id="input-color"
                        )

                        color_preference = gr.Textbox(placeholder='Enter color preference',
                                                      label="Color preference", elem_id="input-color")
                        btn = gr.Button("Inpaint!").style(
                            margin=False,
                            rounded=(False, True, True, False),
                            full_width=False,
                        )
                with gr.Column():
                    mask_image = gr.Image(label="Mask image", elem_id="mask-img").style(height=400)
                    image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)

            btn.click(fn=predict, inputs=[image, style, color_preference], outputs=[image_out, mask_image])

image_blocks.launch()