zerovision / app.py
bthndmn12
Fixed some bugs
4c0393b
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
def get_bbox(gt_map):
if gt_map.ndim > 2:
gt_map = gt_map[:, :, 0]
# Check if the ground truth map is empty
if np.sum(gt_map) == 0:
return [0, 0, gt_map.shape[1], gt_map.shape[0]]
y_indices, x_indices = np.where(gt_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min,y_min,x_max,y_max]
return bbox
def process_image(image_input):
# Convert the input to a PIL Image and resize
image = Image.fromarray(image_input).convert('RGB')
image = image.resize((256, 256))
# Create a prompt based on the image size
prompt = [0, 0, image.width, image.height]
prompt = [[prompt]] # Modify the prompt to be in the expected format for the processor
# Process the image and bounding box
inputs = processor(image, input_boxes=prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Forward pass without gradient calculation
model.eval()
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# Process model output
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
seg_prob = seg_prob.cpu().numpy().squeeze()
seg = (seg_prob > 0.5).astype(np.uint8)
# Convert numpy arrays back to PIL Images for Gradio output
seg_image = Image.fromarray(seg * 255) # Convert boolean mask to uint8 image
# prob_map = Image.fromarray((seg_prob * 255).astype(np.uint8)) # Scale probabilities to 0-255
return seg_image
iface = gr.Interface(fn= process_image, inputs="image", outputs="image", title="zerovision")
iface.launch()