Update generator.py
Browse files- generator.py +1 -1
generator.py
CHANGED
@@ -177,7 +177,7 @@ def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
|
|
177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
178 |
state_dict = torch.load(ckpt_path)
|
179 |
model.load_state_dict(state_dict)
|
180 |
-
model.decoder = torch.compile(model.decoder, fullgraph=True, mode='reduce-overhead')
|
181 |
|
182 |
generator = Generator(model)
|
183 |
return generator
|
|
|
177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
178 |
state_dict = torch.load(ckpt_path)
|
179 |
model.load_state_dict(state_dict)
|
180 |
+
#model.decoder = torch.compile(model.decoder, fullgraph=True, mode='reduce-overhead')
|
181 |
|
182 |
generator = Generator(model)
|
183 |
return generator
|