zerovision / app.py
bthndmn12
fixed some bugs
fba9efa
raw
history blame
2.03 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
def get_bbox(gt_map):
if gt_map.ndim > 2:
gt_map = gt_map[:, :, 0] # Assuming the mask is the same across all channels
# Check if the ground truth map is empty
if np.sum(gt_map) == 0:
return [0, 0, gt_map.shape[1], gt_map.shape[0]]
y_indices, x_indices = np.where(gt_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min,y_min,x_max,y_max]
return bbox
def greet(image):
image = Image.fromarray(image)
image = image.resize((256, 256))
gt_mask = np.array(image)
prompt = get_bbox(gt_mask)
inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs,multimask_outputs=False)
# outputs = outputs.logits[0].cpu().numpy()
# outputs = np.argmax(outputs, axis=0)
# outputs = Image.fromarray(outputs)
# return outputs
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
seg_prob = seg_prob.cpu().numpy().squeeze()
seg_prob = (seg_prob > 0.5).astype(np.uint8)
return seg_prob
iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
iface.launch()