Spaces:
Sleeping
Sleeping
Sean Carnahan
commited on
Commit
·
0c12779
1
Parent(s):
460d366
Attempt to fix JIT error by forcing CNN on CPU
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import os
|
|
7 |
from werkzeug.utils import secure_filename
|
8 |
import sys
|
9 |
import traceback
|
|
|
10 |
from tensorflow.keras.models import load_model
|
11 |
from tensorflow.keras.preprocessing import image
|
12 |
import time
|
@@ -92,13 +93,21 @@ cnn_model = load_model(cnn_model_path)
|
|
92 |
cnn_class_labels = ['side_chest', 'front_double_biceps', 'back_double_biceps', 'front_lat_spread', 'back_lat_spread']
|
93 |
|
94 |
def predict_pose_cnn(img_path):
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
@app.route('/static/uploads/<path:filename>')
|
104 |
def serve_video(filename):
|
|
|
7 |
from werkzeug.utils import secure_filename
|
8 |
import sys
|
9 |
import traceback
|
10 |
+
import tensorflow as tf
|
11 |
from tensorflow.keras.models import load_model
|
12 |
from tensorflow.keras.preprocessing import image
|
13 |
import time
|
|
|
93 |
cnn_class_labels = ['side_chest', 'front_double_biceps', 'back_double_biceps', 'front_lat_spread', 'back_lat_spread']
|
94 |
|
95 |
def predict_pose_cnn(img_path):
|
96 |
+
try:
|
97 |
+
print("[CNN_DEBUG] Forcing CPU for CNN prediction")
|
98 |
+
with tf.device('/CPU:0'): # This line requires 'import tensorflow as tf'
|
99 |
+
img = image.load_img(img_path, target_size=(150, 150))
|
100 |
+
img_array = image.img_to_array(img)
|
101 |
+
img_array = np.expand_dims(img_array, axis=0) / 255.0
|
102 |
+
predictions = cnn_model.predict(img_array)
|
103 |
+
predicted_class = np.argmax(predictions, axis=1)
|
104 |
+
confidence = float(np.max(predictions))
|
105 |
+
print(f"[CNN_DEBUG] Prediction successful on CPU: {cnn_class_labels[predicted_class[0]]}")
|
106 |
+
return cnn_class_labels[predicted_class[0]], confidence
|
107 |
+
except Exception as e:
|
108 |
+
print(f"[CNN_ERROR] Exception during CNN prediction: {e}")
|
109 |
+
traceback.print_exc()
|
110 |
+
raise # Re-raise the exception to be caught by the calling function
|
111 |
|
112 |
@app.route('/static/uploads/<path:filename>')
|
113 |
def serve_video(filename):
|