ancebuc commited on
Commit
2561975
1 Parent(s): c6cb1c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
 
6
  import matplotlib.pyplot as plt
7
  import numpy as np
 
8
 
9
  import gradio as gr
10
 
@@ -14,15 +15,20 @@ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
 
16
  def visualize_segmentation(image, prompts, preds):
17
- _, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
18
- [a.axis('off') for a in ax.flatten()]
19
- ax[0].imshow(image)
20
- [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
21
- [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
 
 
 
 
 
 
22
 
23
 
24
  def segment(img, clases):
25
- print(img)
26
  image = Image.fromarray(img, 'RGB')
27
  prompts = clases.split(',')
28
 
@@ -31,7 +37,7 @@ def segment(img, clases):
31
  outputs = model(**inputs)
32
  preds = outputs.logits.unsqueeze(1)
33
 
34
- return "Hello " + prompts[0] + "!!"
35
 
36
- demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="text")
37
  demo.launch()
 
5
 
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
+ import io
9
 
10
  import gradio as gr
11
 
 
15
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
 
17
  def visualize_segmentation(image, prompts, preds):
18
+ fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
19
+ [a.axis('off') for a in ax.flatten()]
20
+ ax[0].imshow(image)
21
+ [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
22
+ [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
23
+
24
+ buf = io.BytesIO()
25
+ fig.savefig(buf, format='png')
26
+ plt.close(fig)
27
+
28
+ return Image.open(buf)
29
 
30
 
31
  def segment(img, clases):
 
32
  image = Image.fromarray(img, 'RGB')
33
  prompts = clases.split(',')
34
 
 
37
  outputs = model(**inputs)
38
  preds = outputs.logits.unsqueeze(1)
39
 
40
+ return visualize_segmentation(image, prompts, preds)
41
 
42
+ demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="image")
43
  demo.launch()