File size: 1,747 Bytes
b8e4c73
 
 
 
 
 
 
 
 
 
 
 
6370dc6
 
 
 
 
b8e4c73
6370dc6
b8e4c73
 
 
 
 
 
 
6370dc6
b8e4c73
6370dc6
 
 
b8e4c73
6370dc6
 
 
 
 
 
 
b8e4c73
 
6370dc6
 
 
b8e4c73
 
6370dc6
9f412c6
6370dc6
9f412c6
 
b8e4c73
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# All imports
import streamlit as st
import tensorflow as tf
from PIL import Image
import io
import numpy as np

def load_image():
    uploaded_file = st.file_uploader(label='Pick an image to test')
    if uploaded_file is not None:
        image_data = uploaded_file.getvalue()
        st.image(image_data)
        img = Image.open(io.BytesIO(image_data))
        img = img.resize((224,224))
        return img
    else:
        return None

def load_model():
    model_name = 'Model/model.h5'
    model = tf.keras.models.load_model(model_name) 
    return model

def load_labels():
    with open('Oxford-102_Flower_dataset_labels.txt', 'r') as file:
        data = file.read().splitlines()
    return data

def predict(model, labels, img):
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)  # Create a batch

    prediction = model.predict(img_array)
    predicted_class = np.argmax(prediction[0], axis=-1)
    
    flower = labels[predicted_class]
    closeness = np.round(prediction[0][predicted_class] * 100, 2)
    
    return flower, closeness

def main():
    st.title('Oxford 102 Flower Classification Demo')
    model = load_model()
    labels = load_labels()
    image = load_image()
    result = st.button('Run on image')
    if result and image is not None:
        st.markdown('**_Calculating results..._**')
        flower, closeness = predict(model, labels, image)
        st.markdown(f'<h3 style="color:blue;">Flower Type: <span style="color:black;">{flower}</span></h3>', unsafe_allow_html=True)
        st.markdown(f'<h3 style="color:green;">Closeness: <span style="color:black;">{closeness}%</span></h3>', unsafe_allow_html=True)

if __name__ == '__main__':
    main()