File size: 1,377 Bytes
7f36774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf9317
 
7f36774
dbf9317
7f36774
 
aba5ffb
7f36774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import (
    MaskFormerImageProcessor,
    AutoImageProcessor,
    MaskFormerForInstanceSegmentation,
)
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np

processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained(
    "sna89/segmentation_model"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def segment_image(img):
  img_pt = processor(img, return_tensors="pt")
  img_pt = img_pt.to(device)
  with torch.no_grad():
    outputs = model(**img_pt)

  predicted_semantic_map = processor.post_process_semantic_segmentation(
    outputs, target_sizes=[img.size[::-1]]
  )[0]

  fig, ax = plt.subplots(figsize=(5, 5))
  plt.axis('off')
  plt.imshow(predicted_semantic_map.to("cpu"))
  fig.canvas.draw()  # Render the figure
  image_array = np.array(fig.canvas.renderer.buffer_rgba())
  return image_array
  # return predicted_semantic_map.to("cpu").numpy()

demo = gr.Interface(
    fn=segment_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Semantic segmentation for sidewalk dataset",
    examples=[["image.jpg"], ["image (1).jpg"]],
    live=True
)

demo.launch(share=True)