ahishamm commited on
Commit
c7dec5e
·
1 Parent(s): d1dc382

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import streamlit as st
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
7
+ from PIL import Image
8
+ from streamlit_image_select import image_select
9
+ from tqdm import tqdm
10
+ import os
11
+ import shutil
12
+ from PIL import Image
13
+ import torch
14
+ import matplotlib.pyplot as plt
15
+ from datasets import load_dataset
16
+ from transformers import AutoProcessor, AutoModelForMaskGeneration
17
+
18
+ def show_mask(image, mask, ax=None):
19
+ fig, axes = plt.subplots()
20
+ axes.imshow(np.array(image))
21
+ color = np.array([30/255, 144/255, 255/255, 0.6])
22
+ h, w = mask.shape[-2:]
23
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
24
+ axes.imshow(mask_image)
25
+ canvas = FigureCanvasAgg(fig)
26
+ canvas.draw()
27
+ pil_image = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
28
+ plt.close(fig)
29
+ return pil_image
30
+ def get_bounding_box(ground_truth_map):
31
+ y_indices, x_indices = np.where(ground_truth_map > 0)
32
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
33
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
34
+ H, W = ground_truth_map.shape
35
+ x_min = max(0, x_min - np.random.randint(0, 20))
36
+ x_max = min(W, x_max + np.random.randint(0, 20))
37
+ y_min = max(0, y_min - np.random.randint(0, 20))
38
+ y_max = min(H, y_max + np.random.randint(0, 20))
39
+ bbox = [x_min, y_min, x_max, y_max]
40
+ return bbox
41
+ def get_output(image,prompt):
42
+ inputs = processor(image,input_boxes=[[prompt]],return_tensors='pt').to(device)
43
+ model.eval()
44
+ with torch.no_grad():
45
+ outputs = model(**inputs,multimask_output=False)
46
+ output_proba = torch.sigmoid(outputs.pred_masks.squeeze(1))
47
+ output_proba = output_proba.cpu().numpy().squeeze()
48
+ output = (output_proba > 0.5).astype(np.uint8)
49
+ return output
50
+ def generate_image(np_array):
51
+ return Image.fromarray((np_array*255).astype('uint8'),mode='L')
52
+ def iou_calculation(result1, result2):
53
+ intersection = np.logical_and(result1, result2)
54
+ union = np.logical_or(result1, result2)
55
+ iou_score = np.sum(intersection) / np.sum(union)
56
+ iou_score = "{:.4f}".format(iou_score)
57
+ return float(iou_score)
58
+ def calculate_pixel_accuracy(image1, image2):
59
+ if image1.size != image2.size or image1.mode != image2.mode:
60
+ image1 = image1.resize(image2.size, Image.BILINEAR)
61
+ if image1.mode != image2.mode:
62
+ image1 = image1.convert(image2.mode)
63
+ width, height = image1.size
64
+ total_pixels = width * height
65
+ image1 = image1.convert("RGB")
66
+ image2 = image2.convert("RGB")
67
+ pixels1 = image1.load()
68
+ pixels2 = image2.load()
69
+ num_correct_pixels = 0
70
+ for y in range(height):
71
+ for x in range(width):
72
+ if pixels1[x, y] == pixels2[x, y]:
73
+ num_correct_pixels += 1
74
+ accuracy = num_correct_pixels / total_pixels
75
+ return accuracy
76
+ def calculate_f1_score(image1, image2):
77
+ if image1.size != image2.size or image1.mode != image2.mode:
78
+ image1 = image1.resize(image2.size, Image.BILINEAR)
79
+ if image1.mode != image2.mode:
80
+ image1 = image1.convert(image2.mode)
81
+ width, height = image1.size
82
+ image1 = image1.convert("L")
83
+ image2 = image2.convert("L")
84
+ np_image1 = np.array(image1)
85
+ np_image2 = np.array(image2)
86
+ np_image1_flat = np_image1.flatten()
87
+ np_image2_flat = np_image2.flatten()
88
+ true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
89
+ false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
90
+ false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
91
+ precision = true_positives / (true_positives + false_positives + 1e-7)
92
+ recall = true_positives / (true_positives + false_negatives + 1e-7)
93
+ f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
94
+ return f1_score
95
+ def calculate_dice_coefficient(image1, image2):
96
+ if image1.size != image2.size or image1.mode != image2.mode:
97
+ image1 = image1.resize(image2.size, Image.BILINEAR)
98
+ if image1.mode != image2.mode:
99
+ image1 = image1.convert(image2.mode)
100
+ width, height = image1.size
101
+ image1 = image1.convert("L")
102
+ image2 = image2.convert("L")
103
+ np_image1 = np.array(image1)
104
+ np_image2 = np.array(image2)
105
+ np_image1_flat = np_image1.flatten()
106
+ np_image2_flat = np_image2.flatten()
107
+ true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
108
+ false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
109
+ false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
110
+ dice_coefficient = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
111
+ return dice_coefficient
112
+
113
+
114
+ device = "cuda" if torch.cuda.is_available() else "cpu"
115
+ st.set_page_config(layout='wide')
116
+ ds = load_dataset('ahishamm/combined_masks',split='train')
117
+ s1 = ds[0]['image']
118
+ s2 = ds[1]['image']
119
+ s3 = ds[2]['image']
120
+ s4 = ds[3]['image']
121
+ image_arr = [s1,s2,s3,s4]
122
+ img = image_select(
123
+ label="Select a Skin Lesion Image",
124
+ images=[
125
+ s1,s2,s3,s4
126
+ ],
127
+ captions=["sample 1","sample 2","sample 3","sample 4"],
128
+ return_value='index'
129
+ )
130
+ processor = AutoProcessor.from_pretrained('ahishamm/skinsam')
131
+ model = AutoModelForMaskGeneration.from_pretrained('ahishamm/skinsam_focalloss_base_combined')
132
+ model.to(device)
133
+ p = get_bounding_box(np.array(ds[img]['label']))
134
+ predicted_mask_array = get_output(ds[img]['image'],p)
135
+ predicted_mask = generate_image(predicted_mask_array)
136
+ result_image = show_mask(ds[img]['image'],predicted_mask_array)
137
+ with st.container():
138
+ col1, col2, col3 = st.columns(3)
139
+ with col1:
140
+ st.image(ds[img]['image'],caption='Original Skin Lesion Image',use_column_width=True)
141
+ with col2:
142
+ st.image(predicted_mask,caption='Predicted Mask',use_column_width=True)
143
+ with col3:
144
+ st.write(f'The IOU Score: {iou_calculation(ds[img]["label"],predicted_mask)}')
145
+ st.write(f'The Pixel Accuracy: {calculate_pixel_accuracy(ds[img]["label"],predicted_mask)}')
146
+ st.write(f'The Dice Coefficient Score: {calculate_dice_coefficient(ds[img]["label"],predicted_mask)}')
147
+ with st.container():
148
+ col4,col5,col6 = st.columns(3)
149
+ with col5:
150
+ st.image(result_image,caption='Mask Overlay',use_column_width=True)
151
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit-image-select
2
+ streamlit
3
+ tqdm