Spaces:
Runtime error
Runtime error
File size: 13,225 Bytes
ed37a0a 3a5c18f 3549727 ace1b10 d0ed41b aa45550 ace1b10 ed37a0a efa687f ed37a0a db5d62c ed37a0a b4e14ad ed37a0a 4024483 22ca03a ed37a0a db5d62c ace1b10 e8f5595 cd92a82 3db2126 cd92a82 6d94ad3 cd92a82 ace1b10 e640189 ed37a0a d5f906d d927c07 d5f906d cd92a82 a16aa0c 6a56578 a16aa0c db5d62c ed37a0a d5f906d d927c07 d5f906d 9f87cb2 b5d0536 9f87cb2 4024483 9f87cb2 0791f7e 4dadddd 0791f7e 4dadddd d5f906d 65d4d46 d5f906d 65d4d46 d5f906d ace1b10 d5f906d db5d62c ace1b10 d5f906d 65d4d46 ace1b10 d5f906d ace1b10 db5d62c d5f906d db5d62c d5f906d db5d62c 8f831ae db5d62c 22ca03a ace1b10 22ca03a ace1b10 48b683e ace1b10 22ca03a 48b683e 34e8362 48b683e cd92a82 48b683e 1c151cf 48b683e 22ca03a 1c151cf 22ca03a cb086ee 22ca03a cb086ee 22ca03a cd92a82 22ca03a cd92a82 22ca03a cb086ee cd92a82 22ca03a 1154a94 c367928 22ca03a ace1b10 9f87cb2 ed37a0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 |
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.colors as mpl_colors
import pandas as pd
import seaborn as sns
import shinyswatch
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
import os
from transformers import SamModel, SamConfig, SamProcessor
import torch
from PIL import Image
import io
import numpy as np
import matplotlib.pyplot as plt
sns.set_theme()
dir = Path(__file__).resolve().parent
www_dir = Path(__file__).parent.resolve() / "www"
df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
species: List[str] = df["Species"].unique().tolist()
species.sort()
### UI ###
app_ui = ui.page_fillable(
shinyswatch.theme.minty(),
ui.layout_sidebar(
ui.sidebar(
ui.input_file("tile_image", "Choose TIFF File", accept=[".tif"], multiple=False),
# Artwork by @allison_horst
ui.input_selectize(
"xvar",
"X variable",
numeric_cols,
selected="Bill Length (mm)",
),
ui.input_selectize(
"yvar",
"Y variable",
numeric_cols,
selected="Bill Depth (mm)",
),
ui.input_checkbox_group(
"species", "Filter by species", species, selected=species
),
ui.hr(),
ui.input_switch("by_species", "Show species", value=True),
ui.input_switch("show_margins", "Show marginal plots", value=True),
),
ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
ui.output_plot("prediction_plots"),
ui.output_ui("value_boxes"),
ui.output_plot("scatter", fill=True),
ui.help_text(
"Artwork by ",
ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
class_="text-end",
),
),
)
### HELPER FUNCTIONS ###
def tif_bytes_to_pil_image(tif_bytes):
# Create a BytesIO object from the TIFF bytes
bytes_io = io.BytesIO(tif_bytes)
# Open the BytesIO object as an Image, crop to square, resize to 256
image = Image.open(bytes_io).convert("RGB")
w, h = image.size
dim = min(w, h)
image = image.crop((0, 0, dim, dim))
image = image.resize((256, 256))
return image
def load_model():
""" Get Model """
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
model = SamModel(config=model_config)
# Update the model by loading the weights from saved file
model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
model.load_state_dict(model_state_dict)
# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, processor, device
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image, extent=(0, width, height, 0)) # Setting extent to match original image dimensions
### SERVER ###
def server(input: Inputs, output: Outputs, session: Session):
# set model, processor, device once
model, processor, device = load_model()
@reactive.Calc
def uploaded_image_path() -> str:
"""Returns the path to the uploaded image"""
if input.tile_image() is not None:
return input.tile_image()[0]['datapath'] # Assuming multiple=False
else:
return "" # No image uploaded
@render.image
def uploaded_image():
"""Displays the uploaded image"""
img_src = uploaded_image_path()
if img_src:
img: ImgData = {"src": str(dir / uploaded_image_path()), "width": "100px"}
return img
else:
return None # Return an empty string if no image is uploaded
def generate_input_points(image, grid_size=10):
"""
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
Generally yields to much better results. The points can be obtained by passing a
list of list of list to the processor that will create corresponding torch tensors
of dimension 4. The first dimension is the image batch size, the second dimension
is the point batch size (i.e. how many segmentation masks do we want the model to
predict per input point), the third dimension is the number of points per segmentation
mask (it is possible to pass multiple points for a single mask), and the last dimension
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
of points is passed either for each image, or for each mask, the processor will create
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
embedding will be skipped for these points using the labels.
"""
# Get the dimensions of the image
array_size = max(image.width, image.height)
# Generate the grid points
x = np.linspace(0, array_size-1, grid_size)
y = np.linspace(0, array_size-1, grid_size)
# Generate a grid of coordinates
xv, yv = np.meshgrid(x, y)
# Convert the numpy arrays to lists
xv_list = xv.tolist()
yv_list = yv.tolist()
# Combine the x and y coordinates into a list of list of lists
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
#We need to reshape our nxn grid to the expected shape of the input_points tensor
# (batch_size, point_batch_size, num_points_per_image, 2),
# where the last dimension of 2 represents the x and y coordinates of each point.
#batch_size: The number of images you're processing at once.
#point_batch_size: The number of point sets you have for each image.
#num_points_per_image: The number of points in each set.
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
return input_points
def process_image():
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
""" Get Image """
img_src = uploaded_image_path()
# Read the image bytes from the file
with open(img_src, 'rb') as f:
image_bytes = f.read()
# Convert the image bytes to a PIL Image
image = tif_bytes_to_pil_image(image_bytes)
""" Prepare Inputs """
# get input points prompt (grid of points)
input_points = generate_input_points(image)
# prepare image and prompt for the model
inputs = processor(image, input_points=input_points, return_tensors="pt")
# # remove batch dimension which the processor adds by default
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
# Move the input tensor to the GPU if it's not already there
inputs = {k: v.to(device) for k, v in inputs.items()}
""" Get Predictions """
# forward pass
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# apply sigmoid
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
prob = prob.cpu().numpy().squeeze()
prediction = (prob > 0.5).astype(np.uint8)
# Return the processed result
return image, prob, prediction
@reactive.Calc
def get_predictions():
"""Processes the image when uploaded to get predictions"""
if input.tile_image() is not None:
return process_image()
else:
return None, None, None
@output
@render.plot
def prediction_plots():
# get prediction results when image is uploaded
image, prob, prediction = get_predictions()
# Check if there are no predictions (i.e., no uploaded image)
if image is None or prob is None or prediction is None:
# Return a placeholder plot or message
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Upload a square image to see predictions. If the image is not a square, the image will be cropped to a square, taking the top left portion of the image. Predictions will take a few moments to load.", fontsize=12, ha="center")
ax.axis("off") # Hide axis
plt.tight_layout()
return fig
fig, axes = plt.subplots(1, 4, constrained_layout=True)
# Extract the image data
#image_data = image.cpu().detach().numpy()
# Plot the first image on the left
axes[0].imshow(image)
axes[0].set_title("Image")
# Plot the probability map on the right
axes[1].imshow(prob)
axes[1].set_title("Probability Map")
# Plot the prediction image on the right
axes[2].imshow(prediction)
axes[2].set_title("Prediction")
# Plot the predicted mask on the right
axes[3].imshow(image)
show_mask(prediction, axes[3])
axes[3].set_title("Predicted Mask")
# Hide axis ticks and labels
for ax in axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.tight_layout()
return fig
@reactive.Calc
def filtered_df() -> pd.DataFrame:
"""Returns a Pandas data frame that includes only the desired rows"""
# This calculation "req"uires that at least one species is selected
req(len(input.species()) > 0)
# Filter the rows so we only include the desired species
return df[df["Species"].isin(input.species())]
@output
@render.plot
def scatter():
"""Generates a plot for Shiny to display to the user"""
# The plotting function to use depends on whether margins are desired
plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
plotfunc(
data=filtered_df(),
x=input.xvar(),
y=input.yvar(),
palette=palette,
hue="Species" if input.by_species() else None,
hue_order=species,
legend=False,
)
@output
@render.ui
def value_boxes():
df = filtered_df()
def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str):
return ui.value_box(
title,
count,
{"class_": "pt-1 pb-0"},
showcase=ui.fill.as_fill_item(
ui.tags.img(
{"style": "object-fit:contain;"},
src=showcase_img,
)
),
theme_color=None,
style=f"background-color: {bgcol};",
)
if not input.by_species():
return penguin_value_box(
"Penguins",
len(df.index),
bg_palette["default"],
# Artwork by @allison_horst
showcase_img="penguins.png",
)
value_boxes = [
penguin_value_box(
name,
len(df[df["Species"] == name]),
bg_palette[name],
# Artwork by @allison_horst
showcase_img=f"{name}.png",
)
for name in species
# Only include boxes for _selected_ species
if name in input.species()
]
return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
# "darkorange", "purple", "cyan4"
colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
palette: Dict[str, Tuple[float, float, float]] = {
"Adelie": colors[0],
"Chinstrap": colors[1],
"Gentoo": colors[2],
"default": sns.color_palette()[0], # type: ignore
}
bg_palette = {}
# Use `sns.set_style("whitegrid")` to help find approx alpha value
for name, col in palette.items():
# Adjusted n_colors until `axe` accessibility did not complain about color contrast
bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore
app = App(
app_ui,
server,
static_assets=str(www_dir),
)
|