NightRaven109 commited on
Commit
2661499
·
verified ·
1 Parent(s): 50288f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
- os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
 
 
3
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TensorFlow warnings
4
 
5
  # Suppress JAX compilation warnings
@@ -176,13 +178,8 @@ class SimpleMAXIMPredictor:
176
 
177
  try:
178
  import jax
179
- # Check if GPU is available
180
- gpu_available = len(jax.devices('gpu')) > 0
181
- device_info = f"Using device: {jax.default_backend()}"
182
- if gpu_available:
183
- device_info += f" (GPU: {jax.devices('gpu')[0]})"
184
- else:
185
- device_info += " (CPU fallback)"
186
  print(device_info)
187
 
188
  # Get model and parameters
 
1
  import os
2
+ # Force CPU-only mode
3
+ os.environ["JAX_PLATFORM_NAME"] = "cpu"
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TensorFlow warnings
6
 
7
  # Suppress JAX compilation warnings
 
178
 
179
  try:
180
  import jax
181
+ # Force CPU mode - device info
182
+ device_info = f"Using device: {jax.default_backend()} (CPU-only mode)"
 
 
 
 
 
183
  print(device_info)
184
 
185
  # Get model and parameters