boris commited on
Commit
a5ed112
·
1 Parent(s): f69b21b

feat: use_artifact if run existing

Browse files
src/dalle_mini/model/configuration.py CHANGED
@@ -18,7 +18,7 @@ import warnings
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
21
- from .wandb_pretrained import PretrainedFromWandbMixin
22
 
23
  logger = logging.get_logger(__name__)
24
 
 
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
21
+ from .utils import PretrainedFromWandbMixin
22
 
23
  logger = logging.get_logger(__name__)
24
 
src/dalle_mini/model/modeling.py CHANGED
@@ -46,7 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
46
  from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
49
- from .wandb_pretrained import PretrainedFromWandbMixin
50
 
51
  logger = logging.get_logger(__name__)
52
 
 
46
  from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
49
+ from .utils import PretrainedFromWandbMixin
50
 
51
  logger = logging.get_logger(__name__)
52
 
src/dalle_mini/model/tokenizer.py CHANGED
@@ -2,7 +2,7 @@
2
  from transformers import BartTokenizer
3
  from transformers.utils import logging
4
 
5
- from .wandb_pretrained import PretrainedFromWandbMixin
6
 
7
  logger = logging.get_logger(__name__)
8
 
 
2
  from transformers import BartTokenizer
3
  from transformers.utils import logging
4
 
5
+ from .utils import PretrainedFromWandbMixin
6
 
7
  logger = logging.get_logger(__name__)
8
 
src/dalle_mini/model/{wandb_pretrained.py → utils.py} RENAMED
@@ -13,7 +13,10 @@ class PretrainedFromWandbMixin:
13
  pretrained_model_name_or_path
14
  ):
15
  # wandb artifact
16
- artifact = wandb.Api().artifact(pretrained_model_name_or_path)
 
 
 
17
  pretrained_model_name_or_path = artifact.download()
18
 
19
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
 
13
  pretrained_model_name_or_path
14
  ):
15
  # wandb artifact
16
+ if wandb.run is not None:
17
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
18
+ else:
19
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
20
  pretrained_model_name_or_path = artifact.download()
21
 
22
  return super(PretrainedFromWandbMixin, cls).from_pretrained(