Spaces:
Sleeping
Sleeping
File size: 6,997 Bytes
0a2ce36 edae4b7 0a2ce36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
import torchvision.transforms as transforms
from torch.nn import functional as F
import cv2
import gradio as gr
import numpy as np
from PIL import Image
from pipline import Transformer_Regression, extract_regions_Last , compute_ratios
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
## Define some parameters
image_shape = 384 #### 512 got 87
batch_size=1
dim_patch=4
num_classes=3
label_smoothing=0.1
scale=1
import time
start = time.time()
torch.manual_seed(0)
#import random
tfms = transforms.Compose([
transforms.Resize((image_shape, image_shape)),
transforms.ToTensor(),
transforms.Normalize(0.5,0.5)
#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2):
Model.eval()
score_cup = []
score_disc = []
yreg_pred = []
yreg_true = []
with torch.no_grad():
#for batch_sampler in loader:
train_batch_tfms = batch_sampler['image'].to(device=device)
#ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy()
ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy()
scores = Model(train_batch_tfms.unsqueeze(0))
yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear',
align_corners=True)
# Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy())
Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']),
yseg_pred.argmax(1).long()[0].detach().cpu().numpy())
Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB')
### Get back if two heads
ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1],
Regions_crop['cord'][2]:Regions_crop['cord'][3]]
ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0)
if num_head==2:
scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device))
yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]),
mode='bilinear', align_corners=True)
yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop
# yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
# Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop
yseg_pred = torch.softmax(yseg_pred, dim=1)
yseg_pred = yseg_pred.argmax(1).long()
yseg_pred = ((yseg_pred).long()).detach().cpu().numpy()
ratios = compute_ratios(yseg_pred[0])
yreg_pred.append(ratios.vcdr)
### Plot
p_img = batch_sampler['image'].to(device=device).unsqueeze(0)
p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]),
mode='bilinear', align_corners=True)
### Get reversed image
image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy()
image_orig=np.uint8(image_orig*255)
####
# train_batch_tfms
#plt.imshow(image_orig)
# make a copy as these operations are destructive
image_cont = image_orig.copy()
###### plot for Prediction....
# threshold for 2 value
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0)
# find and draw contour for 2 value (red)
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2)
#threshold for 1 value
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0)
#find and draw contour for 1 value (blue)
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2)
#plot contoured image
# plt.imshow(image_cont)
# plt.axis('off')
# print('Vertical cup to disc ratio:')
# print(ratios.vcdr)
if True:
glaucoma = 'not implemented'
# print('Galucoma:')
return image_cont, ratios.vcdr, glaucoma, Regions_crop
#load model
DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128)
DeepLab.to(device=device)
DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device)))
def infer(img):
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
sample_batch = dict()
sample_batch['image_original'] = img
im_retina_pil = Image.fromarray(img)
im_retina_pil = tfms(im_retina_pil)
sample_batch['image'] = im_retina_pil
# plt.figure('Head2')
result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2)
# cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB)
cropped = result[cropped['cord'][0] -100 :cropped['cord'][1] +100,
cropped['cord'][2] -100 :cropped['cord'][3] +100]
return ratio, diagnosis, result, cropped
title = "Glaucoma detection"
description = "Using vertical ratio"
outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')]
with gr.Blocks(css='#title {text-align : center;} ') as demo:
with gr.Row():
gr.Markdown(
f'''
# {title}
{description}
''',
elem_id='title'
)
with gr.Row():
with gr.Column():
prompt = gr.Image(label="Enter Your Retina Image")
btn = gr.Button(value='Submit')
examples = gr.Examples(
['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'],
inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False)
with gr.Column():
with gr.Row():
text1 = gr.Textbox(label="Vertical cup to disc ratio:")
text2 = gr.Textbox(label="predicted diagnosis")
img = gr.Image(label='labeled image')
zoom = gr.Image(label='zoomed in')
outputs = [text1,text2,img,zoom]
btn.click(fn=infer, inputs=prompt, outputs=outputs)
if __name__ == '__main__':
demo.launch() |