import gradio as gr import torch from PIL import Image from torchvision import transforms # Load pre-trained U-Net model model = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'unet', pretrained=True) # Define a function to segment an image def segment_image(image): # Preprocess image image = Image.fromarray(image) image = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])(image) # Run segmentation model output = model(image.unsqueeze(0)) output = torch.argmax(output, dim=1) # Postprocess output output = output.squeeze(0).cpu().numpy() output = Image.fromarray(output.astype('uint8')) return output # Create Gradio app demo = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Segment Anything", description="Segment any image using a pre-trained U-Net model" ) # Launch Gradio app demo.launch()