AK1239 commited on
Commit
e8f7557
·
1 Parent(s): 925a57e

Memory Optimizations

Browse files
Files changed (1) hide show
  1. app/main.py +35 -5
app/main.py CHANGED
@@ -418,12 +418,30 @@ async def startup_event():
418
  logger = logging.getLogger(__name__)
419
  logger.info("Starting application initialization...")
420
 
421
- # Check if CUDA is available
 
 
 
422
  device = "cuda" if torch.cuda.is_available() else "cpu"
423
  logger.info(f"Using device: {device}")
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  if device == "cpu":
426
- logger.warning("GPU not detected. Model will run slower on CPU.")
427
 
428
  # Set NLTK data path
429
  nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data'))
@@ -438,15 +456,27 @@ async def startup_event():
438
  logger.error(f"Error downloading NLTK data: {str(e)}")
439
  raise Exception(f"Failed to initialize application: {str(e)}")
440
 
441
- # Initialize the model and index
442
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  app.state.pipe = pipeline(
444
  "text-generation",
445
  model=MODEL_ID,
446
  trust_remote_code=True,
447
  token=HUGGINGFACE_TOKEN,
448
- device_map="auto",
449
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
450
  )
451
 
452
  faiss_index, documents, embedding_model = await load_or_create_index()
 
418
  logger = logging.getLogger(__name__)
419
  logger.info("Starting application initialization...")
420
 
421
+ # Set PyTorch memory management settings
422
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
423
+
424
+ # Check if CUDA is available and has enough memory
425
  device = "cuda" if torch.cuda.is_available() else "cpu"
426
  logger.info(f"Using device: {device}")
427
 
428
+ if device == "cuda":
429
+ try:
430
+ # Try to estimate available GPU memory
431
+ torch.cuda.empty_cache()
432
+ total_memory = torch.cuda.get_device_properties(0).total_memory
433
+ free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
434
+ logger.info(f"GPU Memory - Total: {total_memory/1e9:.2f}GB, Free: {free_memory/1e9:.2f}GB")
435
+
436
+ if free_memory < 1e9: # If less than 2GB free
437
+ logger.warning("Low GPU memory detected, falling back to CPU")
438
+ device = "cpu"
439
+ except Exception as e:
440
+ logger.warning(f"Error checking GPU memory: {e}. Falling back to CPU")
441
+ device = "cpu"
442
+
443
  if device == "cpu":
444
+ logger.warning("Using CPU. Model will run slower.")
445
 
446
  # Set NLTK data path
447
  nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data'))
 
456
  logger.error(f"Error downloading NLTK data: {str(e)}")
457
  raise Exception(f"Failed to initialize application: {str(e)}")
458
 
459
+ # Initialize the model and index with memory optimizations
460
  try:
461
+ model_kwargs = {
462
+ "device_map": "auto" if device == "cuda" else "cpu",
463
+ "torch_dtype": torch.float16 if device == "cuda" else torch.float32,
464
+ "low_cpu_mem_usage": True,
465
+ }
466
+
467
+ if device == "cpu":
468
+ # Additional CPU optimizations
469
+ model_kwargs.update({
470
+ "offload_folder": "offload",
471
+ "offload_state_dict": True
472
+ })
473
+
474
  app.state.pipe = pipeline(
475
  "text-generation",
476
  model=MODEL_ID,
477
  trust_remote_code=True,
478
  token=HUGGINGFACE_TOKEN,
479
+ **model_kwargs
 
480
  )
481
 
482
  faiss_index, documents, embedding_model = await load_or_create_index()