from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import shinyswatch
import predictor
import PIL

from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch

sns.set_theme()

www_dir = Path(__file__).parent.resolve() / "www"

app_ui = ui.page_fillable(
    shinyswatch.theme.minty(),
    ui.layout_sidebar(
        ui.sidebar(
            ui.input_file("image_input", "Upload image: ", multiple=True),
        ),
        ui.output_image("image"),
        ui.output_plot("plot")
    ),
)


def server(input: Inputs, output: Outputs, session: Session):
    @output
    @render.image
    def image():
        here = Path(__file__).parent
        if input.image_input():
            src = input.image_input()[0]['datapath']
            img = {"src": src, "width": "500px"} 
            return img
        return None
    
    @output
    @render.plot 
    def plot(): 
        if input.image_input():
            new_image = input.image_input()[0]['datapath']

            pred_prob, pred_prediction = predictor.pred(new_image)
            
            print("plotting...")
            fig, axes = plt.subplots(1, 2, figsize=(15, 5))

            
            axes[0].imshow(pred_prediction, cmap='gray')
            axes[0].set_title("Prediction")

            im = axes[1].imshow(pred_prob) 
            axes[1].set_title("Probability Map")
            cbar = fig.colorbar(im, ax=axes[1])

            for ax in axes:
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])
            return fig
        return None
        # else:
        #     print("no image received")
        #     fig, _ = plt.subplots()  # Create an empty figure if no image received
        #     fig.text(0.5, 0.5, "No image received", ha='center', va='center', fontsize=14)
            
        # return None 


app = App(
    app_ui,
    server,
    static_assets=str(www_dir),
)