zerovision / app.py
bthndmn12
fixed some bugs
08eeae0
raw
history blame
2.24 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
def get_bbox(gt_map):
if gt_map.ndim > 2:
gt_map = gt_map[:, :, 0] # Assuming the mask is the same across all channels
# Check if the ground truth map is empty
if np.sum(gt_map) == 0:
return [0, 0, gt_map.shape[1], gt_map.shape[0]]
y_indices, x_indices = np.where(gt_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min,y_min,x_max,y_max]
return bbox
def greet(image):
image = Image.fromarray(image)
image = image.resize((256, 256))
gt_mask = np.array(image)
prompt = get_bbox(gt_mask)
inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs, multimask_outputs=False)
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
seg_prob = seg_prob.cpu().numpy().squeeze()
seg_prob = (seg_prob > 0.5).astype(np.uint8)
# Ensure the array is 2D (height, width) for grayscale image
if seg_prob.ndim > 2:
seg_prob = seg_prob.squeeze() # Remove extra dimensions if any
elif seg_prob.ndim < 2:
raise ValueError("Output mask has less than 2 dimensions")
# Convert the processed mask back to a PIL image
seg_prob_image = Image.fromarray(seg_prob)
return seg_prob_image
iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
iface.launch()