import os import tempfile from pathlib import Path from PIL import Image import torch import numpy as np import torchvision.transforms as transforms from shiny import App, Inputs, Outputs, Session, reactive, render, ui from shiny.types import FileInfo os.environ["TRANSFORMERS_CACHE"] = "/code/" from transformers import SamModel image_resize_transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor() ]) app_ui = ui.page_fluid( ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False), ui.output_image("original_image"), ui.output_image("image_display") ) def server(input: Inputs, output: Outputs, session: Session): @reactive.calc def loaded_image(): file: list[FileInfo] | None = input.file2() if file is None: return None device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model2 = SamModel.from_pretrained("facebook/sam-vit-base") model2.load_state_dict(torch.load('model.pth', map_location=device)) model2.eval() model2.to(device) image = Image.open(file[0]["datapath"]).convert('RGB') transform = image_resize_transform image_tensor = transform(image).to(device) with torch.no_grad(): outputs = model2(pixel_values=image_tensor.unsqueeze(0), multimask_output=False) predicted_masks = outputs.pred_masks.squeeze(1) predicted_masks = predicted_masks[:, 0, :, :] mask_tensor = predicted_masks.cpu().detach().squeeze() mask_array = mask_tensor.numpy() mask_array = (mask_array * 255).astype(np.uint8) mask = Image.fromarray(mask_array) mask = mask.resize((1024, 1024), Image.LANCZOS) mask = mask.convert('RGBA') alpha = Image.new('L', mask.size, 128) mask.putalpha(alpha) image = Image.open(file[0]["datapath"]).convert('RGB') image = image.resize((1024, 1024), Image.LANCZOS) image = image.convert('RGBA') combined = Image.alpha_composite(image, mask) combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') image.save(original_file.name, "PNG", quality=100) mask.save(combined_file.name, "PNG", quality=100) return original_file.name, combined_file.name @render.image def original_image(): result = loaded_image() if result is None: return None img_path, _ = result return {"src": img_path, "width": "300px"} @render.image def image_display(): result = loaded_image() if result is None: return None _, img_path = result return {"src": img_path, "width": "300px"} app = App(app_ui, server)