import gradio as gr import torch import numpy as np from transformers import AutoModel from transformers import SamModel, SamConfig, SamProcessor from PIL import Image import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_config = SamConfig.from_pretrained("./checkpoint",local_files_only=True) processor = SamProcessor.from_pretrained("./checkpoint",local_files_only=True) model = SamModel.from_pretrained("./checkpoint",local_files_only=True) def get_bbox(gt_map): if gt_map.ndim > 2: gt_map = gt_map[:, :, 0] # Check if the ground truth map is empty if np.sum(gt_map) == 0: return [0, 0, gt_map.shape[1], gt_map.shape[0]] y_indices, x_indices = np.where(gt_map > 0) x_min, x_max = np.min(x_indices), np.max(x_indices) y_min, y_max = np.min(y_indices), np.max(y_indices) H, W = gt_map.shape x_min = max(0, x_min - np.random.randint(0, 20)) x_max = min(W, x_max + np.random.randint(0, 20)) y_min = max(0, y_min - np.random.randint(0, 20)) y_max = min(H, y_max + np.random.randint(0, 20)) bbox = [x_min,y_min,x_max,y_max] return bbox def process_image(image_input): # Convert the input to a PIL Image and resize image = Image.fromarray(image_input).convert('RGB') image = image.resize((256, 256)) # Create a prompt based on the image size prompt = [0, 0, image.width, image.height] prompt = [[prompt]] # Modify the prompt to be in the expected format for the processor # Process the image and bounding box inputs = processor(image, input_boxes=prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Forward pass without gradient calculation model.eval() with torch.no_grad(): outputs = model(**inputs, multimask_output=False) # Process model output seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) seg_prob = seg_prob.cpu().numpy().squeeze() seg = (seg_prob > 0.5).astype(np.uint8) # Convert numpy arrays back to PIL Images for Gradio output seg_image = Image.fromarray(seg * 255) # Convert boolean mask to uint8 image # prob_map = Image.fromarray((seg_prob * 255).astype(np.uint8)) # Scale probabilities to 0-255 return seg_image iface = gr.Interface(fn= process_image, inputs="image", outputs="image", title="Greeter") iface.launch()