zerovision / app.py
bthndmn12
fixed bugs
7bc59eb
raw
history blame
1.87 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
from transformers import SamModel, SamConfig, SamProcessor
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True)
processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True)
model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
def get_bbox(gt_map):
if np.sum(gt_map) == 0:
return [0, 0, gt_map.shape[1], gt_map.shape[0]]
y_indices, x_indices = np.where(gt_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min,y_min,x_max,y_max]
return bbox
def greet(image):
image = Image.open(image)
image = image.resize((256, 256))
gt_mask = np.array(image)
prompt = get_bbox(gt_mask)
inputs = processor(images=image, input_boxes=[[prompt]], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs,multimask_outputs=False)
# outputs = outputs.logits[0].cpu().numpy()
# outputs = np.argmax(outputs, axis=0)
# outputs = Image.fromarray(outputs)
# return outputs
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(0))
seg_prob = seg_prob.cpu().numpy().squeeze()
seg_prob = (seg_prob > 0.5).astype(np.uint8)
return seg_prob
iface = gr.Interface(fn= greet, inputs="image", outputs="image", title="Greeter")
iface.launch()