Use `bfloat16` in the example
Browse files
README.md
CHANGED
@@ -62,7 +62,7 @@ Then, run the following code:
|
|
62 |
from diffusers import CogView3PlusPipeline
|
63 |
import torch
|
64 |
|
65 |
-
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.
|
66 |
|
67 |
# Enable it to reduce GPU memory usage
|
68 |
pipe.enable_model_cpu_offload()
|
|
|
62 |
from diffusers import CogView3PlusPipeline
|
63 |
import torch
|
64 |
|
65 |
+
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16).to("cuda")
|
66 |
|
67 |
# Enable it to reduce GPU memory usage
|
68 |
pipe.enable_model_cpu_offload()
|