philipp-zettl commited on
Commit
6cc9fa2
·
verified ·
1 Parent(s): 88254c7

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -335,8 +335,8 @@ class MultiHeadClassification(nn.Module):
335
  dropout (float): Dropout rate
336
  l2_reg (float): L2 regularization rate
337
  """
338
- backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
339
  instance = cls(backbone, head_config, dropout, l2_reg)
340
- instance.load(os.path.join(model_path, 'pretrained/model.pth'))
341
  instance.head_config = {k: v. instance.heads}
342
  return instance
 
335
  dropout (float): Dropout rate
336
  l2_reg (float): L2 regularization rate
337
  """
338
+ backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone'))
339
  instance = cls(backbone, head_config, dropout, l2_reg)
340
+ instance.load(os.path.join(model_path, 'pretrained/multi-head-sequence-classification-model-model.pth'))
341
  instance.head_config = {k: v. instance.heads}
342
  return instance