Jonathan Malott commited on
Commit
4831980
·
1 Parent(s): 3caf1a1

Updated model location

Browse files
Files changed (1) hide show
  1. dalle/models/__init__.py +7 -6
dalle/models/__init__.py CHANGED
@@ -43,20 +43,21 @@ class Dalle(nn.Module):
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
- path = _MODELS[path] if path in _MODELS else path
47
- path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
 
48
 
49
  config_base = get_base_config()
50
- config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
51
  config_update = OmegaConf.merge(config_base, config_new)
52
 
53
  model = cls(config_update)
54
- model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
55
  context_length=model.config_dataset.context_length,
56
  lowercase=True,
57
  dropout=None)
58
- model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
59
- model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
60
  return model
61
 
62
  @torch.no_grad()
 
43
  @classmethod
44
  def from_pretrained(cls,
45
  path: str) -> nn.Module:
46
+ #path = _MODELS[path] if path in _MODELS else path
47
+ #path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
48
+ path = ".cache/minDALL-E/1.3B/"
49
 
50
  config_base = get_base_config()
51
+ config_new = OmegaConf.load(path+'config.yaml')
52
  config_update = OmegaConf.merge(config_base, config_new)
53
 
54
  model = cls(config_update)
55
+ model.tokenizer = build_tokenizer(path+'tokenizer',
56
  context_length=model.config_dataset.context_length,
57
  lowercase=True,
58
  dropout=None)
59
+ model.stage1.from_ckpt(path+'stage1_last.ckpt')
60
+ model.stage2.from_ckpt(path+'stage2_last.ckpt')
61
  return model
62
 
63
  @torch.no_grad()