Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
9f522b8
1
Parent(s):
290e443
Accept changes suggested by linter.
Browse files- src/dalle_mini/model/modeling.py +7 -3
- tools/train/train.py +3 -1
src/dalle_mini/model/modeling.py
CHANGED
@@ -569,14 +569,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
569 |
"""
|
570 |
Initializes from a wandb artifact, or delegates loading to the superclass.
|
571 |
"""
|
572 |
-
if
|
|
|
|
|
573 |
# wandb artifact
|
574 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
575 |
-
|
576 |
# we download everything, including opt_state, so we can resume training if needed
|
577 |
# see also: #120
|
578 |
pretrained_model_name_or_path = artifact.download()
|
579 |
|
580 |
-
model = super(DalleBart, cls).from_pretrained(
|
|
|
|
|
581 |
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
582 |
return model
|
|
|
569 |
"""
|
570 |
Initializes from a wandb artifact, or delegates loading to the superclass.
|
571 |
"""
|
572 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
573 |
+
pretrained_model_name_or_path
|
574 |
+
):
|
575 |
# wandb artifact
|
576 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
577 |
+
|
578 |
# we download everything, including opt_state, so we can resume training if needed
|
579 |
# see also: #120
|
580 |
pretrained_model_name_or_path = artifact.download()
|
581 |
|
582 |
+
model = super(DalleBart, cls).from_pretrained(
|
583 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
584 |
+
)
|
585 |
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
586 |
return model
|
tools/train/train.py
CHANGED
@@ -437,7 +437,9 @@ def main():
|
|
437 |
if training_args.resume_from_checkpoint is not None:
|
438 |
# load model
|
439 |
model = DalleBart.from_pretrained(
|
440 |
-
training_args.resume_from_checkpoint,
|
|
|
|
|
441 |
)
|
442 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
443 |
print(model.params)
|
|
|
437 |
if training_args.resume_from_checkpoint is not None:
|
438 |
# load model
|
439 |
model = DalleBart.from_pretrained(
|
440 |
+
training_args.resume_from_checkpoint,
|
441 |
+
dtype=getattr(jnp, model_args.dtype),
|
442 |
+
abstract_init=True,
|
443 |
)
|
444 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
445 |
print(model.params)
|