fkunn1326 commited on
Commit
0c43c79
Β·
1 Parent(s): 05932d9

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +64 -0
  2. requirements.txt +7 -0
  3. scripts/convertor.py +63 -0
  4. scripts/main.py +73 -0
  5. scripts/td_abg.py +122 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import cv2
4
+
5
+ from scripts.td_abg import get_foreground
6
+ from scripts.convertor import pil2cv
7
+
8
+ class webui:
9
+ def __init__(self):
10
+ self.demo = gr.Blocks()
11
+
12
+ def processing(self, input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
13
+ image = pil2cv(input_image)
14
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
15
+ mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
16
+ return image, mask
17
+
18
+ def launch(self, share):
19
+ with self.demo:
20
+ with gr.Row():
21
+ with gr.Column():
22
+ input_image = gr.Image(type="pil")
23
+ with gr.Accordion("tile division ABG", open=True):
24
+ with gr.Box():
25
+ td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
26
+ h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
27
+ v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
28
+
29
+ n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
30
+ alpha = gr.Slider(1, 255, value=100, step=1, label="alpha threshold", show_label=True)
31
+ th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
32
+
33
+ with gr.Accordion("cascadePSP", open=True):
34
+ with gr.Box():
35
+ cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
36
+ fast = gr.Checkbox(label="fast", show_label=True)
37
+ psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
38
+
39
+ submit = gr.Button(value="Submit")
40
+ with gr.Row():
41
+ with gr.Column():
42
+ with gr.Tab("output"):
43
+ output_img = gr.Image()
44
+ with gr.Tab("mask"):
45
+ output_mask = gr.Image()
46
+ submit.click(
47
+ self.processing,
48
+ inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
49
+ outputs=[output_img, output_mask]
50
+ )
51
+
52
+ self.demo.queue()
53
+ self.demo.launch(share=share)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ ui = webui()
58
+ if len(sys.argv) > 1:
59
+ if sys.argv[1] == "share":
60
+ ui.launch(share=True)
61
+ else:
62
+ ui.launch(share=False)
63
+ else:
64
+ ui.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ onnx
2
+ onnxruntime
3
+ opencv-python
4
+ numpy
5
+ pillow
6
+ segmentation-refinement
7
+ scikit-learn
scripts/convertor.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from PIL import Image
4
+
5
+ def rgb2df(img):
6
+ """
7
+ Convert an RGB image to a DataFrame.
8
+
9
+ Args:
10
+ img (np.ndarray): RGB image.
11
+
12
+ Returns:
13
+ df (pd.DataFrame): DataFrame containing the image data.
14
+ """
15
+ h, w, _ = img.shape
16
+ x_l, y_l = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
17
+ r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
18
+ df = pd.DataFrame({
19
+ "x_l": x_l.ravel(),
20
+ "y_l": y_l.ravel(),
21
+ "r": r.ravel(),
22
+ "g": g.ravel(),
23
+ "b": b.ravel(),
24
+ })
25
+ return df
26
+
27
+ def df2rgba(img_df):
28
+ """
29
+ Convert a DataFrame to an RGB image.
30
+
31
+ Args:
32
+ img_df (pd.DataFrame): DataFrame containing image data.
33
+
34
+ Returns:
35
+ img (np.ndarray): RGB image.
36
+ """
37
+ r_img = img_df.pivot_table(index="x_l", columns="y_l",values= "r").reset_index(drop=True).values
38
+ g_img = img_df.pivot_table(index="x_l", columns="y_l",values= "g").reset_index(drop=True).values
39
+ b_img = img_df.pivot_table(index="x_l", columns="y_l",values= "b").reset_index(drop=True).values
40
+ a_img = img_df.pivot_table(index="x_l", columns="y_l",values= "a").reset_index(drop=True).values
41
+ df_img = np.stack([r_img, g_img, b_img, a_img], 2).astype(np.uint8)
42
+ return df_img
43
+
44
+ def pil2cv(image):
45
+ new_image = np.array(image, dtype=np.uint8)
46
+ if new_image.ndim == 2:
47
+ pass
48
+ elif new_image.shape[2] == 3:
49
+ new_image = new_image[:, :, ::-1]
50
+ elif new_image.shape[2] == 4:
51
+ new_image = new_image[:, :, [2, 1, 0, 3]]
52
+ return new_image
53
+
54
+ def cv2pil(image):
55
+ new_image = image.copy()
56
+ if new_image.ndim == 2:
57
+ pass
58
+ elif new_image.shape[2] == 3:
59
+ new_image = new_image[:, :, ::-1]
60
+ elif new_image.shape[2] == 4:
61
+ new_image = new_image[:, :, [2, 1, 0, 3]]
62
+ new_image = Image.fromarray(new_image)
63
+ return new_image
scripts/main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import numpy as np
5
+ import cv2
6
+
7
+ import gradio as gr
8
+
9
+ import modules.scripts as scripts
10
+ from modules import script_callbacks
11
+
12
+ from scripts.td_abg import get_foreground
13
+ from scripts.convertor import pil2cv
14
+
15
+
16
+ def processing(input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
17
+ image = pil2cv(input_image)
18
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
19
+ mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
20
+ return image, mask
21
+
22
+ class Script(scripts.Script):
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+
26
+ def title(self):
27
+ return "PBRemTools"
28
+
29
+ def show(self, is_img2img):
30
+ return scripts.AlwaysVisible
31
+
32
+ def ui(self, is_img2img):
33
+ return ()
34
+
35
+ def on_ui_tabs():
36
+ with gr.Blocks(analytics_enabled=False) as PBRemTools:
37
+ with gr.Row():
38
+ with gr.Column():
39
+ input_image = gr.Image(type="pil")
40
+ with gr.Accordion("tile division BG Remover", open=True):
41
+ with gr.Box():
42
+ td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
43
+ h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
44
+ v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
45
+
46
+ n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
47
+ alpha = gr.Slider(1, 255, value=50, step=1, label="alpha threshold", show_label=True)
48
+ th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
49
+
50
+ with gr.Accordion("cascadePSP", open=True):
51
+ with gr.Box():
52
+ cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
53
+ fast = gr.Checkbox(label="fast", show_label=True)
54
+ psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
55
+
56
+ submit = gr.Button(value="Submit")
57
+ with gr.Row():
58
+ with gr.Column():
59
+ with gr.Tab("output"):
60
+ output_img = gr.Image()
61
+ with gr.Tab("mask"):
62
+ output_mask = gr.Image()
63
+ #dummy_component = gr.Label(visible=False)
64
+ #preset = gr.Text(visible=False)
65
+ submit.click(
66
+ processing,
67
+ inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
68
+ outputs=[output_img, output_mask]
69
+ )
70
+
71
+ return [(PBRemTools, "PBRemTools", "pbremtools")]
72
+
73
+ script_callbacks.on_ui_tabs(on_ui_tabs)
scripts/td_abg.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sklearn.cluster import KMeans, MiniBatchKMeans
6
+
7
+ from scripts.convertor import rgb2df, df2rgba
8
+
9
+ import gradio as gr
10
+ import huggingface_hub
11
+ import onnxruntime as rt
12
+ import copy
13
+ from PIL import Image
14
+
15
+ import segmentation_refinement as refine
16
+
17
+
18
+ # Declare Execution Providers
19
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
20
+
21
+ # Download and host the model
22
+ model_path = huggingface_hub.hf_hub_download(
23
+ "skytnt/anime-seg", "isnetis.onnx")
24
+ rmbg_model = rt.InferenceSession(model_path, providers=providers)
25
+
26
+ def get_mask(img, s=1024):
27
+ img = (img / 255).astype(np.float32)
28
+ dim = img.shape[2]
29
+ if dim == 4:
30
+ img = img[..., :3]
31
+ dim = 3
32
+ h, w = h0, w0 = img.shape[:-1]
33
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
34
+ ph, pw = s - h, s - w
35
+ img_input = np.zeros([s, s, dim], dtype=np.float32)
36
+ img_input[ph // 2:ph // 2 + h, pw //
37
+ 2:pw // 2 + w] = cv2.resize(img, (w, h))
38
+ img_input = np.transpose(img_input, (2, 0, 1))
39
+ img_input = img_input[np.newaxis, :]
40
+ mask = rmbg_model.run(None, {'img': img_input})[0][0]
41
+ mask = np.transpose(mask, (1, 2, 0))
42
+ mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
43
+ mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
44
+ return mask
45
+
46
+ def assign_tile(row, tile_width, tile_height):
47
+ tile_x = row['x_l'] // tile_width
48
+ tile_y = row['y_l'] // tile_height
49
+ return f"tile_{tile_y}_{tile_x}"
50
+
51
+ def rmbg_fn(img):
52
+ mask = get_mask(img)
53
+ img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
54
+ mask = (mask * 255).astype(np.uint8)
55
+ img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
56
+ mask = mask.repeat(3, axis=2)
57
+ return mask, img
58
+
59
+ def refinement(img, mask, fast, psp_L):
60
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
61
+ refiner = refine.Refiner(device='cuda:0') # device can also be 'cpu'
62
+
63
+ # Fast - Global step only.
64
+ # Smaller L -> Less memory usage; faster in fast mode.
65
+ mask = refiner.refine(img, mask, fast=fast, L=psp_L)
66
+
67
+ return mask
68
+
69
+
70
+ def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
71
+ if td_abg_enabled == True:
72
+ mask = get_mask(img)
73
+ mask = (mask * 255).astype(np.uint8)
74
+ mask = mask.repeat(3, axis=2)
75
+ if cascadePSP_enabled == True:
76
+ mask = refinement(img, mask, fast, psp_L)
77
+ mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
78
+ df = rgb2df(img)
79
+
80
+ image_width = img.shape[1]
81
+ image_height = img.shape[0]
82
+
83
+ num_horizontal_splits = h_split
84
+ num_vertical_splits = v_split
85
+ tile_width = image_width // num_horizontal_splits
86
+ tile_height = image_height // num_vertical_splits
87
+
88
+ df['tile'] = df.apply(assign_tile, args=(tile_width, tile_height), axis=1)
89
+
90
+ cls = MiniBatchKMeans(n_clusters=n_cluster, batch_size=100)
91
+ cls.fit(df[["r","g","b"]])
92
+ df["label"] = cls.labels_
93
+
94
+ mask_df = rgb2df(mask)
95
+ mask_df['bg_label'] = (mask_df['r'] > alpha) & (mask_df['g'] > alpha) & (mask_df['b'] > alpha)
96
+
97
+ img_df = df.copy()
98
+ img_df["bg_label"] = mask_df["bg_label"]
99
+ img_df["label"] = img_df["label"].astype(str) + "-" + img_df["tile"]
100
+ bg_rate = img_df.groupby("label").sum()["bg_label"]/img_df.groupby("label").count()["bg_label"]
101
+ img_df['bg_cls'] = (img_df['label'].isin(bg_rate[bg_rate > th_rate].index)).astype(int)
102
+ img_df.loc[img_df['bg_cls'] == 0, ['a']] = 0
103
+ img_df.loc[img_df['bg_cls'] != 0, ['a']] = 255
104
+ img = df2rgba(img_df)
105
+
106
+ if cascadePSP_enabled == True and td_abg_enabled == False:
107
+ mask = get_mask(img)
108
+ mask = (mask * 255).astype(np.uint8)
109
+ refiner = refine.Refiner(device='cuda:0')
110
+ mask = refiner.refine(img, mask, fast=fast, L=psp_L)
111
+ img = np.dstack((img, mask))
112
+
113
+ if cascadePSP_enabled == False and td_abg_enabled == False:
114
+ mask, img = rmbg_fn(img)
115
+
116
+ return mask, img
117
+
118
+
119
+
120
+
121
+
122
+