from pathlib import Path from typing import List, Dict, Tuple import matplotlib.colors as mpl_colors import pandas as pd import seaborn as sns import shinyswatch from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui import os from transformers import SamModel, SamConfig, SamProcessor import torch from PIL import Image import io import numpy as np import matplotlib.pyplot as plt sns.set_theme() dir = Path(__file__).resolve().parent www_dir = Path(__file__).parent.resolve() / "www" df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA") numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist() species: List[str] = df["Species"].unique().tolist() species.sort() ### UI ### app_ui = ui.page_fillable( shinyswatch.theme.minty(), ui.layout_sidebar( ui.sidebar( ui.input_file("tile_image", "Choose TIFF File", accept=[".tif"], multiple=False), # Artwork by @allison_horst ui.input_selectize( "xvar", "X variable", numeric_cols, selected="Bill Length (mm)", ), ui.input_selectize( "yvar", "Y variable", numeric_cols, selected="Bill Depth (mm)", ), ui.input_checkbox_group( "species", "Filter by species", species, selected=species ), ui.hr(), ui.input_switch("by_species", "Show species", value=True), ui.input_switch("show_margins", "Show marginal plots", value=True), ), ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image ui.output_plot("prediction_plots"), ui.output_ui("value_boxes"), ui.output_plot("scatter", fill=True), ui.help_text( "Artwork by ", ui.a("@allison_horst", href="https://twitter.com/allison_horst"), class_="text-end", ), ), ) ### HELPER FUNCTIONS ### def tif_bytes_to_pil_image(tif_bytes): # Create a BytesIO object from the TIFF bytes bytes_io = io.BytesIO(tif_bytes) # Open the BytesIO object as an Image, crop to square, resize to 256 image = Image.open(bytes_io).convert("RGB") w, h = image.size dim = min(w, h) image = image.crop((0, 0, dim, dim)) image = image.resize((256, 256)) return image def load_model(): """ Get Model """ # Load the model configuration model_config = SamConfig.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Create an instance of the model architecture with the loaded configuration model = SamModel(config=model_config) # Update the model by loading the weights from saved file model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu')) model.load_state_dict(model_state_dict) # set the device to cuda if available, otherwise use cpu device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) return model, processor, device def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) ### SERVER ### def server(input: Inputs, output: Outputs, session: Session): # set model, processor, device once model, processor, device = load_model() @reactive.Calc def uploaded_image_path() -> str: """Returns the path to the uploaded image""" if input.tile_image() is not None: return input.tile_image()[0]['datapath'] # Assuming multiple=False else: return "" # No image uploaded @render.image def uploaded_image(): """Displays the uploaded image""" img_src = uploaded_image_path() if img_src: img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"} return img else: return None # Return an empty string if no image is uploaded def generate_input_points(image, grid_size=10): """ input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) — Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much better results. The points can be obtained by passing a list of list of list to the processor that will create corresponding torch tensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels. """ # Get the dimensions of the image array_size = max(image.width, image.height) # Generate the grid points x = np.linspace(0, array_size-1, grid_size) y = np.linspace(0, array_size-1, grid_size) # Generate a grid of coordinates xv, yv = np.meshgrid(x, y) # Convert the numpy arrays to lists xv_list = xv.tolist() yv_list = yv.tolist() # Combine the x and y coordinates into a list of list of lists input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] #We need to reshape our nxn grid to the expected shape of the input_points tensor # (batch_size, point_batch_size, num_points_per_image, 2), # where the last dimension of 2 represents the x and y coordinates of each point. #batch_size: The number of images you're processing at once. #point_batch_size: The number of point sets you have for each image. #num_points_per_image: The number of points in each set. input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) return input_points def process_image(): """Processes the uploaded image, loads the model, and evaluates to get predictions""" """ Get Image """ img_src = uploaded_image_path() # Read the image bytes from the file with open(img_src, 'rb') as f: image_bytes = f.read() # Convert the image bytes to a PIL Image image = tif_bytes_to_pil_image(image_bytes) """ Prepare Inputs """ # get input points prompt (grid of points) input_points = generate_input_points(image) # prepare image and prompt for the model inputs = processor(image, input_points=input_points, return_tensors="pt") # # remove batch dimension which the processor adds by default # inputs = {k:v.squeeze(0) for k,v in inputs.items()} # Move the input tensor to the GPU if it's not already there inputs = {k: v.to(device) for k, v in inputs.items()} """ Get Predictions """ # forward pass with torch.no_grad(): outputs = model(**inputs, multimask_output=False) # apply sigmoid prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) # convert soft mask to hard mask prob = prob.cpu().numpy().squeeze() prediction = (prob > 0.5).astype(np.uint8) # Return the processed result return image, prob, prediction @reactive.Calc def get_predictions(): """Processes the image when uploaded to get predictions""" if input.tile_image() is not None: return process_image() else: return None, None, None @output @render.plot def prediction_plots(): # get prediction results when image is uploaded image, prob, prediction = get_predictions() # Check if there are no predictions (i.e., no uploaded image) if image is None or prob is None or prediction is None: # Return a placeholder plot or message fig, ax = plt.subplots() ax.text(0.5, 0.5, "Upload a square image to see predictions. If the image is not a square, the image will be cropped to a square, taking the top left portion of the image. Predictions will take a few moments to load.", fontsize=12, ha="center") ax.axis("off") # Hide axis plt.tight_layout() return fig fig, axes = plt.subplots(1, 4, constrained_layout=True) # Extract the image data #image_data = image.cpu().detach().numpy() # Plot the first image on the left axes[0].imshow(image) axes[0].set_title("Image") # Plot the probability map on the right axes[1].imshow(prob) axes[1].set_title("Probability Map") # Plot the prediction image on the right axes[2].imshow(prediction) axes[2].set_title("Prediction") # Plot the predicted mask on the right axes[3].imshow(image) show_mask(prediction, axes[3]) axes[3].set_title("Predicted Mask") # Hide axis ticks and labels for ax in axes: ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) plt.tight_layout() return fig @reactive.Calc def filtered_df() -> pd.DataFrame: """Returns a Pandas data frame that includes only the desired rows""" # This calculation "req"uires that at least one species is selected req(len(input.species()) > 0) # Filter the rows so we only include the desired species return df[df["Species"].isin(input.species())] @output @render.plot def scatter(): """Generates a plot for Shiny to display to the user""" # The plotting function to use depends on whether margins are desired plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot plotfunc( data=filtered_df(), x=input.xvar(), y=input.yvar(), palette=palette, hue="Species" if input.by_species() else None, hue_order=species, legend=False, ) @output @render.ui def value_boxes(): df = filtered_df() def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str): return ui.value_box( title, count, {"class_": "pt-1 pb-0"}, showcase=ui.fill.as_fill_item( ui.tags.img( {"style": "object-fit:contain;"}, src=showcase_img, ) ), theme_color=None, style=f"background-color: {bgcol};", ) if not input.by_species(): return penguin_value_box( "Penguins", len(df.index), bg_palette["default"], # Artwork by @allison_horst showcase_img="penguins.png", ) value_boxes = [ penguin_value_box( name, len(df[df["Species"] == name]), bg_palette[name], # Artwork by @allison_horst showcase_img=f"{name}.png", ) for name in species # Only include boxes for _selected_ species if name in input.species() ] return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes)) # "darkorange", "purple", "cyan4" colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]] colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors] palette: Dict[str, Tuple[float, float, float]] = { "Adelie": colors[0], "Chinstrap": colors[1], "Gentoo": colors[2], "default": sns.color_palette()[0], # type: ignore } bg_palette = {} # Use `sns.set_style("whitegrid")` to help find approx alpha value for name, col in palette.items(): # Adjusted n_colors until `axe` accessibility did not complain about color contrast bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore app = App( app_ui, server, static_assets=str(www_dir), )