jacaranda-app / app.py
lily-hust's picture
Update app.py
214e0e3
raw
history blame
2.66 kB
import streamlit as st
import pandas as pd
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
def main():
st.title('Jacaranda Identification')
st.markdown("This is a Deep Learning application to identify if a satellite image clip contains Jacaranda trees.\n")
st.markdown('The predicting result will be "Jacaranda", or "Others".')
st.markdown('You can click "Browse files" multiple times until adding all images before generating prediction.\n')
run_the_app()
@st.cache_resource()#(allow_output_mutation=True)
def load_model():
# Load the network. Because this is cached it will only happen once.
model = tf.keras.models.load_model('model')
return model
@st.cache_data()
def generate_df():
dict = {'Image file name':[],
'Class name': []
}
df = pd.DataFrame(dict)
return df
@st.cache_data()
def write_df(df, file, cls):
rec = {'Image file name': file.name,
'Class name': cls}
df = pd.concat([df, pd.DataFrame([rec])], ignore_index=True)
return df
@st.cache_data()
def convert_df(df):
return df.to_csv(index=False, encoding='utf-8')
def run_the_app():
class_names = ['Jacaranda', 'Others']
model = load_model()
df = generate_df()
uploaded_files = st.file_uploader(
"Upload images",
type="jpg" or 'jpeg' or 'bmp' or 'png' or 'tif',
accept_multiple_files=True)
if uploaded_files:
st.image(uploaded_files, width=100)
if st.button("Clear uploaded images"):
st.empty()
st.experimental_rerun()
if st.button("Generate prediction"):
for file in uploaded_files:
img = Image.open(file)
img_array = img_to_array(img)
img_array = tf.expand_dims(img_array, axis = 0) # Create a batch
processed_image = preprocess_input(img_array)
predictions = model.predict(processed_image)
score = predictions[0]
cls = class_names[np.argmax(score)]
st.markdown("Predicted class of the image {} is : {}".format(file, cls))
df = write_df(df, file, cls)
csv = convert_df(df)
st.download_button("Download the results as CSV",
data = csv,
file_name = "jacaranda_identification.csv")
if __name__ == "__main__":
main()