Sean Carnahan commited on
Commit
0c12779
·
1 Parent(s): 460d366

Attempt to fix JIT error by forcing CNN on CPU

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -7,6 +7,7 @@ import os
7
  from werkzeug.utils import secure_filename
8
  import sys
9
  import traceback
 
10
  from tensorflow.keras.models import load_model
11
  from tensorflow.keras.preprocessing import image
12
  import time
@@ -92,13 +93,21 @@ cnn_model = load_model(cnn_model_path)
92
  cnn_class_labels = ['side_chest', 'front_double_biceps', 'back_double_biceps', 'front_lat_spread', 'back_lat_spread']
93
 
94
  def predict_pose_cnn(img_path):
95
- img = image.load_img(img_path, target_size=(150, 150))
96
- img_array = image.img_to_array(img)
97
- img_array = np.expand_dims(img_array, axis=0) / 255.0
98
- predictions = cnn_model.predict(img_array)
99
- predicted_class = np.argmax(predictions, axis=1)
100
- confidence = float(np.max(predictions))
101
- return cnn_class_labels[predicted_class[0]], confidence
 
 
 
 
 
 
 
 
102
 
103
  @app.route('/static/uploads/<path:filename>')
104
  def serve_video(filename):
 
7
  from werkzeug.utils import secure_filename
8
  import sys
9
  import traceback
10
+ import tensorflow as tf
11
  from tensorflow.keras.models import load_model
12
  from tensorflow.keras.preprocessing import image
13
  import time
 
93
  cnn_class_labels = ['side_chest', 'front_double_biceps', 'back_double_biceps', 'front_lat_spread', 'back_lat_spread']
94
 
95
  def predict_pose_cnn(img_path):
96
+ try:
97
+ print("[CNN_DEBUG] Forcing CPU for CNN prediction")
98
+ with tf.device('/CPU:0'): # This line requires 'import tensorflow as tf'
99
+ img = image.load_img(img_path, target_size=(150, 150))
100
+ img_array = image.img_to_array(img)
101
+ img_array = np.expand_dims(img_array, axis=0) / 255.0
102
+ predictions = cnn_model.predict(img_array)
103
+ predicted_class = np.argmax(predictions, axis=1)
104
+ confidence = float(np.max(predictions))
105
+ print(f"[CNN_DEBUG] Prediction successful on CPU: {cnn_class_labels[predicted_class[0]]}")
106
+ return cnn_class_labels[predicted_class[0]], confidence
107
+ except Exception as e:
108
+ print(f"[CNN_ERROR] Exception during CNN prediction: {e}")
109
+ traceback.print_exc()
110
+ raise # Re-raise the exception to be caught by the calling function
111
 
112
  @app.route('/static/uploads/<path:filename>')
113
  def serve_video(filename):