import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image import numpy as np import argparse import os # Define class labels (ensure these match your training) CNN_CLASS_LABELS = ['Side Chest', 'Front Double Biceps', 'Back Double Biceps', 'Front Lat Spread', 'Back Lat Spread'] MODEL_PATH = 'bodybuilding_pose_classifier_savedmodel.keras' # Corrected model path def predict_pose_from_image(model, img_path): """ Loads an image, preprocesses it, and predicts the bodybuilding pose. Args: model: The loaded Keras model. img_path (str): Path to the image file. Returns: tuple: (predicted_class_label, confidence_score) or (None, None) if error. """ try: if not os.path.exists(img_path): print(f"Error: Image path not found: {img_path}") return None, None # Load and preprocess the image img = image.load_img(img_path, target_size=(150, 150)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) / 255.0 # Normalize # Make prediction predictions = model.predict(img_array) predicted_class_index = np.argmax(predictions, axis=1)[0] confidence = float(np.max(predictions)) predicted_class_label = CNN_CLASS_LABELS[predicted_class_index] return predicted_class_label, confidence except Exception as e: print(f"Error during prediction: {e}") return None, None def main(): parser = argparse.ArgumentParser(description="Classify a bodybuilding pose from an image.") parser.add_argument("image_path", help="Path to the input image file.") args = parser.parse_args() # Load the Keras model print(f"Loading model from: {MODEL_PATH}") try: model = load_model(MODEL_PATH) # Optional: Print model summary to verify # model.summary() except Exception as e: print(f"Error loading model: {e}") return print(f"Classifying image: {args.image_path}") predicted_pose, confidence_score = predict_pose_from_image(model, args.image_path) if predicted_pose and confidence_score is not None: print(f"Predicted Pose: {predicted_pose}") print(f"Confidence: {confidence_score:.2f}") if __name__ == "__main__": main()