File size: 3,911 Bytes
04a4a6b
 
 
 
 
0a3569d
04a4a6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e96c04
04a4a6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be5e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f0b442
 
7be5e7c
 
0d33900
7be5e7c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st 
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt


model=tf.keras.models.load_model("dental_xray_seg.h5")
    
st.header("Segmentation of Teeth in Panoramic X-ray Image")

examples=["teeth_01.png","teeth_02.png","teeth_03.png","teeth_04.png","teeth_05.png"]

def load_image(image_file):
	img = Image.open(image_file)
	return img

def convert_one_channel(img):
    #some images have 3 channels , although they are grayscale image
    if len(img.shape)>2:
        img= cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        return img
    else:
        return img
    
def convert_rgb(img):
    #some images have 3 channels , although they are grayscale image
    if len(img.shape)==2:
        img= cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)  
        return img
    else:
        return img
    
    
st.subheader("Upload Dental Panoramic X-ray Image Image")
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])

col1, col2, col3, col4, col5 = st.columns(5)
with col1:
    ex=load_image(examples[0])
    st.image(ex,width=200)
    if st.button('Example 1'):
        image_file=examples[0]

with col2:
    ex1=load_image(examples[1])
    st.image(ex1,width=200)
    if st.button('Example 2'):
        image_file=examples[1]

with col3:
    ex2=load_image(examples[2])
    st.image(ex2,width=200)
    if st.button('Example 3'):
        image_file=examples[2]
    
with col4:
    ex2=load_image(examples[3])
    st.image(ex2,width=200)
    if st.button('Example 4'):
        image_file=examples[3]

with col5:
    ex2=load_image(examples[4])
    st.image(ex2,width=200)
    if st.button('Example 5'):
        image_file=examples[4]
    
if image_file is not None:

      img=load_image(image_file)
      
      st.text("Making A Prediction ....")
      st.image(img,width=850)
      
      img=np.asarray(img)
  
      img_cv=convert_one_channel(img)
      img_cv=cv2.resize(img_cv,(512,512), interpolation=cv2.INTER_LANCZOS4)
      img_cv=np.float32(img_cv/255)
      
      img_cv=np.reshape(img_cv,(1,512,512,1))
      prediction=model.predict(img_cv)
      predicted=prediction[0]
      predicted = cv2.resize(predicted, (img.shape[1],img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
      mask=np.uint8(predicted*255)# 
      _, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY+cv2.THRESH_OTSU)
      kernel =( np.ones((5,5), dtype=np.float32))
      mask=cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel,iterations=1 )  
      mask=cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel,iterations=1 )
      cnts,hieararch=cv2.findContours(mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
      output = cv2.drawContours(convert_rgb(img), cnts, -1, (255, 0, 0) , 3)


      if output is not None :      
          st.subheader("Predicted Image")  
          st.write(output.shape)
          st.image(output,width=850)

      st.text("DONE ! ....")
      
if image_file is not None:
    img=load_image(image_file)
      
    st.text("Making A Prediction ....")
    st.image(img,width=850)
      
    img=np.asarray(img)
  
    img_cv=convert_one_channel(img)
    img_cv=cv2.resize(img_cv,(512,512), interpolation=cv2.INTER_LANCZOS4)
    img_cv=np.float32(img_cv/255)
     
    img_cv=np.reshape(img_cv,(1,512,512,1))
    predict_img=model.predict(img_cv)
#    predict=predict_img[1,:,:,0]
    plt.imsave("/content/predict.png",predict_img)
    
    ## Plotting - Пример результата
    img = cv2.imread(image_file)

    predict1 = cv2.resize(predict, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LANCZOS4)

    mask = np.uint8(predict1 * 255)
    _, mask = cv2.threshold(mask, thresh=255/2, maxval=255, type=cv2.THRESH_BINARY)
    cnts, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    img = cv2.drawContours(img, cnts, -1, (255, 0, 0), 2)
    cv2_imshow(img)