File size: 2,159 Bytes
586d4f8
 
 
fbc5057
586d4f8
 
 
7fee2e3
4c1b88d
586d4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c44061e
7fee2e3
586d4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
8d22280
 
9e43181
fbc5057
 
 
 
7fee2e3
fbc5057
8e17922
 
fbc5057
8e17922
 
 
fbc5057
 
 
 
 
 
 
 
 
 
990d0b6
 
 
 
 
 
8d22280
990d0b6
fbc5057
586d4f8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import shinyswatch
import predictor
import PIL

from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch

sns.set_theme()

www_dir = Path(__file__).parent.resolve() / "www"

app_ui = ui.page_fillable(
    shinyswatch.theme.minty(),
    ui.layout_sidebar(
        ui.sidebar(
            ui.input_file("image_input", "Upload image: ", multiple=True),
        ),
        ui.output_image("image"),
        ui.output_plot("plot")
    ),
)


def server(input: Inputs, output: Outputs, session: Session):
    @output
    @render.image
    def image():
        here = Path(__file__).parent
        if input.image_input():
            src = input.image_input()[0]['datapath']
            img = {"src": src, "width": "500px"} 
            return img
        return None
    
    @output
    @render.plot 
    def plot(): 
        if input.image_input():
            new_image = input.image_input()[0]['datapath']

            pred_prob, pred_prediction = predictor.pred(new_image)
            
            print("plotting...")
            fig, axes = plt.subplots(1, 2, figsize=(15, 5))

            
            axes[0].imshow(pred_prediction, cmap='gray')
            axes[0].set_title("Prediction")

            im = axes[1].imshow(pred_prob) 
            axes[1].set_title("Probability Map")
            cbar = fig.colorbar(im, ax=axes[1])

            for ax in axes:
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])
            return fig
        return None
        # else:
        #     print("no image received")
        #     fig, _ = plt.subplots()  # Create an empty figure if no image received
        #     fig.text(0.5, 0.5, "No image received", ha='center', va='center', fontsize=14)
            
        # return None 


app = App(
    app_ui,
    server,
    static_assets=str(www_dir),
)