skinsam_demo / app.py
ahishamm's picture
Updated interface
87f6406
raw
history blame
6.41 kB
import pandas as pd
import numpy as np
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from streamlit_image_select import image_select
from tqdm import tqdm
import os
import shutil
from PIL import Image
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForMaskGeneration
def show_mask(image, mask, ax=None):
fig, axes = plt.subplots()
axes.imshow(np.array(image))
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
axes.imshow(mask_image)
canvas = FigureCanvasAgg(fig)
canvas.draw()
pil_image = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
plt.close(fig)
return pil_image
def get_bounding_box(ground_truth_map):
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
def get_output(image,prompt):
inputs = processor(image,input_boxes=[[prompt]],return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
outputs = model(**inputs,multimask_output=False)
output_proba = torch.sigmoid(outputs.pred_masks.squeeze(1))
output_proba = output_proba.cpu().numpy().squeeze()
output = (output_proba > 0.5).astype(np.uint8)
return output
def generate_image(np_array):
return Image.fromarray((np_array*255).astype('uint8'),mode='L')
def iou_calculation(result1, result2):
intersection = np.logical_and(result1, result2)
union = np.logical_or(result1, result2)
iou_score = np.sum(intersection) / np.sum(union)
iou_score = "{:.4f}".format(iou_score)
return float(iou_score)
def calculate_pixel_accuracy(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
total_pixels = width * height
image1 = image1.convert("RGB")
image2 = image2.convert("RGB")
pixels1 = image1.load()
pixels2 = image2.load()
num_correct_pixels = 0
for y in range(height):
for x in range(width):
if pixels1[x, y] == pixels2[x, y]:
num_correct_pixels += 1
accuracy = num_correct_pixels / total_pixels
return accuracy
def calculate_f1_score(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
precision = true_positives / (true_positives + false_positives + 1e-7)
recall = true_positives / (true_positives + false_negatives + 1e-7)
f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
return f1_score
def calculate_dice_coefficient(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
dice_coefficient = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
return dice_coefficient
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(layout='wide')
ds = load_dataset('ahishamm/combined_masks',split='train')
s1 = ds[3]['image']
s2 = ds[4]['image']
s3 = ds[5]['image']
s4 = ds[6]['image']
s1_label = ds[3]['label']
s2_label = ds[4]['label']
s3_label = ds[5]['label']
s4_label = ds[6]['label']
image_arr = [s1,s2,s3,s4]
label_arr = [s1_label,s2_label,s3_label,s4_label]
img = image_select(
label="Select a Skin Lesion Image",
images=[
s1,s2,s3,s4
],
captions=["sample 1","sample 2","sample 3","sample 4"],
return_value='index'
)
processor = AutoProcessor.from_pretrained('ahishamm/skinsam')
model = AutoModelForMaskGeneration.from_pretrained('ahishamm/skinsam_focalloss_base_combined')
model.to(device)
p = get_bounding_box(np.array(label_arr[img]))
predicted_mask_array = get_output(image_arr[img],p)
predicted_mask = generate_image(predicted_mask_array)
result_image = show_mask(image_arr[img],predicted_mask_array)
with st.container():
tab1, tab2 = st.tabs(['Visualizations','Metrics'])
with tab1:
col1, col2 = st.columns(2)
with col1:
st.image(image_arr[img],caption='Original Skin Lesion Image',use_column_width=True)
with col2:
st.image(result_image,caption='Mask Overlay',use_column_width=True)
with tab2:
st.write(f'The IOU Score: {iou_calculation(label_arr[img],predicted_mask)}')
st.write(f'The Pixel Accuracy: {calculate_pixel_accuracy(label_arr[img],predicted_mask)}')
st.write(f'The Dice Coefficient Score: {calculate_dice_coefficient(label_arr[img],predicted_mask)}')