hb-setosys commited on
Commit
1237637
·
verified ·
1 Parent(s): af590ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -27
app.py CHANGED
@@ -1,33 +1,30 @@
1
- from tensorflow.keras.models import load_model
2
- from tensorflow.keras.preprocessing.image import load_img, img_to_array
3
- from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
4
  import numpy as np
5
- from google.colab import files
6
-
7
- # Step 1: Upload the model to Colab (run this in a Colab cell)
8
- uploaded = files.upload() # Upload the .h5 model file
9
 
10
- # Step 2: Load the trained model
11
- MODEL_PATH = "setosys_dogs_model.h5" # Adjust if needed
12
- model = load_model(MODEL_PATH)
13
 
14
- # Step 3: Define the class labels manually (as per your model's training setup)
15
- # You need to know the classes that were used during model training
16
- class_labels = ["Labrador Retriever", "German Shepherd", "Golden Retriever", "Bulldog", "Poodle"] # Example, update this list
17
 
18
- # Step 4: Define image preprocessing function for EfficientNetV2
19
- def preprocess_image(image_path):
20
  """Preprocess the image to match the model's input requirements."""
21
- img = load_img(image_path, target_size=(224, 224)) # Resize image to model input size
22
- img_array = img_to_array(img)
23
  img_array = preprocess_input(img_array) # EfficientNetV2 preprocessing
24
  img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
25
  return img_array
26
 
27
- # Step 5: Define prediction function
28
- def predict_dog_breed(image_path):
29
  """Predict the breed of the dog in the uploaded image."""
30
- img_array = preprocess_image(image_path)
31
  predictions = model.predict(img_array)
32
 
33
  # Check the shape of the predictions to make sure the output is correct
@@ -39,12 +36,17 @@ def predict_dog_breed(image_path):
39
  # Get predicted breed and its confidence score
40
  predicted_breed = class_labels[class_idx] if class_idx < len(class_labels) else "Unknown"
41
 
42
- return predicted_breed, confidence
43
 
44
- # Step 6: Upload and test with an image
45
- uploaded_image = files.upload() # Upload a test image
46
- image_path = list(uploaded_image.keys())[0] # Get the filename of the uploaded image
 
 
 
 
 
47
 
48
- # Step 7: Run prediction
49
- breed, confidence = predict_dog_breed(image_path)
50
- print(f"Predicted Breed: {breed}, Confidence: {confidence:.2f}")
 
1
+ import gradio as gr
2
+ import tensorflow as tf
 
3
  import numpy as np
4
+ from tensorflow.keras.preprocessing import image
5
+ from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
6
+ from PIL import Image
 
7
 
8
+ # Load the trained model
9
+ MODEL_PATH = "setosys_dogs_model.h5"
10
+ model = tf.keras.models.load_model(MODEL_PATH)
11
 
12
+ # Get class labels from the model (assuming the model has a 'class_indices' attribute)
13
+ class_labels = list(model.class_indices.keys()) # Fetch class labels from the model
 
14
 
15
+ # Image preprocessing function using EfficientNetV2S
16
+ def preprocess_image(img: Image.Image) -> np.ndarray:
17
  """Preprocess the image to match the model's input requirements."""
18
+ img = img.resize((224, 224)) # Resize image to model input size
19
+ img_array = np.array(img)
20
  img_array = preprocess_input(img_array) # EfficientNetV2 preprocessing
21
  img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
22
  return img_array
23
 
24
+ # Prediction function
25
+ def predict_dog_breed(img: Image.Image) -> dict:
26
  """Predict the breed of the dog in the uploaded image."""
27
+ img_array = preprocess_image(img)
28
  predictions = model.predict(img_array)
29
 
30
  # Check the shape of the predictions to make sure the output is correct
 
36
  # Get predicted breed and its confidence score
37
  predicted_breed = class_labels[class_idx] if class_idx < len(class_labels) else "Unknown"
38
 
39
+ return {predicted_breed: confidence}
40
 
41
+ # Create Gradio interface
42
+ interface = gr.Interface(
43
+ fn=predict_dog_breed,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs=gr.Label(),
46
+ title="Dog Breed Classifier",
47
+ description="Upload an image of a dog to predict its breed."
48
+ )
49
 
50
+ # Launch the Gradio app
51
+ if __name__ == "__main__":
52
+ interface.launch()