Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
# Load pre-trained U-Net model | |
model = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'unet', pretrained=True) | |
# Define a function to segment an image | |
def segment_image(image): | |
# Preprocess image | |
image = Image.fromarray(image) | |
image = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
])(image) | |
# Run segmentation model | |
output = model(image.unsqueeze(0)) | |
output = torch.argmax(output, dim=1) | |
# Postprocess output | |
output = output.squeeze(0).cpu().numpy() | |
output = Image.fromarray(output.astype('uint8')) | |
return output | |
# Create Gradio app | |
demo = gr.Interface( | |
fn=segment_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="Segment Anything", | |
description="Segment any image using a pre-trained U-Net model" | |
) | |
# Launch Gradio app | |
demo.launch() |