Update code/inference.py
Browse files- code/inference.py +12 -6
code/inference.py
CHANGED
@@ -7,8 +7,11 @@ import fcntl # For file locking
|
|
7 |
import os # For file operations
|
8 |
import time # For sleep function
|
9 |
|
10 |
-
# Set
|
11 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:
|
|
|
|
|
|
|
12 |
|
13 |
# Print to verify the environment variable is correctly set
|
14 |
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
|
@@ -70,15 +73,18 @@ def model_fn(model_dir, context=None):
|
|
70 |
model = load_checkpoint_and_dispatch(
|
71 |
model,
|
72 |
model_dir,
|
73 |
-
device_map="
|
74 |
-
offload_folder=offload_dir,
|
75 |
-
max_memory={i: "
|
76 |
-
no_split_module_classes=["QwenForCausalLM"] #
|
77 |
)
|
78 |
|
79 |
# Load the tokenizer
|
80 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
81 |
|
|
|
|
|
|
|
82 |
except Exception as e:
|
83 |
print(f"Error loading model and tokenizer: {e}")
|
84 |
raise
|
|
|
7 |
import os # For file operations
|
8 |
import time # For sleep function
|
9 |
|
10 |
+
# Set max_split_size globally to prevent memory fragmentation
|
11 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
|
12 |
+
|
13 |
+
# Enable detailed distributed logs
|
14 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
15 |
|
16 |
# Print to verify the environment variable is correctly set
|
17 |
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
|
|
|
73 |
model = load_checkpoint_and_dispatch(
|
74 |
model,
|
75 |
model_dir,
|
76 |
+
device_map="balanced", # Evenly distribute across GPUs
|
77 |
+
offload_folder=offload_dir,
|
78 |
+
max_memory={i: "18GiB" for i in range(torch.cuda.device_count())}, # Allocate 18 GiB per GPU
|
79 |
+
no_split_module_classes=["QwenForCausalLM"] # Split model across GPUs
|
80 |
)
|
81 |
|
82 |
# Load the tokenizer
|
83 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
84 |
|
85 |
+
# Free up any unused memory after loading
|
86 |
+
torch.cuda.empty_cache()
|
87 |
+
|
88 |
except Exception as e:
|
89 |
print(f"Error loading model and tokenizer: {e}")
|
90 |
raise
|