anamargarida commited on
Commit
55b8086
·
verified ·
1 Parent(s): 55bc5ab

Update modeling_st2.py

Browse files
Files changed (1) hide show
  1. modeling_st2.py +15 -37
modeling_st2.py CHANGED
@@ -33,9 +33,9 @@ class ST2ModelV2(nn.Module):
33
  super(ST2ModelV2, self).__init__()
34
  self.args = args
35
 
36
- self.config = AutoConfig.from_pretrained(args.model_name_or_path)
37
- self.model = AutoModel.from_pretrained(args.model_name_or_path)
38
- self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
39
 
40
  classifier_dropout = self.args.dropout
41
  self.dropout = nn.Dropout(classifier_dropout)
@@ -104,40 +104,18 @@ class ST2ModelV2(nn.Module):
104
  return_dict=return_dict,
105
  )
106
  else:
107
- if self.args.model_name_or_path in ['facebook/bart-large', 'facebook/bart-base', 'facebook/bart-large-cnn']:
108
- outputs = self.model(
109
- input_ids,
110
- attention_mask=attention_mask,
111
- head_mask=head_mask,
112
- inputs_embeds=inputs_embeds,
113
- output_attentions=output_attentions,
114
- output_hidden_states=output_hidden_states,
115
- return_dict=return_dict,
116
- )
117
- elif self.args.model_name_or_path in ['microsoft/deberta-base']:
118
- outputs = self.model(
119
- input_ids,
120
- attention_mask=attention_mask,
121
- token_type_ids=token_type_ids,
122
- position_ids=position_ids,
123
- inputs_embeds=inputs_embeds,
124
- output_attentions=output_attentions,
125
- output_hidden_states=output_hidden_states,
126
- return_dict=return_dict,
127
- )
128
-
129
- else:
130
- outputs = self.model(
131
- input_ids,
132
- attention_mask=attention_mask,
133
- token_type_ids=token_type_ids,
134
- position_ids=position_ids,
135
- head_mask=head_mask,
136
- inputs_embeds=inputs_embeds,
137
- output_attentions=output_attentions,
138
- output_hidden_states=output_hidden_states,
139
- return_dict=return_dict,
140
- )
141
 
142
  sequence_output = outputs[0]
143
  if signal_bias_mask is not None and self.args.signal_bias_on_top_of_lm:
 
33
  super(ST2ModelV2, self).__init__()
34
  self.args = args
35
 
36
+ self.config = AutoConfig.from_pretrained("roberta-large")
37
+ self.model = AutoModel.from_pretrained("roberta-large")
38
+ self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
39
 
40
  classifier_dropout = self.args.dropout
41
  self.dropout = nn.Dropout(classifier_dropout)
 
104
  return_dict=return_dict,
105
  )
106
  else:
107
+
108
+ outputs = self.model(
109
+ input_ids,
110
+ attention_mask=attention_mask,
111
+ token_type_ids=token_type_ids,
112
+ position_ids=position_ids,
113
+ head_mask=head_mask,
114
+ inputs_embeds=inputs_embeds,
115
+ output_attentions=output_attentions,
116
+ output_hidden_states=output_hidden_states,
117
+ return_dict=return_dict,
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  sequence_output = outputs[0]
121
  if signal_bias_mask is not None and self.args.signal_bias_on_top_of_lm: