|
import os |
|
import torch |
|
from torch import nn |
|
from transformers import AutoModel |
|
from huggingface_hub import hf_hub_download |
|
|
|
token=os.getenv("HF_TOKEN") |
|
repo_id = "Siyunb323/CreativityEvaluation" |
|
model = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese") |
|
|
|
class BERTregressor(nn.Module): |
|
def __init__(self, bert, hidden_size=768, num_linear=1, dropout=0.1, |
|
o_type='cls', t_type= 'C', use_sigmoid=False): |
|
|
|
super(BERTregressor, self).__init__() |
|
self.encoder = bert |
|
self.o_type = o_type |
|
self.t_type = t_type |
|
self.sigmoid = use_sigmoid |
|
|
|
if num_linear == 2: |
|
layers = [nn.Linear(hidden_size, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(128, 1)] |
|
elif num_linear == 1: |
|
layers = [nn.Dropout(dropout), |
|
nn.Linear(hidden_size, 1)] |
|
|
|
if use_sigmoid: |
|
layers.append(nn.Sigmoid()) |
|
|
|
self.output = nn.Sequential(*layers) |
|
|
|
def forward(self, inputs, return_attention=False): |
|
|
|
X = {'input_ids':inputs['input_ids'], |
|
'token_type_ids':inputs['token_type_ids'], |
|
'attention_mask':inputs['attention_mask'], |
|
'output_attentions':return_attention} |
|
encoded_X = self.encoder(**X) |
|
if self.o_type == 'cls': |
|
output = self.output(encoded_X.last_hidden_state[:, 0, :]) |
|
elif self.o_type == 'pooler': |
|
output = self.output(encoded_X.pooler_output) |
|
|
|
output = 4 * output.squeeze(-1) + 1 if self.sigmoid and self.t_type == 'C' else output.squeeze(-1) |
|
|
|
return output if not return_attention else (output, encoded_X.attentions) |
|
|
|
class Effectiveness(nn.Module): |
|
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs): |
|
super(Effectiveness, self).__init__(**kwargs) |
|
self.sigmoid = use_sigmoid |
|
|
|
if num_layers == 2: |
|
layers = [ |
|
nn.Linear(hidden_size, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(128, 1) |
|
] |
|
else: |
|
layers = [ |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_size, 1) |
|
] |
|
|
|
if use_sigmoid: |
|
layers.append(nn.Sigmoid()) |
|
|
|
self.output = nn.Sequential(*layers) |
|
|
|
def forward(self, X): |
|
output = self.output(X) |
|
|
|
|
|
if self.sigmoid: |
|
return 4 * output.squeeze(-1) + 1 |
|
else: |
|
return output.squeeze(-1) |
|
|
|
class Creativity(nn.Module): |
|
"""BERT的下一句预测任务""" |
|
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs): |
|
super(Creativity, self).__init__(**kwargs) |
|
self.sigmoid = use_sigmoid |
|
|
|
if num_layers == 2: |
|
layers = [ |
|
nn.Linear(hidden_size, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(128, 1) |
|
] |
|
else: |
|
layers = [ |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_size, 1) |
|
] |
|
|
|
if use_sigmoid: |
|
layers.append(nn.Sigmoid()) |
|
|
|
self.output = nn.Sequential(*layers) |
|
|
|
def forward(self, X): |
|
output = self.output(X) |
|
|
|
|
|
if self.sigmoid: |
|
return 4 * output.squeeze(-1) + 1 |
|
else: |
|
return output.squeeze(-1) |
|
|
|
class BERT2Phase(nn.Module): |
|
def __init__(self, bert, hidden_size=768, type='cls', |
|
num_linear=1, dropout=0.1, use_sigmoid=False): |
|
|
|
super(BERT2Phase, self).__init__() |
|
self.encoder = bert |
|
self.type = type |
|
self.sigmoid = use_sigmoid |
|
|
|
self.effectiveness = Effectiveness(num_linear, hidden_size, use_sigmoid, dropout) |
|
self.creativity = Creativity(num_linear, hidden_size, use_sigmoid, dropout) |
|
|
|
def forward(self, inputs, return_attention=False): |
|
X = {'input_ids':inputs['input_ids'], |
|
'token_type_ids':inputs['token_type_ids'], |
|
'attention_mask':inputs['attention_mask'], |
|
'output_attentions':return_attention} |
|
encoded_X = self.encoder(**X) |
|
|
|
if self.type == 'cls': |
|
e_pred = self.effectiveness(encoded_X.last_hidden_state[:, 0, :]) |
|
c_pred = self.creativity(encoded_X.last_hidden_state[:, 0, :]) |
|
elif self.type == 'pooler': |
|
e_pred = self.effectiveness(encoded_X.pooler_output) |
|
c_pred = self.creativity(encoded_X.pooler_output) |
|
|
|
return (c_pred, e_pred) if not return_attention else (c_pred, e_pred, encoded_X.attentions) |
|
|
|
def load_model(model_name, pooling_method): |
|
pooling = pooling_method if pooling_method == 'cls' else 'pooler' |
|
if model_name == "One-phase Fine-tuned BERT": |
|
loaded_net = BERTregressor(model, hidden_size=768, num_linear=1, dropout=0.1, o_type=pooling, t_type='C', use_sigmoid=True) |
|
filename = 'model' + f"/OnePhase_BERT_{pooling_method}.pth" |
|
elif model_name == "Two-phase Fine-tuned BERT": |
|
loaded_net = BERT2Phase(model, hidden_size=768, num_linear=1, dropout=0.1, type=pooling, use_sigmoid=True) |
|
filename = 'model' + f"/TwoPhase_BERT_{pooling_method}.pth" |
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token) |
|
loaded_net.load_state_dict(torch.load(model_path)) |
|
loaded_net.eval() |
|
return loaded_net |
|
|