aagoluoglu's picture
Update app.py
3549727 verified
raw
history blame
6.14 kB
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import pandas as pd
import seaborn as sns
import shinyswatch
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from transformers import SamModel, SamConfig, SamProcessor
import torch
sns.set_theme()
www_dir = Path(__file__).parent.resolve() / "www"
df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
species: List[str] = df["Species"].unique().tolist()
species.sort()
app_ui = ui.page_fillable(
shinyswatch.theme.minty(),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("tile_image", "Choose TIFF File", accept=[".tif"], multiple=False),
# Artwork by @allison_horst
ui.input_selectize(
"xvar",
"X variable",
numeric_cols,
selected="Bill Length (mm)",
),
ui.input_selectize(
"yvar",
"Y variable",
numeric_cols,
selected="Bill Depth (mm)",
),
ui.input_checkbox_group(
"species", "Filter by species", species, selected=species
),
ui.hr(),
ui.input_switch("by_species", "Show species", value=True),
ui.input_switch("show_margins", "Show marginal plots", value=True),
),
ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
ui.output_ui("value_boxes"),
ui.output_plot("scatter", fill=True),
ui.help_text(
"Artwork by ",
ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
class_="text-end",
),
),
)
def server(input: Inputs, output: Outputs, session: Session):
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model from my fine-tuned model with the loaded configuration
model = SamModel.from_pretrained("aagoluoglu/SAM_Sidewalks", config=model_config)
# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@reactive.Calc
def uploaded_image_path() -> str:
"""Returns the path to the uploaded image"""
if input.tile_image() is not None:
return input.tile_image()[0]['datapath'] # Assuming multiple=False
else:
return "" # No image uploaded
@render.image
def uploaded_image():
"""Displays the uploaded image"""
from pathlib import Path
img_src = uploaded_image_path()
if img_src:
dir = Path(__file__).resolve().parent
img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"}
return img
else:
return None # Return an empty string if no image is uploaded
@reactive.Calc
def filtered_df() -> pd.DataFrame:
"""Returns a Pandas data frame that includes only the desired rows"""
# This calculation "req"uires that at least one species is selected
req(len(input.species()) > 0)
# Filter the rows so we only include the desired species
return df[df["Species"].isin(input.species())]
@output
@render.plot
def scatter():
"""Generates a plot for Shiny to display to the user"""
# The plotting function to use depends on whether margins are desired
plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
plotfunc(
data=filtered_df(),
x=input.xvar(),
y=input.yvar(),
palette=palette,
hue="Species" if input.by_species() else None,
hue_order=species,
legend=False,
)
@output
@render.ui
def value_boxes():
df = filtered_df()
def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str):
return ui.value_box(
title,
count,
{"class_": "pt-1 pb-0"},
showcase=ui.fill.as_fill_item(
ui.tags.img(
{"style": "object-fit:contain;"},
src=showcase_img,
)
),
theme_color=None,
style=f"background-color: {bgcol};",
)
if not input.by_species():
return penguin_value_box(
"Penguins",
len(df.index),
bg_palette["default"],
# Artwork by @allison_horst
showcase_img="penguins.png",
)
value_boxes = [
penguin_value_box(
name,
len(df[df["Species"] == name]),
bg_palette[name],
# Artwork by @allison_horst
showcase_img=f"{name}.png",
)
for name in species
# Only include boxes for _selected_ species
if name in input.species()
]
return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
# "darkorange", "purple", "cyan4"
colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
palette: Dict[str, Tuple[float, float, float]] = {
"Adelie": colors[0],
"Chinstrap": colors[1],
"Gentoo": colors[2],
"default": sns.color_palette()[0], # type: ignore
}
bg_palette = {}
# Use `sns.set_style("whitegrid")` to help find approx alpha value
for name, col in palette.items():
# Adjusted n_colors until `axe` accessibility did not complain about color contrast
bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore
app = App(
app_ui,
server,
static_assets=str(www_dir),
)