qgyd2021 commited on
Commit
b18f746
·
1 Parent(s): 44adba5

[update]edit main

Browse files
Files changed (1) hide show
  1. main.py +2 -1
main.py CHANGED
@@ -64,7 +64,7 @@ def main():
64
  args.pretrained_model_name_or_path,
65
  trust_remote_code=True,
66
  low_cpu_mem_usage=True,
67
- torch_dtype=torch.float32,
68
  device_map="auto",
69
  offload_folder="./offload",
70
  offload_state_dict=True,
@@ -78,6 +78,7 @@ def main():
78
  return_tensors="pt",
79
  add_special_tokens=False,
80
  ).input_ids.to(args.device)
 
81
  with torch.no_grad():
82
  outputs = model.generate(
83
  input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True,
 
64
  args.pretrained_model_name_or_path,
65
  trust_remote_code=True,
66
  low_cpu_mem_usage=True,
67
+ torch_dtype=torch.bfloat16,
68
  device_map="auto",
69
  offload_folder="./offload",
70
  offload_state_dict=True,
 
78
  return_tensors="pt",
79
  add_special_tokens=False,
80
  ).input_ids.to(args.device)
81
+
82
  with torch.no_grad():
83
  outputs = model.generate(
84
  input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True,