sam / app.py
Nguyen Thai Thao Uyen
Update file format
c44061e
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import shinyswatch
import predictor
import PIL
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch
sns.set_theme()
www_dir = Path(__file__).parent.resolve() / "www"
app_ui = ui.page_fillable(
shinyswatch.theme.minty(),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("image_input", "Upload image: ", multiple=True),
),
ui.output_image("image"),
ui.output_plot("plot")
),
)
def server(input: Inputs, output: Outputs, session: Session):
@output
@render.image
def image():
here = Path(__file__).parent
if input.image_input():
src = input.image_input()[0]['datapath']
img = {"src": src, "width": "500px"}
return img
return None
@output
@render.plot
def plot():
if input.image_input():
new_image = input.image_input()[0]['datapath']
pred_prob, pred_prediction = predictor.pred(new_image)
print("plotting...")
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].imshow(pred_prediction, cmap='gray')
axes[0].set_title("Prediction")
im = axes[1].imshow(pred_prob)
axes[1].set_title("Probability Map")
cbar = fig.colorbar(im, ax=axes[1])
for ax in axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
return fig
return None
# else:
# print("no image received")
# fig, _ = plt.subplots() # Create an empty figure if no image received
# fig.text(0.5, 0.5, "No image received", ha='center', va='center', fontsize=14)
# return None
app = App(
app_ui,
server,
static_assets=str(www_dir),
)