from pathlib import Path from typing import List, Dict, Tuple import matplotlib.colors as mpl_colors import pandas as pd import seaborn as sns import shinyswatch from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui import os from transformers import SamModel, SamConfig, SamProcessor import torch from PIL import Image import io sns.set_theme() dir = Path(__file__).resolve().parent www_dir = Path(__file__).parent.resolve() / "www" df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA") numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist() species: List[str] = df["Species"].unique().tolist() species.sort() app_ui = ui.page_fillable( shinyswatch.theme.minty(), ui.layout_sidebar( ui.sidebar( ui.input_file("tile_image", "Choose TIFF File", accept=[".tif"], multiple=False), # Artwork by @allison_horst ui.input_selectize( "xvar", "X variable", numeric_cols, selected="Bill Length (mm)", ), ui.input_selectize( "yvar", "Y variable", numeric_cols, selected="Bill Depth (mm)", ), ui.input_checkbox_group( "species", "Filter by species", species, selected=species ), ui.hr(), ui.input_switch("by_species", "Show species", value=True), ui.input_switch("show_margins", "Show marginal plots", value=True), ), ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image ui.output_ui("value_boxes"), ui.output_plot("scatter", fill=True), ui.help_text( "Artwork by ", ui.a("@allison_horst", href="https://twitter.com/allison_horst"), class_="text-end", ), ), ) def tif_bytes_to_pil_image(tif_bytes): # Create a BytesIO object from the TIFF bytes bytes_io = io.BytesIO(tif_bytes) # Open the BytesIO object as an Image image = Image.open(bytes_io) return image def server(input: Inputs, output: Outputs, session: Session): @reactive.Calc def uploaded_image_path() -> str: """Returns the path to the uploaded image""" if input.tile_image() is not None: return input.tile_image()[0]['datapath'] # Assuming multiple=False else: return "" # No image uploaded @render.image def uploaded_image(): """Displays the uploaded image""" img_src = uploaded_image_path() if img_src: img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"} return img else: return None # Return an empty string if no image is uploaded def process_image(): """Processes the uploaded image, loads the model, and evaluates to get predictions""" # Load the uploaded image uploaded_image_bytes = input.tile_image()[0].read() # Convert the uploaded TIFF bytes to a PIL Image object uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes) # Perform any preprocessing steps on the image as needed # Example: Convert the image to the required input format for the model # image_array = preprocess_image(uploaded_image) # Load the model configuration model_config = SamConfig.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Create an instance of the model architecture with the loaded configuration model = SamModel(config=model_config) # Update the model by loading the weights from saved file model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu')) model.load_state_dict(model_state_dict) # set the device to cuda if available, otherwise use cpu device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # Evaluate the image with the model # Example: predictions = model.predict(image_array) # Return the processed result (replace 'result' with the actual processed result) return "Processed result" @reactive.Calc def processed_result(): """Processes the image when uploaded""" if input.tile_image() is not None: return process_image() else: return None @output @render.text def processed_output(): """Displays the predictions of the uploaded image""" return processed_result() @reactive.Calc def filtered_df() -> pd.DataFrame: """Returns a Pandas data frame that includes only the desired rows""" # This calculation "req"uires that at least one species is selected req(len(input.species()) > 0) # Filter the rows so we only include the desired species return df[df["Species"].isin(input.species())] @output @render.plot def scatter(): """Generates a plot for Shiny to display to the user""" # The plotting function to use depends on whether margins are desired plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot plotfunc( data=filtered_df(), x=input.xvar(), y=input.yvar(), palette=palette, hue="Species" if input.by_species() else None, hue_order=species, legend=False, ) @output @render.ui def value_boxes(): df = filtered_df() def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str): return ui.value_box( title, count, {"class_": "pt-1 pb-0"}, showcase=ui.fill.as_fill_item( ui.tags.img( {"style": "object-fit:contain;"}, src=showcase_img, ) ), theme_color=None, style=f"background-color: {bgcol};", ) if not input.by_species(): return penguin_value_box( "Penguins", len(df.index), bg_palette["default"], # Artwork by @allison_horst showcase_img="penguins.png", ) value_boxes = [ penguin_value_box( name, len(df[df["Species"] == name]), bg_palette[name], # Artwork by @allison_horst showcase_img=f"{name}.png", ) for name in species # Only include boxes for _selected_ species if name in input.species() ] return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes)) # "darkorange", "purple", "cyan4" colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]] colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors] palette: Dict[str, Tuple[float, float, float]] = { "Adelie": colors[0], "Chinstrap": colors[1], "Gentoo": colors[2], "default": sns.color_palette()[0], # type: ignore } bg_palette = {} # Use `sns.set_style("whitegrid")` to help find approx alpha value for name, col in palette.items(): # Adjusted n_colors until `axe` accessibility did not complain about color contrast bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore app = App( app_ui, server, static_assets=str(www_dir), )