Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from clipseg import CLIPDensePredT | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.Resize((352, 352), antialias=True), | |
]) | |
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) | |
model.eval() | |
model.load_state_dict(torch.load('weights/rd64-uni.pth', | |
map_location=torch.device('cpu')), strict=False) | |
def predict(image, prompts): | |
""" | |
Predict segmentation masks for the given image based on the provided prompts. | |
Parameters: | |
- image (PIL.Image): The input image. | |
- prompts (str): A comma-separated string of prompts. | |
- Model (torch.nn): Segmentation Model. | |
Returns: | |
- tuple: A tuple containing the resized input image and a list of segmentation masks. | |
""" | |
img = transform(image).unsqueeze(0) | |
# Split the prompts string into a list of individual prompts | |
prompts = prompts.split(',') | |
num_prompts = len(prompts) | |
# Ensure no gradient computation during prediction for performance | |
with torch.no_grad(): | |
# Get model predictions for each prompt | |
preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0] | |
# Convert model predictions to segmentation masks | |
masks = [torch.sigmoid(preds[i][0]) for i in range(num_prompts)] | |
masks = [(m.squeeze(0).numpy(), prompts[i]) for i, m in enumerate(masks)] | |
# Return the resized input image and the list of segmentation masks | |
return (image.resize((352, 352), Image.LANCZOS), masks) | |
def get_examples(): | |
examples = [ | |
['images/000013.jpg', 'deer, tree, grass'], | |
['images/000002.jpg', 'train, tracks, electric pole, house'], | |
['images/00125.jpg', 'dog, flowers'], | |
['images/000010.jpg', 'horse, man, fence, buildings, hill'], | |
['images/000004.jpg', 'car, truck, building, sky, traffic light, tree, clouds'] | |
] | |
return(examples) | |
def get_html(): | |
html_string = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Multi-Prompt Image Segmentation</title> | |
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@400;700&display=swap" rel="stylesheet"> | |
<style> | |
/* General styling */ | |
body { | |
font-family: 'Roboto Slab', serif; | |
margin: 0; | |
padding: 0; | |
background-color: #f4f4f4; | |
} | |
.app-header { | |
background: linear-gradient(135deg, #4a90e2, #50e3c2); | |
color: #fff; | |
text-align: center; | |
padding: 40px 0; | |
border-radius: 20px; | |
position: relative; | |
overflow: hidden; | |
box-shadow: 0px 10px 20px rgba(0, 0, 0, 0.1); | |
} | |
/* Ellipse Overlay */ | |
.app-header::before { | |
content: ""; | |
position: absolute; | |
top: -50%; | |
left: -50%; | |
width: 200%; | |
height: 200%; | |
background: rgba(255, 255, 255, 0.1); | |
transform: rotate(45deg); | |
border-radius: 50%; | |
} | |
/* Floating Shapes */ | |
.app-header::after { | |
content: ""; | |
position: absolute; | |
top: 20%; | |
right: 10%; | |
width: 70px; | |
height: 70px; | |
background: rgba(255, 255, 255, 0.2); | |
border-radius: 50%; | |
} | |
.floating-shape { | |
content: ""; | |
position: absolute; | |
top: 10%; | |
left: 5%; | |
width: 50px; | |
height: 50px; | |
background: rgba(255, 255, 255, 0.2); | |
border-radius: 50%; | |
} | |
/* Text Styling */ | |
.app-title { | |
font-size: 28px; | |
margin: 0; | |
font-weight: 700; | |
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2); | |
} | |
.app-description { | |
font-size: 18px; | |
margin-top: 15px; | |
opacity: 0.9; | |
text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.1); | |
} | |
/* Wavy Bottom */ | |
.wavy-bottom { | |
position: absolute; | |
bottom: -10px; | |
left: 0; | |
width: 100%; | |
height: 20px; | |
background: #f4f4f4; | |
border-radius: 100% 100% 0 0; | |
} | |
</style> | |
</head> | |
<body> | |
<!-- App Header --> | |
<div class="app-header"> | |
<h1 class="app-title">Multi-Prompt Image Segmentation</h1> | |
<p class="app-description">Upload an image & provide multiple text prompts separated by commas. Get segmented masks for each prompt.</p> | |
<div class="floating-shape"></div> | |
<div class="wavy-bottom"></div> | |
</div> | |
<!-- Rest of the app content will go here --> | |
</body> | |
</html> | |
""" | |
return(html_string) |