Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import AutoModel | |
from transformers import SamModel, SamConfig, SamProcessor | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True) | |
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True) | |
model = SamModel.from_pretrained("./checkpoint",local_files_only=True) | |
def get_bbox(gt_map): | |
if gt_map.ndim > 2: | |
gt_map = gt_map[:, :, 0] | |
# Check if the ground truth map is empty | |
if np.sum(gt_map) == 0: | |
return [0, 0, gt_map.shape[1], gt_map.shape[0]] | |
y_indices, x_indices = np.where(gt_map > 0) | |
x_min, x_max = np.min(x_indices), np.max(x_indices) | |
y_min, y_max = np.min(y_indices), np.max(y_indices) | |
H, W = gt_map.shape | |
x_min = max(0, x_min - np.random.randint(0, 20)) | |
x_max = min(W, x_max + np.random.randint(0, 20)) | |
y_min = max(0, y_min - np.random.randint(0, 20)) | |
y_max = min(H, y_max + np.random.randint(0, 20)) | |
bbox = [x_min,y_min,x_max,y_max] | |
return bbox | |
def process_image(image_input): | |
# Convert the input to a PIL Image and resize | |
image = Image.fromarray(image_input).convert('RGB') | |
image = image.resize((256, 256)) | |
# Create a prompt based on the image size | |
prompt = [0, 0, image.width, image.height] | |
prompt = [[prompt]] # Modify the prompt to be in the expected format for the processor | |
# Process the image and bounding box | |
inputs = processor(image, input_boxes=prompt, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Forward pass without gradient calculation | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
# Process model output | |
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
seg_prob = seg_prob.cpu().numpy().squeeze() | |
seg = (seg_prob > 0.5).astype(np.uint8) | |
# Convert numpy arrays back to PIL Images for Gradio output | |
seg_image = Image.fromarray(seg * 255) # Convert boolean mask to uint8 image | |
# prob_map = Image.fromarray((seg_prob * 255).astype(np.uint8)) # Scale probabilities to 0-255 | |
return seg_image | |
iface = gr.Interface(fn= process_image, inputs="image", outputs="image", title="Greeter") | |
iface.launch() | |