File size: 967 Bytes
f08fa03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import transformers
from mtm import MultitaskModel
from config import model_name, drop_out
multitask_model = MultitaskModel.create(
model_name=model_name,
model_type_dict={
"document": transformers.AutoModelForSequenceClassification,
"paragraph": transformers.AutoModelForSequenceClassification,
"sentence": transformers.AutoModelForSequenceClassification,
},
model_config_dict={
"document": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
"paragraph": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
"sentence": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|