|
|
|
import streamlit as st |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from PIL import Image |
|
from tensorflow.keras.preprocessing import image |
|
import io |
|
from collections import Counter |
|
import numpy as np |
|
|
|
def load_image(): |
|
uploaded_file = st.file_uploader(label='Pick an image to test') |
|
if uploaded_file is not None: |
|
image_data = uploaded_file.getvalue() |
|
st.image(image_data) |
|
|
|
def load_models(): |
|
model_name = 'Model/model.h5' |
|
model = tf.keras.models.load_model(model_name) |
|
return model |
|
|
|
def load_labels(): |
|
with open('Oxford-102_Flower_dataset_labels.txt', 'r') as file: |
|
data = file.read().splitlines() |
|
flower_dict = dict(enumerate(data, 1)) |
|
return flower_dict |
|
|
|
def load_image(): |
|
uploaded_file = st.file_uploader(label='Pick an image to test') |
|
if uploaded_file is not None: |
|
image_data = uploaded_file.getvalue() |
|
st.image(image_data) |
|
img = Image.open(io.BytesIO(image_data)) |
|
img = img.resize((224,224)) |
|
return img |
|
else: |
|
return None |
|
|
|
def predict(model, categories, img): |
|
img_array = tf.keras.preprocessing.image.img_to_array(img) |
|
prediction = [img_array] |
|
prediction_test = [1] |
|
test_ds = tf.data.Dataset.from_tensor_slices((prediction, prediction_test)) |
|
test_ds = test_ds.cache().batch(32).prefetch(buffer_size = tf.data.experimental.AUTOTUNE) |
|
|
|
prediction = model.predict(test_ds) |
|
prediction_dict = dict(enumerate(prediction.flatten(), 1)) |
|
k = Counter(prediction_dict) |
|
|
|
|
|
high = k.most_common(3) |
|
|
|
percentages = [] |
|
flowers = [] |
|
for i in high: |
|
key, value = i |
|
flowers.append(categories[key]) |
|
percentages.append(np.round(value*100, 2)) |
|
return flowers, percentages |
|
|
|
def main(): |
|
st.title('Oxford 102 Flower CLassification Demo') |
|
model = load_models() |
|
categories = load_labels() |
|
image = load_image() |
|
result = st.button('Run on image') |
|
if result: |
|
st.write('Calculating results...') |
|
flowers, percentages = predict(model, categories, image) |
|
st.text(flowers) |
|
st.text(percentages) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|