Using bits and bytes
Browse files- app/main.py +20 -15
app/main.py
CHANGED
@@ -432,10 +432,6 @@ async def startup_event():
|
|
432 |
total_memory = torch.cuda.get_device_properties(0).total_memory
|
433 |
free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
|
434 |
logger.info(f"GPU Memory - Total: {total_memory/1e9:.2f}GB, Free: {free_memory/1e9:.2f}GB")
|
435 |
-
|
436 |
-
if free_memory < 1e9: # If less than 2GB free
|
437 |
-
logger.warning("Low GPU memory detected, falling back to CPU")
|
438 |
-
device = "cpu"
|
439 |
except Exception as e:
|
440 |
logger.warning(f"Error checking GPU memory: {e}. Falling back to CPU")
|
441 |
device = "cpu"
|
@@ -458,25 +454,34 @@ async def startup_event():
|
|
458 |
|
459 |
# Initialize the model and index with memory optimizations
|
460 |
try:
|
|
|
|
|
|
|
|
|
461 |
model_kwargs = {
|
462 |
-
"device_map": "auto"
|
463 |
-
"
|
|
|
464 |
"low_cpu_mem_usage": True,
|
465 |
}
|
466 |
|
467 |
-
|
468 |
-
|
469 |
-
model_kwargs.update({
|
470 |
-
"offload_folder": "offload",
|
471 |
-
"offload_state_dict": True
|
472 |
-
})
|
473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
app.state.pipe = pipeline(
|
475 |
"text-generation",
|
476 |
-
model=
|
|
|
477 |
trust_remote_code=True,
|
478 |
-
token=HUGGINGFACE_TOKEN,
|
479 |
-
**model_kwargs
|
480 |
)
|
481 |
|
482 |
faiss_index, documents, embedding_model = await load_or_create_index()
|
|
|
432 |
total_memory = torch.cuda.get_device_properties(0).total_memory
|
433 |
free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
|
434 |
logger.info(f"GPU Memory - Total: {total_memory/1e9:.2f}GB, Free: {free_memory/1e9:.2f}GB")
|
|
|
|
|
|
|
|
|
435 |
except Exception as e:
|
436 |
logger.warning(f"Error checking GPU memory: {e}. Falling back to CPU")
|
437 |
device = "cpu"
|
|
|
454 |
|
455 |
# Initialize the model and index with memory optimizations
|
456 |
try:
|
457 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
458 |
+
import bitsandbytes as bnb
|
459 |
+
|
460 |
+
logger.info("Loading model with 8-bit quantization...")
|
461 |
model_kwargs = {
|
462 |
+
"device_map": "auto",
|
463 |
+
"load_in_8bit": True, # Enable 8-bit quantization
|
464 |
+
"torch_dtype": torch.float16,
|
465 |
"low_cpu_mem_usage": True,
|
466 |
}
|
467 |
|
468 |
+
# Initialize tokenizer
|
469 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HUGGINGFACE_TOKEN)
|
|
|
|
|
|
|
|
|
470 |
|
471 |
+
# Load model with 8-bit quantization
|
472 |
+
model = AutoModelForCausalLM.from_pretrained(
|
473 |
+
MODEL_ID,
|
474 |
+
token=HUGGINGFACE_TOKEN,
|
475 |
+
trust_remote_code=True,
|
476 |
+
**model_kwargs
|
477 |
+
)
|
478 |
+
|
479 |
+
# Create pipeline with quantized model
|
480 |
app.state.pipe = pipeline(
|
481 |
"text-generation",
|
482 |
+
model=model,
|
483 |
+
tokenizer=tokenizer,
|
484 |
trust_remote_code=True,
|
|
|
|
|
485 |
)
|
486 |
|
487 |
faiss_index, documents, embedding_model = await load_or_create_index()
|