aagoluoglu's picture
Update app.py
d927c07 verified
raw
history blame
10.6 kB
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import pandas as pd
import seaborn as sns
import shinyswatch
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
import os
from transformers import SamModel, SamConfig, SamProcessor
import torch
from PIL import Image
import io
sns.set_theme()
dir = Path(__file__).resolve().parent
www_dir = Path(__file__).parent.resolve() / "www"
df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
species: List[str] = df["Species"].unique().tolist()
species.sort()
app_ui = ui.page_fillable(
shinyswatch.theme.minty(),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("tile_image", "Choose TIFF File", accept=[".tif"], multiple=False),
# Artwork by @allison_horst
ui.input_selectize(
"xvar",
"X variable",
numeric_cols,
selected="Bill Length (mm)",
),
ui.input_selectize(
"yvar",
"Y variable",
numeric_cols,
selected="Bill Depth (mm)",
),
ui.input_checkbox_group(
"species", "Filter by species", species, selected=species
),
ui.hr(),
ui.input_switch("by_species", "Show species", value=True),
ui.input_switch("show_margins", "Show marginal plots", value=True),
),
ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
ui.output_text("processed_output"),
ui.output_ui("value_boxes"),
ui.output_plot("scatter", fill=True),
ui.help_text(
"Artwork by ",
ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
class_="text-end",
),
),
)
def tif_bytes_to_pil_image(tif_bytes):
# Create a BytesIO object from the TIFF bytes
bytes_io = io.BytesIO(tif_bytes)
# Open the BytesIO object as an Image
image = Image.open(bytes_io)
return image
def load_model():
""" Get Model """
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
model = SamModel(config=model_config)
# Update the model by loading the weights from saved file
model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
model.load_state_dict(model_state_dict)
# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, processor, device
def server(input: Inputs, output: Outputs, session: Session):
# set model, processor, device once
model, processor, device = load_model()
@reactive.Calc
def uploaded_image_path() -> str:
"""Returns the path to the uploaded image"""
if input.tile_image() is not None:
return input.tile_image()[0]['datapath'] # Assuming multiple=False
else:
return "" # No image uploaded
@render.image
def uploaded_image():
"""Displays the uploaded image"""
img_src = uploaded_image_path()
if img_src:
img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"}
return img
else:
return None # Return an empty string if no image is uploaded
@reactive.Calc
def generate_input_points():
"""
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
Generally yields to much better results. The points can be obtained by passing a
list of list of list to the processor that will create corresponding torch tensors
of dimension 4. The first dimension is the image batch size, the second dimension
is the point batch size (i.e. how many segmentation masks do we want the model to
predict per input point), the third dimension is the number of points per segmentation
mask (it is possible to pass multiple points for a single mask), and the last dimension
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
of points is passed either for each image, or for each mask, the processor will create
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
embedding will be skipped for these points using the labels.
"""
# Define the size of your array
array_size = 256
# Define the size of your grid
grid_size = 10
# Generate the grid points
x = np.linspace(0, array_size-1, grid_size)
y = np.linspace(0, array_size-1, grid_size)
# Generate a grid of coordinates
xv, yv = np.meshgrid(x, y)
# Convert the numpy arrays to lists
xv_list = xv.tolist()
yv_list = yv.tolist()
# Combine the x and y coordinates into a list of list of lists
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
#We need to reshape our nxn grid to the expected shape of the input_points tensor
# (batch_size, point_batch_size, num_points_per_image, 2),
# where the last dimension of 2 represents the x and y coordinates of each point.
#batch_size: The number of images you're processing at once.
#point_batch_size: The number of point sets you have for each image.
#num_points_per_image: The number of points in each set.
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
return input_points
def process_image():
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
""" Get Image """
# Load the uploaded image
uploaded_image_bytes = input.tile_image()[0].read()
# Convert the uploaded TIFF bytes to a PIL Image object
uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
""" Prepare Inputs """
# get input points prompt (grid of points)
input_points = generate_input_points(image)
# prepare image and prompt for the model
inputs = processor(image, input_points=input_points, return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
""" Get Predictions """
# Evaluate the image with the model
# Example: predictions = model.predict(image_array)
# Return the processed result (replace 'result' with the actual processed result)
return "Processed result"
@reactive.Calc
def processed_result():
"""Processes the image when uploaded"""
if input.tile_image() is not None:
return process_image()
else:
return None
@output
@render.text
def processed_output():
"""Displays the predictions of the uploaded image"""
return processed_result()
@reactive.Calc
def filtered_df() -> pd.DataFrame:
"""Returns a Pandas data frame that includes only the desired rows"""
# This calculation "req"uires that at least one species is selected
req(len(input.species()) > 0)
# Filter the rows so we only include the desired species
return df[df["Species"].isin(input.species())]
@output
@render.plot
def scatter():
"""Generates a plot for Shiny to display to the user"""
# The plotting function to use depends on whether margins are desired
plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
plotfunc(
data=filtered_df(),
x=input.xvar(),
y=input.yvar(),
palette=palette,
hue="Species" if input.by_species() else None,
hue_order=species,
legend=False,
)
@output
@render.ui
def value_boxes():
df = filtered_df()
def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str):
return ui.value_box(
title,
count,
{"class_": "pt-1 pb-0"},
showcase=ui.fill.as_fill_item(
ui.tags.img(
{"style": "object-fit:contain;"},
src=showcase_img,
)
),
theme_color=None,
style=f"background-color: {bgcol};",
)
if not input.by_species():
return penguin_value_box(
"Penguins",
len(df.index),
bg_palette["default"],
# Artwork by @allison_horst
showcase_img="penguins.png",
)
value_boxes = [
penguin_value_box(
name,
len(df[df["Species"] == name]),
bg_palette[name],
# Artwork by @allison_horst
showcase_img=f"{name}.png",
)
for name in species
# Only include boxes for _selected_ species
if name in input.species()
]
return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
# "darkorange", "purple", "cyan4"
colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
palette: Dict[str, Tuple[float, float, float]] = {
"Adelie": colors[0],
"Chinstrap": colors[1],
"Gentoo": colors[2],
"default": sns.color_palette()[0], # type: ignore
}
bg_palette = {}
# Use `sns.set_style("whitegrid")` to help find approx alpha value
for name, col in palette.items():
# Adjusted n_colors until `axe` accessibility did not complain about color contrast
bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore
app = App(
app_ui,
server,
static_assets=str(www_dir),
)