Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import AutoModel | |
from transformers import SamModel, SamConfig, SamProcessor | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True) | |
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True) | |
model = SamModel.from_pretrained("./checkpoint",local_files_only=True) | |
def get_bbox(gt_map): | |
if gt_map.ndim > 2: | |
gt_map = gt_map[:, :, 0] # Assuming the mask is the same across all channels | |
# Check if the ground truth map is empty | |
if np.sum(gt_map) == 0: | |
return [0, 0, gt_map.shape[1], gt_map.shape[0]] | |
y_indices, x_indices = np.where(gt_map > 0) | |
x_min, x_max = np.min(x_indices), np.max(x_indices) | |
y_min, y_max = np.min(y_indices), np.max(y_indices) | |
H, W = gt_map.shape | |
x_min = max(0, x_min - np.random.randint(0, 20)) | |
x_max = min(W, x_max + np.random.randint(0, 20)) | |
y_min = max(0, y_min - np.random.randint(0, 20)) | |
y_max = min(H, y_max + np.random.randint(0, 20)) | |
bbox = [x_min,y_min,x_max,y_max] | |
return bbox | |
def greet(image): | |
image = Image.fromarray(image) | |
image = image.resize((256, 256)) | |
gt_mask = np.array(image) | |
prompt = get_bbox(gt_mask) | |
inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs,multimask_outputs=False) | |
# outputs = outputs.logits[0].cpu().numpy() | |
# outputs = np.argmax(outputs, axis=0) | |
# outputs = Image.fromarray(outputs) | |
# return outputs | |
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0)) | |
seg_prob = seg_prob.cpu().numpy().squeeze() | |
seg_prob = (seg_prob > 0.5).astype(np.uint8) | |
return seg_prob | |
iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter") | |
iface.launch() |