Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import tensorflow as tf | |
import numpy as np | |
def load_model(): | |
"""Load a pre-trained TensorFlow model for image classification.""" | |
# Use a TensorFlow Hub model or a local TensorFlow model | |
model = tf.keras.applications.MobileNetV2( | |
input_shape=(224, 224, 3), | |
include_top=True, | |
weights="imagenet" | |
) | |
return model | |
def predict_disease(image_file): | |
"""Predicts the class of an image using TensorFlow. | |
Args: | |
image_file: The uploaded image file. | |
Returns: | |
A string representing the predicted class. | |
""" | |
try: | |
# Load the model | |
model = load_model() | |
# Process the image | |
image = Image.open(image_file).convert("RGB").resize((224, 224)) | |
image_array = np.array(image) / 255.0 | |
image_array = np.expand_dims(image_array, axis=0) | |
# Make prediction | |
predictions = model.predict(image_array) | |
predicted_class = np.argmax(predictions[0]) | |
# Get the class label from ImageNet (as an example) | |
# In a real app, you'd map this to plant diseases | |
from tensorflow.keras.applications.mobilenet_v2 import decode_predictions | |
_, label, confidence = decode_predictions(predictions, top=1)[0][0] | |
return f"{label} (confidence: {confidence:.2f})" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def main(): | |
"""Creates the Streamlit app.""" | |
st.title("Image Classification App") | |
st.caption("Note: This is using a general ImageNet model, not a plant disease model") | |
# Upload an image | |
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
# Predict the class | |
if image_file is not None: | |
# Display the image | |
image = Image.open(image_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Add a prediction button | |
if st.button("Classify Image"): | |
with st.spinner("Analyzing image..."): | |
result = predict_disease(image_file) | |
# Display the prediction | |
if result.startswith("Error"): | |
st.error(result) | |
else: | |
st.success(f"Prediction: {result}") | |
if __name__ == "__main__": | |
main() |