AK1239 commited on
Commit
ee55a04
·
1 Parent(s): e8f7557

Using bits and bytes

Browse files
Files changed (1) hide show
  1. app/main.py +20 -15
app/main.py CHANGED
@@ -432,10 +432,6 @@ async def startup_event():
432
  total_memory = torch.cuda.get_device_properties(0).total_memory
433
  free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
434
  logger.info(f"GPU Memory - Total: {total_memory/1e9:.2f}GB, Free: {free_memory/1e9:.2f}GB")
435
-
436
- if free_memory < 1e9: # If less than 2GB free
437
- logger.warning("Low GPU memory detected, falling back to CPU")
438
- device = "cpu"
439
  except Exception as e:
440
  logger.warning(f"Error checking GPU memory: {e}. Falling back to CPU")
441
  device = "cpu"
@@ -458,25 +454,34 @@ async def startup_event():
458
 
459
  # Initialize the model and index with memory optimizations
460
  try:
 
 
 
 
461
  model_kwargs = {
462
- "device_map": "auto" if device == "cuda" else "cpu",
463
- "torch_dtype": torch.float16 if device == "cuda" else torch.float32,
 
464
  "low_cpu_mem_usage": True,
465
  }
466
 
467
- if device == "cpu":
468
- # Additional CPU optimizations
469
- model_kwargs.update({
470
- "offload_folder": "offload",
471
- "offload_state_dict": True
472
- })
473
 
 
 
 
 
 
 
 
 
 
474
  app.state.pipe = pipeline(
475
  "text-generation",
476
- model=MODEL_ID,
 
477
  trust_remote_code=True,
478
- token=HUGGINGFACE_TOKEN,
479
- **model_kwargs
480
  )
481
 
482
  faiss_index, documents, embedding_model = await load_or_create_index()
 
432
  total_memory = torch.cuda.get_device_properties(0).total_memory
433
  free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
434
  logger.info(f"GPU Memory - Total: {total_memory/1e9:.2f}GB, Free: {free_memory/1e9:.2f}GB")
 
 
 
 
435
  except Exception as e:
436
  logger.warning(f"Error checking GPU memory: {e}. Falling back to CPU")
437
  device = "cpu"
 
454
 
455
  # Initialize the model and index with memory optimizations
456
  try:
457
+ from transformers import AutoModelForCausalLM, AutoTokenizer
458
+ import bitsandbytes as bnb
459
+
460
+ logger.info("Loading model with 8-bit quantization...")
461
  model_kwargs = {
462
+ "device_map": "auto",
463
+ "load_in_8bit": True, # Enable 8-bit quantization
464
+ "torch_dtype": torch.float16,
465
  "low_cpu_mem_usage": True,
466
  }
467
 
468
+ # Initialize tokenizer
469
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HUGGINGFACE_TOKEN)
 
 
 
 
470
 
471
+ # Load model with 8-bit quantization
472
+ model = AutoModelForCausalLM.from_pretrained(
473
+ MODEL_ID,
474
+ token=HUGGINGFACE_TOKEN,
475
+ trust_remote_code=True,
476
+ **model_kwargs
477
+ )
478
+
479
+ # Create pipeline with quantized model
480
  app.state.pipe = pipeline(
481
  "text-generation",
482
+ model=model,
483
+ tokenizer=tokenizer,
484
  trust_remote_code=True,
 
 
485
  )
486
 
487
  faiss_index, documents, embedding_model = await load_or_create_index()