Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import List, Dict, Tuple | |
import matplotlib.colors as mpl_colors | |
import pandas as pd | |
import seaborn as sns | |
import shinyswatch | |
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui | |
import os | |
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
from PIL import Image | |
import io | |
import numpy as np | |
import matplotlib.pyplot as plt | |
sns.set_theme() | |
dir = Path(__file__).resolve().parent | |
www_dir = Path(__file__).parent.resolve() / "www" | |
### UI ### | |
app_ui = ui.page_fillable( | |
shinyswatch.theme.minty(), | |
ui.layout_sidebar( | |
ui.sidebar( | |
ui.input_file("tile_image", "Choose an Image", accept=[".tif", ".tiff", ".png"], multiple=False), | |
), | |
#ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image, for some reason doesn't work on all accepted files | |
ui.output_plot("prediction_plots", fill=True), | |
ui.output_ui("value_boxes"), | |
ui.output_plot("scatter", fill=True), | |
ui.help_text( | |
"Project by ", | |
ui.a("@agoluoglu", href="https://github.com/agoluoglu"), | |
class_="text-end", | |
), | |
), | |
) | |
### HELPER FUNCTIONS ### | |
def bytes_to_pil_image(bytes): | |
# Create a BytesIO object from the bytes | |
bytes_io = io.BytesIO(bytes) | |
# Open the BytesIO object as an Image, crop to square, resize to 256 | |
image = Image.open(bytes_io).convert("RGB") | |
w, h = image.size | |
dim = min(w, h) | |
image = image.crop((0, 0, dim, dim)) | |
image = image.resize((256, 256)) | |
return image | |
def load_model(): | |
""" Get Model """ | |
# Load the model configuration | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Create an instance of the model architecture with the loaded configuration | |
model = SamModel(config=model_config) | |
# Update the model by loading the weights from saved file | |
model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu')) | |
model.load_state_dict(model_state_dict) | |
# set the device to cuda if available, otherwise use cpu | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
return model, processor, device | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def generate_input_points(image, grid_size=10): | |
""" | |
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) — | |
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. | |
Generally yields to much better results. The points can be obtained by passing a | |
list of list of list to the processor that will create corresponding torch tensors | |
of dimension 4. The first dimension is the image batch size, the second dimension | |
is the point batch size (i.e. how many segmentation masks do we want the model to | |
predict per input point), the third dimension is the number of points per segmentation | |
mask (it is possible to pass multiple points for a single mask), and the last dimension | |
is the x (vertical) and y (horizontal) coordinates of the point. If a different number | |
of points is passed either for each image, or for each mask, the processor will create | |
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the | |
embedding will be skipped for these points using the labels. | |
""" | |
# Get the dimensions of the image | |
array_size = max(image.width, image.height) | |
# Generate the grid points | |
x = np.linspace(0, array_size-1, grid_size) | |
y = np.linspace(0, array_size-1, grid_size) | |
# Generate a grid of coordinates | |
xv, yv = np.meshgrid(x, y) | |
# Convert the numpy arrays to lists | |
xv_list = xv.tolist() | |
yv_list = yv.tolist() | |
# Combine the x and y coordinates into a list of list of lists | |
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] | |
#We need to reshape our nxn grid to the expected shape of the input_points tensor | |
# (batch_size, point_batch_size, num_points_per_image, 2), | |
# where the last dimension of 2 represents the x and y coordinates of each point. | |
#batch_size: The number of images you're processing at once. | |
#point_batch_size: The number of point sets you have for each image. | |
#num_points_per_image: The number of points in each set. | |
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) | |
return input_points | |
### SERVER ### | |
def server(input: Inputs, output: Outputs, session: Session): | |
# set model, processor, device once | |
model, processor, device = load_model() | |
def uploaded_image_path() -> str: | |
"""Returns the path to the uploaded image""" | |
if input.tile_image() is not None: | |
print ("IMAGE PATH!!!!!!", input.tile_image()[0]['datapath']) | |
return input.tile_image()[0]['datapath'] # Assuming multiple=False | |
else: | |
return "" # No image uploaded | |
# for some reason below function does not work on all accepted files | |
# works on one screenshot that was converted to .tif but not another *shrug* | |
# @render.image | |
# def uploaded_image(): | |
# """Displays the uploaded image""" | |
# img_src = uploaded_image_path() | |
# if img_src: | |
# img: ImgData = {"src": str(img_src), "width": "200px"} | |
# print("IMAGE", img) | |
# return img | |
# else: | |
# return None # Return an empty string if no image is uploaded | |
def process_image(): | |
"""Processes the uploaded image, loads the model, and evaluates to get predictions""" | |
""" Get Image """ | |
img_src = uploaded_image_path() | |
# Read the image bytes from the file | |
with open(img_src, 'rb') as f: | |
image_bytes = f.read() | |
# Convert the image bytes to a PIL Image | |
image = bytes_to_pil_image(image_bytes) | |
""" Prepare Inputs """ | |
# get input points prompt (grid of points) | |
input_points = generate_input_points(image) | |
# prepare image and prompt for the model | |
inputs = processor(image, input_points=input_points, return_tensors="pt") | |
# # remove batch dimension which the processor adds by default | |
# inputs = {k:v.squeeze(0) for k,v in inputs.items()} | |
# Move the input tensor to the GPU if it's not already there | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
""" Get Predictions """ | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
# apply sigmoid | |
prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
prob = prob.cpu().numpy().squeeze() | |
prediction = (prob > 0.5).astype(np.uint8) | |
# Return the processed result | |
return image, prob, prediction | |
def get_predictions(): | |
"""Processes the image when uploaded to get predictions""" | |
if input.tile_image() is not None: | |
return process_image() | |
else: | |
return None, None, None | |
def prediction_plots(): | |
# get prediction results when image is uploaded | |
image, prob, prediction = get_predictions() | |
# Check if there are no predictions (i.e., no uploaded image) | |
if image is None or prob is None or prediction is None: | |
# Return a placeholder plot or message | |
fig, ax = plt.subplots() | |
ax.text(0.5, 0.5, "Upload an image to see predictions. Predictions will take a few moments to load.", fontsize=12, ha="center") | |
ax.axis("off") # Hide axis | |
plt.tight_layout() | |
return fig | |
fig, axes = plt.subplots(1, 4, figsize=(15, 30)) | |
# Extract the image data | |
#image_data = image.cpu().detach().numpy() | |
# Plot the first image on the left | |
axes[0].imshow(image) | |
axes[0].set_title("Image") | |
# Plot the probability map on the right | |
axes[1].imshow(prob) | |
axes[1].set_title("Probability Map") | |
# Plot the prediction image on the right | |
axes[2].imshow(prediction) | |
axes[2].set_title("Prediction") | |
# Plot the predicted mask on the right | |
axes[3].imshow(image) | |
show_mask(prediction, axes[3]) | |
axes[3].set_title("Predicted Mask") | |
# Hide axis ticks and labels | |
for ax in axes: | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
plt.tight_layout() | |
return fig | |
app = App( | |
app_ui, | |
server, | |
static_assets=str(www_dir), | |
) | |