Spaces:
Running
Running
fix: update model name
Browse files- tools/train/train.py +1 -1
tools/train/train.py
CHANGED
@@ -398,7 +398,7 @@ def main():
|
|
398 |
artifact_dir = artifact.download()
|
399 |
|
400 |
# load model
|
401 |
-
model =
|
402 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
403 |
print(model.params)
|
404 |
|
|
|
398 |
artifact_dir = artifact.download()
|
399 |
|
400 |
# load model
|
401 |
+
model = DalleBart.from_pretrained(artifact_dir)
|
402 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
403 |
print(model.params)
|
404 |
|