File size: 2,599 Bytes
586d4f8
 
 
fbc5057
586d4f8
 
 
 
4c1b88d
586d4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbc5057
 
 
 
586d4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e43181
 
 
 
 
 
 
 
 
 
8d22280
 
9e43181
fbc5057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e43181
8d22280
9e43181
 
8d22280
fbc5057
 
586d4f8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import shinyswatch
import run
import PIL

from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch

sns.set_theme()

www_dir = Path(__file__).parent.resolve() / "www"

app_ui = ui.page_fillable(
    shinyswatch.theme.minty(),
    ui.layout_sidebar(
        ui.sidebar(
            ui.input_file("image_input", "Upload image: ", multiple=True),
        ),
        ui.output_image("image"),
        # ui.output_image("image_output"),
        ui.output_plot("plot"),
        # ui.output_image("prediction"),
        # ui.output_image("prob")
    ),
)


def server(input: Inputs, output: Outputs, session: Session):
    @output
    @render.image
    def image():
        here = Path(__file__).parent
        if input.image_input():
            src = input.image_input()[0]['datapath']
            img = {"src": src, "width": "500px"} 
            return img
        return None

    @output
    # @render.image
    # def image_output():
    #     here = Path(__file__).parent
    #     if input.image_input():
    #         src = input.image_input()[0]['datapath']
    #         img = {"src": src, "width": "500px"} 
    #         x = run.pred(src)
    #         print(x)
    #         return img
    #     return None
    
    @output
    @render.plot 
    def plot(): 
        if input.image_input():
            new_image = input.image_input()[0]['datapath']

            pred_prob, pred_prediction = run.pred(new_image)
            
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            axes[0].imshow(new_image, cmap='gray')
            axes[0].set_title("Image")

            im = axes[1].imshow(pred_prob) 
            axes[1].set_title("Probability Map")
            cbar = fig.colorbar(im, ax=axes[1])

            axes[2].imshow(pred_prediction, cmap='gray')
            axes[2].set_title("Prediction")

            for ax in axes:
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])
        else:
            print("no image received")
            fig, _ = plt.subplots()  # Create an empty figure if no image received
            fig.text(0.5, 0.5, "No image received", ha='center', va='center', fontsize=14)
            
        return fig  


app = App(
    app_ui,
    server,
    static_assets=str(www_dir),
)