File size: 2,915 Bytes
0c43c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import io
import json
import numpy as np
import cv2

import gradio as gr

import modules.scripts as scripts
from modules import script_callbacks

from scripts.td_abg import get_foreground
from scripts.convertor import pil2cv


def processing(input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
    image = pil2cv(input_image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
    return image, mask

class Script(scripts.Script):
  def __init__(self) -> None:
    super().__init__()

  def title(self):
    return "PBRemTools"

  def show(self, is_img2img):
    return scripts.AlwaysVisible

  def ui(self, is_img2img):
    return ()

def on_ui_tabs():
    with gr.Blocks(analytics_enabled=False) as PBRemTools:
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil")
                with gr.Accordion("tile division BG Remover", open=True):
                    with gr.Box():
                      td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
                      h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
                      v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
                      
                      n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
                      alpha = gr.Slider(1, 255, value=50, step=1, label="alpha threshold", show_label=True)
                      th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
                        
                with gr.Accordion("cascadePSP", open=True):        
                    with gr.Box():
                        cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
                        fast = gr.Checkbox(label="fast", show_label=True)
                        psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)

                submit = gr.Button(value="Submit")
            with gr.Row():
                with gr.Column():
                    with gr.Tab("output"):
                        output_img = gr.Image()
                    with gr.Tab("mask"):
                        output_mask = gr.Image()
        #dummy_component = gr.Label(visible=False)
        #preset = gr.Text(visible=False)
        submit.click(
            processing, 
            inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L], 
            outputs=[output_img, output_mask]
        )

    return [(PBRemTools, "PBRemTools", "pbremtools")]

script_callbacks.on_ui_tabs(on_ui_tabs)