File size: 2,935 Bytes
b064fda
ab52a15
 
b064fda
ab52a15
 
 
b064fda
 
 
263020d
b064fda
 
ab52a15
 
 
 
 
 
 
 
 
 
8b61b70
 
 
ab52a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b064fda
ab52a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import tempfile
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import torchvision.transforms as transforms
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo

os.environ["TRANSFORMERS_CACHE"] = "/code/"

from transformers import SamModel

image_resize_transform = transforms.Compose([
    transforms.Resize((1024, 1024)), 
    transforms.ToTensor()  
])

app_ui = ui.page_fluid(
    ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
    ui.output_image("original_image"),  
    ui.output_image("image_display")    
)

def server(input: Inputs, output: Outputs, session: Session):
    @reactive.calc
    def loaded_image():
        file: list[FileInfo] | None = input.file2()
        if file is None:
            return None
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model2 = SamModel.from_pretrained("facebook/sam-vit-base")
        model2.load_state_dict(torch.load('model.pth', map_location=device))
        model2.eval()
        model2.to(device)

        image = Image.open(file[0]["datapath"]).convert('RGB')
        transform = image_resize_transform
        image_tensor = transform(image).to(device)

        with torch.no_grad():
            outputs = model2(pixel_values=image_tensor.unsqueeze(0), multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1) 
        predicted_masks = predicted_masks[:, 0, :, :]
        
        mask_tensor = predicted_masks.cpu().detach().squeeze()
        mask_array = mask_tensor.numpy()  
        mask_array = (mask_array * 255).astype(np.uint8)  
        mask = Image.fromarray(mask_array)
        mask = mask.resize((1024, 1024), Image.LANCZOS)
        mask = mask.convert('RGBA')
        
        alpha = Image.new('L', mask.size, 128)
        mask.putalpha(alpha)

        image = Image.open(file[0]["datapath"]).convert('RGB')  
        image = image.resize((1024, 1024), Image.LANCZOS)  
        image = image.convert('RGBA')

        combined = Image.alpha_composite(image, mask)

        combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        image.save(original_file.name, "PNG", quality=100)
        mask.save(combined_file.name, "PNG", quality=100)
        
        return original_file.name, combined_file.name

    @render.image
    def original_image():
        result = loaded_image()
        if result is None:
            return None
        img_path, _ = result
        return {"src": img_path, "width": "300px"}

    @render.image
    def image_display():
        result = loaded_image()
        if result is None:
            return None
        _, img_path = result
        return {"src": img_path, "width": "300px"}

app = App(app_ui, server)