skinsam_demo / app.py
ahishamm's picture
Updated interface
87f6406
import pandas as pd
import numpy as np
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from streamlit_image_select import image_select
from tqdm import tqdm
import os
import shutil
from PIL import Image
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForMaskGeneration
def show_mask(image, mask, ax=None):
fig, axes = plt.subplots()
axes.imshow(np.array(image))
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
axes.imshow(mask_image)
canvas = FigureCanvasAgg(fig)
canvas.draw()
pil_image = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
plt.close(fig)
return pil_image
def get_bounding_box(ground_truth_map):
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
def get_output(image,prompt):
inputs = processor(image,input_boxes=[[prompt]],return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
outputs = model(**inputs,multimask_output=False)
output_proba = torch.sigmoid(outputs.pred_masks.squeeze(1))
output_proba = output_proba.cpu().numpy().squeeze()
output = (output_proba > 0.5).astype(np.uint8)
return output
def generate_image(np_array):
return Image.fromarray((np_array*255).astype('uint8'),mode='L')
def iou_calculation(result1, result2):
intersection = np.logical_and(result1, result2)
union = np.logical_or(result1, result2)
iou_score = np.sum(intersection) / np.sum(union)
iou_score = "{:.4f}".format(iou_score)
return float(iou_score)
def calculate_pixel_accuracy(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
total_pixels = width * height
image1 = image1.convert("RGB")
image2 = image2.convert("RGB")
pixels1 = image1.load()
pixels2 = image2.load()
num_correct_pixels = 0
for y in range(height):
for x in range(width):
if pixels1[x, y] == pixels2[x, y]:
num_correct_pixels += 1
accuracy = num_correct_pixels / total_pixels
return accuracy
def calculate_f1_score(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
precision = true_positives / (true_positives + false_positives + 1e-7)
recall = true_positives / (true_positives + false_negatives + 1e-7)
f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
return f1_score
def calculate_dice_coefficient(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
dice_coefficient = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
return dice_coefficient
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(layout='wide')
ds = load_dataset('ahishamm/combined_masks',split='train')
s1 = ds[3]['image']
s2 = ds[4]['image']
s3 = ds[5]['image']
s4 = ds[6]['image']
s1_label = ds[3]['label']
s2_label = ds[4]['label']
s3_label = ds[5]['label']
s4_label = ds[6]['label']
image_arr = [s1,s2,s3,s4]
label_arr = [s1_label,s2_label,s3_label,s4_label]
img = image_select(
label="Select a Skin Lesion Image",
images=[
s1,s2,s3,s4
],
captions=["sample 1","sample 2","sample 3","sample 4"],
return_value='index'
)
processor = AutoProcessor.from_pretrained('ahishamm/skinsam')
model = AutoModelForMaskGeneration.from_pretrained('ahishamm/skinsam_focalloss_base_combined')
model.to(device)
p = get_bounding_box(np.array(label_arr[img]))
predicted_mask_array = get_output(image_arr[img],p)
predicted_mask = generate_image(predicted_mask_array)
result_image = show_mask(image_arr[img],predicted_mask_array)
with st.container():
tab1, tab2 = st.tabs(['Visualizations','Metrics'])
with tab1:
col1, col2 = st.columns(2)
with col1:
st.image(image_arr[img],caption='Original Skin Lesion Image',use_column_width=True)
with col2:
st.image(result_image,caption='Mask Overlay',use_column_width=True)
with tab2:
st.write(f'The IOU Score: {iou_calculation(label_arr[img],predicted_mask)}')
st.write(f'The Pixel Accuracy: {calculate_pixel_accuracy(label_arr[img],predicted_mask)}')
st.write(f'The Dice Coefficient Score: {calculate_dice_coefficient(label_arr[img],predicted_mask)}')