oleksandrfluxon commited on
Commit
73a8df3
·
1 Parent(s): 2497126

Test max_memory for GPU

Browse files
Files changed (1) hide show
  1. pipeline.py +24 -3
pipeline.py CHANGED
@@ -1,5 +1,7 @@
1
  import torch
2
  import transformers
 
 
3
  from typing import Dict, List, Any
4
 
5
  class PreTrainedPipeline():
@@ -21,17 +23,36 @@ class PreTrainedPipeline():
21
  path,
22
  config=config,
23
  # torch_dtype=torch.bfloat16, # Load model weights in bfloat16
 
24
  trust_remote_code=True
25
  # load_in_4bit=True, # Load model in the lowest 4-bit precision quantization
26
  )
27
- model.to('cuda')
28
  print("===> model loaded")
29
 
30
  # removed device_map="auto"
31
  tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', padding_side="left")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
-
34
- self.pipeline = transformers.pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda:0')
 
35
  print("===> init finished")
36
 
37
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
1
  import torch
2
  import transformers
3
+ from accelerate import dispatch_model, infer_auto_device_map
4
+ from accelerate.utils import get_balanced_memory
5
  from typing import Dict, List, Any
6
 
7
  class PreTrainedPipeline():
 
23
  path,
24
  config=config,
25
  # torch_dtype=torch.bfloat16, # Load model weights in bfloat16
26
+ torch_dtype=torch.float16,
27
  trust_remote_code=True
28
  # load_in_4bit=True, # Load model in the lowest 4-bit precision quantization
29
  )
30
+ # model.to('cuda')
31
  print("===> model loaded")
32
 
33
  # removed device_map="auto"
34
  tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', padding_side="left")
35
+
36
+
37
+ max_memory = get_balanced_memory(
38
+ model,
39
+ max_memory=None,
40
+ no_split_module_classes=["MPTBlock"],
41
+ dtype='float16',
42
+ low_zero=False,
43
+ )
44
+
45
+ device_map = infer_auto_device_map(
46
+ model,
47
+ max_memory=max_memory,
48
+ no_split_module_classes=["MPTBlock"],
49
+ dtype='float16'
50
+ )
51
+ model = dispatch_model(model, device_map=device_map)
52
 
53
+
54
+ # device='cuda:0'
55
+ self.pipeline = transformers.pipeline('text-generation', model=model, tokenizer=tokenizer)
56
  print("===> init finished")
57
 
58
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: