Spaces:
Runtime error
Runtime error
[update]edit main
Browse files
main.py
CHANGED
@@ -64,7 +64,7 @@ def main():
|
|
64 |
args.pretrained_model_name_or_path,
|
65 |
trust_remote_code=True,
|
66 |
low_cpu_mem_usage=True,
|
67 |
-
torch_dtype=torch.
|
68 |
device_map="auto",
|
69 |
offload_folder="./offload",
|
70 |
offload_state_dict=True,
|
@@ -78,6 +78,7 @@ def main():
|
|
78 |
return_tensors="pt",
|
79 |
add_special_tokens=False,
|
80 |
).input_ids.to(args.device)
|
|
|
81 |
with torch.no_grad():
|
82 |
outputs = model.generate(
|
83 |
input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True,
|
|
|
64 |
args.pretrained_model_name_or_path,
|
65 |
trust_remote_code=True,
|
66 |
low_cpu_mem_usage=True,
|
67 |
+
torch_dtype=torch.bfloat16,
|
68 |
device_map="auto",
|
69 |
offload_folder="./offload",
|
70 |
offload_state_dict=True,
|
|
|
78 |
return_tensors="pt",
|
79 |
add_special_tokens=False,
|
80 |
).input_ids.to(args.device)
|
81 |
+
|
82 |
with torch.no_grad():
|
83 |
outputs = model.generate(
|
84 |
input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True,
|