ex6 / app.py
xgenatiik's picture
Update app.py
9e532db verified
raw
history blame
1.72 kB
pip install transformers
import gradio as gr
import transformers
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
from PIL import Image
import numpy as np
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small")
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
#(21 classes)
COLORS = np.array([
[0, 0, 0],
[128, 0, 0],
[0, 128, 0],
[128, 128, 0],
[0, 0, 128],
[128, 0, 128],
[0, 128, 128],
[128, 128, 128],
[64, 0, 0],
[192, 0, 0],
[64, 128, 0],
[192, 128, 0],
[64, 0, 128],
[192, 0, 128],
[64, 128, 128],
[192, 128, 128],
[0, 64, 0],
[128, 64, 0],
[0, 192, 0],
[128, 192, 0],
[0, 64, 128],
[128, 64, 128]
], dtype=np.uint8) # Ensure the data type is uint8 for image processing
def segment_image(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_mask = logits.argmax(1).squeeze(0).numpy()
colored_mask = COLORS[predicted_mask]
colored_mask_image = Image.fromarray(colored_mask)
colored_mask_resized = colored_mask_image.resize(image.size, Image.NEAREST)
return colored_mask_resized
interface = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs="image",
title="Image Segmentation with MobileViT",
description="Upload an image to see the semantic segmentation result. The segmentation mask uses different colors to indicate different classes.",
)
interface.launch(share=True)