Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as transforms | |
from torch.nn import functional as F | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from pipline import Transformer_Regression, extract_regions_Last , compute_ratios | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
## Define some parameters | |
image_shape = 384 #### 512 got 87 | |
batch_size=1 | |
dim_patch=4 | |
num_classes=3 | |
label_smoothing=0.1 | |
scale=1 | |
import time | |
start = time.time() | |
torch.manual_seed(0) | |
#import random | |
tfms = transforms.Compose([ | |
transforms.Resize((image_shape, image_shape)), | |
transforms.ToTensor(), | |
transforms.Normalize(0.5,0.5) | |
#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
#transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) | |
]) | |
def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2): | |
Model.eval() | |
score_cup = [] | |
score_disc = [] | |
yreg_pred = [] | |
yreg_true = [] | |
with torch.no_grad(): | |
#for batch_sampler in loader: | |
train_batch_tfms = batch_sampler['image'].to(device=device) | |
#ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy() | |
ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy() | |
scores = Model(train_batch_tfms.unsqueeze(0)) | |
yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear', | |
align_corners=True) | |
# Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy()) | |
Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']), | |
yseg_pred.argmax(1).long()[0].detach().cpu().numpy()) | |
Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB') | |
### Get back if two heads | |
ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
Regions_crop['cord'][2]:Regions_crop['cord'][3]] | |
ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0) | |
if num_head==2: | |
scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device)) | |
yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]), | |
mode='bilinear', align_corners=True) | |
yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop | |
# yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
# Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop | |
yseg_pred = torch.softmax(yseg_pred, dim=1) | |
yseg_pred = yseg_pred.argmax(1).long() | |
yseg_pred = ((yseg_pred).long()).detach().cpu().numpy() | |
ratios = compute_ratios(yseg_pred[0]) | |
yreg_pred.append(ratios.vcdr) | |
### Plot | |
p_img = batch_sampler['image'].to(device=device).unsqueeze(0) | |
p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]), | |
mode='bilinear', align_corners=True) | |
### Get reversed image | |
image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() | |
image_orig=np.uint8(image_orig*255) | |
#### | |
# train_batch_tfms | |
#plt.imshow(image_orig) | |
# make a copy as these operations are destructive | |
image_cont = image_orig.copy() | |
###### plot for Prediction.... | |
# threshold for 2 value | |
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0) | |
# find and draw contour for 2 value (red) | |
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2) | |
#threshold for 1 value | |
ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0) | |
#find and draw contour for 1 value (blue) | |
conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2) | |
#plot contoured image | |
# plt.imshow(image_cont) | |
# plt.axis('off') | |
# print('Vertical cup to disc ratio:') | |
# print(ratios.vcdr) | |
if True: | |
glaucoma = 'not implemented' | |
# print('Galucoma:') | |
return image_cont, ratios.vcdr, glaucoma, Regions_crop | |
#load model | |
DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128) | |
DeepLab.to(device=device) | |
DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar")) | |
def infer(img): | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
sample_batch = dict() | |
sample_batch['image_original'] = img | |
im_retina_pil = Image.fromarray(img) | |
im_retina_pil = tfms(im_retina_pil) | |
sample_batch['image'] = im_retina_pil | |
# plt.figure('Head2') | |
result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2) | |
# cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB) | |
cropped = result[cropped['cord'][0] -100 :cropped['cord'][1] +100, | |
cropped['cord'][2] -100 :cropped['cord'][3] +100] | |
return ratio, diagnosis, result, cropped | |
title = "Glaucoma detection" | |
description = "Using vertical ratio" | |
outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')] | |
with gr.Blocks(css='#title {text-align : center;} ') as demo: | |
with gr.Row(): | |
gr.Markdown( | |
f''' | |
# {title} | |
{description} | |
''', | |
elem_id='title' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Image(label="Enter Your Retina Image") | |
btn = gr.Button(value='Submit') | |
examples = gr.Examples( | |
['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'], | |
inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False) | |
with gr.Column(): | |
with gr.Row(): | |
text1 = gr.Textbox(label="Vertical cup to disc ratio:") | |
text2 = gr.Textbox(label="predicted diagnosis") | |
img = gr.Image(label='labeled image') | |
zoom = gr.Image(label='zoomed in') | |
outputs = [text1,text2,img,zoom] | |
btn.click(fn=infer, inputs=prompt, outputs=outputs) | |
if __name__ == '__main__': | |
demo.launch() |