File size: 2,429 Bytes
7d8db22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.imagenet_utils import preprocess_input
import os

# Load your frozen model
model = tf.keras.models.load_model("final_trashnet_transfer_learning_model.keras")

# Mapping of original classes to broader categories
class_mapping = {
    0: "Compostable",  # compostable
    1: "Recyclables",  # recyclable
    2: "Trash",  #trash
}
# Define a function to preprocess the input image
def preprocess_image(image):
    # Resize the image to 128*128 (as required by your model)
    image = image.resize((128, 128))
    # Convert the image to a NumPy array and normalize it
    img_array = img_to_array(image)
    img_array = preprocess_input(img_array)
    # Ensure the image has the correct shape (32, 32, 3)
    img_array = np.expand_dims(img_array, axis=0) 
    return img_array


# Define the prediction function
def classify_trash(image):
    processed_image = preprocess_image(image)
    predictions = model.predict(processed_image)
    print(predictions)
    class_index = np.argmax(predictions)
    confidence = np.max(predictions)
    predicted_class = class_mapping[class_index]
    return f"Predicted Category: {predicted_class}", f"Confidence: {confidence*100:.2f}"

# Function to gather example images dynamically
def get_example_images():
    example_images = []
    base_dir = "examples"
    categories = ["Compostable", "Recyclables", "Trash"]
    for category in categories:
        folder_path = os.path.join(base_dir, category)
        if os.path.exists(folder_path):
            example_images += [
                os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith((".jpg", ".png"))
            ]
    return example_images

# Example images from all categories
example_images = get_example_images()

# Define the Gradio interface
interface = gr.Interface(
    fn=classify_trash,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=[gr.Textbox(label="Predicted Category"), gr.Textbox(label="Confidence")],
    examples= example_images,
    title="Trash Classifier",
    description="Upload an image of trash, and the model will classify it into 'Compostable', 'Recyclables' and 'Trash' based on its category."
)

# Run the app
if __name__ == "__main__":
    interface.launch()