import os
import tempfile
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import torchvision.transforms as transforms
from transformers import AutoModel
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import base64
from io import BytesIO

from transformers import SamModel

image_resize_transform = transforms.Compose([
    transforms.Resize((1024, 1024)), 
    transforms.ToTensor()  
])

app_ui = ui.page_fluid(
    ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
    ui.output_image("original_image"),  
    ui.output_image("image_display")    
)

def server(input: Inputs, output: Outputs, session: Session):
    @reactive.calc
    def loaded_image():
        file: list[FileInfo] | None = input.file2()
        if file is None:
            return None
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model2 = AutoModel.from_pretrained("ansal/sidewalk-segment", device_map = 'cpu')
        model2.eval()
        model2.to('cpu')

        image = Image.open(file[0]["datapath"]).convert('RGB')
        transform = image_resize_transform
        image_tensor = transform(image).to(device)

        with torch.no_grad():
            outputs = model2(pixel_values=image_tensor.unsqueeze(0), multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1) 
        predicted_masks = predicted_masks[:, 0, :, :]
        
        mask_tensor = predicted_masks.cpu().detach().squeeze()
        mask_array = mask_tensor.numpy()  
        mask_array = np.nan_to_num(mask_array)
        mask_array = mask_array * 255
        mask_array = np.clip(mask_array, 0, 255)  
        mask_array = mask_array.astype(np.uint8) 
        mask = Image.fromarray(mask_array)
        mask = mask.resize((1024, 1024), Image.LANCZOS)
        mask = mask.convert('RGBA')
        
        alpha = Image.new('L', mask.size, 128)
        mask.putalpha(alpha)

        image = Image.open(file[0]["datapath"]).convert('RGB')  
        image = image.resize((1024, 1024), Image.LANCZOS)  
        image = image.convert('RGBA')

        combined = Image.alpha_composite(image, mask)

        combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        image.save(original_file.name, "PNG", quality=100)
        mask.save(combined_file.name, "PNG", quality=100)
        
        return original_file.name, combined_file.name

    @render.image
    def original_image():
        result = loaded_image()
        if result is None:
            return None
        img_path, _ = result
        return {"src": img_path, "width": "300px"}

    @render.image
    def image_display():
        result = loaded_image()
        if result is None:
            return None
        _, img_path = result
        return {"src": img_path, "width": "300px"}

app = App(app_ui, server)