leonelhs commited on
Commit
1c39d4f
·
1 Parent(s): 001870b

add trimap generator

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -29,8 +29,9 @@ interface = Interface(pre_pipe=preprocessing,
29
  seg_pipe=seg_net)
30
 
31
 
32
- def generate_trimap(original, mask):
33
- trimap(original_image=original, mask=mask)
 
34
 
35
 
36
  def predict(image):
@@ -48,22 +49,35 @@ Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-too
48
  """
49
 
50
  with gr.Blocks(title="CarveKit") as app:
51
- gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
52
- gr.HTML(
53
- "<center><h3>Automated high-quality background removal framework for an image using neural networks.</h3></center>")
54
- with gr.Row().style(equal_height=False):
55
- with gr.Column():
56
- input_img = gr.Image(type="pil", label="Input image")
57
- run_btn = gr.Button(variant="primary")
58
- with gr.Column():
59
- output_img = gr.Image(type="pil", label="result")
60
-
61
- run_btn.click(predict, [input_img], [output_img])
62
 
63
- with gr.Row():
64
- examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
65
- examples = gr.Dataset(components=[input_img], samples=examples_data)
66
- examples.click(lambda x: x[0], [examples], [input_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  with gr.Row():
69
  gr.HTML(footer)
 
29
  seg_pipe=seg_net)
30
 
31
 
32
+ def generate_trimap(original):
33
+ mask = seg_net([original])
34
+ return trimap(original_image=original, mask=mask[0])
35
 
36
 
37
  def predict(image):
 
49
  """
50
 
51
  with gr.Blocks(title="CarveKit") as app:
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
54
+ gr.HTML("<center><h3>High-quality image background removal</h3></center>")
55
+
56
+ with gr.Tabs() as tabs:
57
+ with gr.TabItem("Remove background", id=0):
58
+ with gr.Row().style(equal_height=False):
59
+ with gr.Column():
60
+ input_img = gr.Image(type="pil", label="Input image")
61
+ run_btn = gr.Button(variant="primary")
62
+ with gr.Column():
63
+ output_img = gr.Image(type="pil", label="result")
64
+
65
+ run_btn.click(predict, [input_img], [output_img])
66
+
67
+ with gr.TabItem("Generate trimap", id=1):
68
+ with gr.Row().style(equal_height=False):
69
+ with gr.Column():
70
+ trimap_input = gr.Image(type="pil", label="Input image")
71
+ trimap_btn = gr.Button(variant="primary")
72
+ with gr.Column():
73
+ trimap_output = gr.Image(type="pil", label="result")
74
+
75
+ trimap_btn.click(generate_trimap, [trimap_input], [trimap_output])
76
+
77
+ # with gr.Row():
78
+ # examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
79
+ # examples = gr.Dataset(components=[input_img], samples=examples_data)
80
+ # examples.click(lambda x: x[0], [examples], [input_img])
81
 
82
  with gr.Row():
83
  gr.HTML(footer)