|
import tensorflow as tf |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
train_dir = 'Tomato_Plant_Disease' |
|
|
|
|
|
datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
rotation_range=20, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2, |
|
shear_range=0.2, |
|
zoom_range=0.2, |
|
horizontal_flip=True, |
|
validation_split=0.2 |
|
) |
|
|
|
|
|
train_generator = datagen.flow_from_directory( |
|
train_dir, |
|
target_size=(128, 128), |
|
batch_size=32, |
|
class_mode='binary', |
|
subset='training' |
|
) |
|
|
|
validation_generator = datagen.flow_from_directory( |
|
train_dir, |
|
target_size=(128, 128), |
|
batch_size=32, |
|
class_mode='binary', |
|
subset='validation' |
|
) |
|
|
|
|
|
model = Sequential() |
|
|
|
|
|
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3))) |
|
model.add(MaxPooling2D(pool_size=(2, 2))) |
|
|
|
|
|
model.add(Conv2D(64, (3, 3), activation='relu')) |
|
model.add(MaxPooling2D(pool_size=(2, 2))) |
|
|
|
|
|
model.add(Conv2D(128, (3, 3), activation='relu')) |
|
model.add(MaxPooling2D(pool_size=(2, 2))) |
|
|
|
|
|
model.add(Flatten()) |
|
|
|
|
|
model.add(Dense(128, activation='relu')) |
|
model.add(Dropout(0.5)) |
|
|
|
|
|
model.add(Dense(1, activation='sigmoid')) |
|
|
|
|
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) |
|
|
|
|
|
model.summary() |
|
|
|
history = model.fit( |
|
train_generator, |
|
steps_per_epoch=train_generator.samples // 32, |
|
validation_data=validation_generator, |
|
validation_steps=validation_generator.samples // 32, |
|
epochs=20 |
|
) |
|
|
|
|
|
plt.plot(history.history['accuracy']) |
|
plt.plot(history.history['val_accuracy']) |
|
plt.title('Model accuracy') |
|
plt.ylabel('Accuracy') |
|
plt.xlabel('Epoch') |
|
plt.legend(['Train', 'Validation'], loc='upper left') |
|
plt.show() |
|
|
|
|
|
plt.plot(history.history['loss']) |
|
plt.plot(history.history['val_loss']) |
|
plt.title('Model loss') |
|
plt.ylabel('Loss') |
|
plt.xlabel('Epoch') |
|
plt.legend(['Train', 'Validation'], loc='upper left') |
|
plt.show() |
|
|
|
|
|
test_dir = 'Tomato_Plant_Disease' |
|
|
|
|
|
test_datagen = ImageDataGenerator(rescale=1./255) |
|
|
|
|
|
test_generator = test_datagen.flow_from_directory( |
|
test_dir, |
|
target_size=(128, 128), |
|
batch_size=32, |
|
class_mode='binary' |
|
) |
|
|
|
test_loss, test_acc = model.evaluate(test_generator, steps=test_generator.samples // 32) |
|
print(f"Test Accuracy: {test_acc}") |
|
|
|
model.save('tomato_disease_detection_model.h5') |
|
|
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.preprocessing import image |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
model = load_model('tomato_disease_detection_model.h5') |
|
|
|
def predict_disease(img_path): |
|
|
|
img = Image.open(img_path) |
|
img = img.resize((128, 128)) |
|
img = np.array(img) |
|
img = img / 255.0 |
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
prediction = model.predict(img) |
|
|
|
|
|
if prediction[0][0] > 0.5: |
|
print("The plant is healthy.") |
|
else: |
|
print("The plant is infected.") |
|
|
|
|
|
img_path = 'Tomato_Plant_Disease/0/0045ba29-ed1b-43b4-afde-719cc7adefdb___GCREC_Bact.Sp 6254.JPG' |
|
predict_disease(img_path) |
|
|
|
import gradio as gr |
|
from tensorflow.keras.models import load_model |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
model = load_model('tomato_disease_detection_model.h5') |
|
|
|
|
|
def predict_disease(img): |
|
|
|
img = img.resize((128, 128)) |
|
img = np.array(img) |
|
img = img / 255.0 |
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
prediction = model.predict(img) |
|
|
|
|
|
if prediction[0][0] > 0.5: |
|
return "The plant is healthy." |
|
else: |
|
return "The plant is infected." |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_disease, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(), |
|
title="Tomato Plant Disease Detection", |
|
description="Upload an image of a tomato leaf to determine if it's infected or healthy." |
|
) |
|
|
|
|
|
interface.launch() |