File size: 1,377 Bytes
7f36774 dbf9317 7f36774 dbf9317 7f36774 aba5ffb 7f36774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from transformers import (
MaskFormerImageProcessor,
AutoImageProcessor,
MaskFormerForInstanceSegmentation,
)
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained(
"sna89/segmentation_model"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def segment_image(img):
img_pt = processor(img, return_tensors="pt")
img_pt = img_pt.to(device)
with torch.no_grad():
outputs = model(**img_pt)
predicted_semantic_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[img.size[::-1]]
)[0]
fig, ax = plt.subplots(figsize=(5, 5))
plt.axis('off')
plt.imshow(predicted_semantic_map.to("cpu"))
fig.canvas.draw() # Render the figure
image_array = np.array(fig.canvas.renderer.buffer_rgba())
return image_array
# return predicted_semantic_map.to("cpu").numpy()
demo = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Semantic segmentation for sidewalk dataset",
examples=[["image.jpg"], ["image (1).jpg"]],
live=True
)
demo.launch(share=True) |