Update train.py
Browse files
train.py
CHANGED
@@ -161,7 +161,7 @@ class markerModel(pl.LightningModule):
|
|
161 |
# self.sigmoid = nn.Sigmoid()
|
162 |
|
163 |
#-- Pretrained Model Setting
|
164 |
-
acc_config = AutoConfig.from_pretrained(
|
165 |
if d_pretrained is False:
|
166 |
self.d_model = RobertaModel(acc_config)
|
167 |
print('acceptor model without pretraining')
|
@@ -170,7 +170,7 @@ class markerModel(pl.LightningModule):
|
|
170 |
output_hidden_states=True,
|
171 |
output_attentions=True)
|
172 |
|
173 |
-
don_config = AutoConfig.from_pretrained(
|
174 |
|
175 |
if p_pretrained is False:
|
176 |
self.p_model = RobertaModel(don_config)
|
|
|
161 |
# self.sigmoid = nn.Sigmoid()
|
162 |
|
163 |
#-- Pretrained Model Setting
|
164 |
+
acc_config = AutoConfig.from_pretrained(acc_model_name)
|
165 |
if d_pretrained is False:
|
166 |
self.d_model = RobertaModel(acc_config)
|
167 |
print('acceptor model without pretraining')
|
|
|
170 |
output_hidden_states=True,
|
171 |
output_attentions=True)
|
172 |
|
173 |
+
don_config = AutoConfig.from_pretrained(don_model_name)
|
174 |
|
175 |
if p_pretrained is False:
|
176 |
self.p_model = RobertaModel(don_config)
|