Siyunb323's picture
update
e1af10a
raw
history blame
5.78 kB
import os
import torch
from torch import nn
from transformers import AutoModel
from huggingface_hub import hf_hub_download
token=os.getenv("HF_TOKEN")
repo_id = "Siyunb323/CreativityEvaluation"
model = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese")
class BERTregressor(nn.Module):
def __init__(self, bert, hidden_size=768, num_linear=1, dropout=0.1,
o_type='cls', t_type= 'C', use_sigmoid=False):
super(BERTregressor, self).__init__()
self.encoder = bert
self.o_type = o_type
self.t_type = t_type
self.sigmoid = use_sigmoid
if num_linear == 2:
layers = [nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 1)]
elif num_linear == 1:
layers = [nn.Dropout(dropout),
nn.Linear(hidden_size, 1)]
if use_sigmoid:
layers.append(nn.Sigmoid())
self.output = nn.Sequential(*layers)
def forward(self, inputs, return_attention=False):
X = {'input_ids':inputs['input_ids'],
'token_type_ids':inputs['token_type_ids'],
'attention_mask':inputs['attention_mask'],
'output_attentions':return_attention}
encoded_X = self.encoder(**X)
if self.o_type == 'cls':
output = self.output(encoded_X.last_hidden_state[:, 0, :])
elif self.o_type == 'pooler':
output = self.output(encoded_X.pooler_output)
output = 4 * output.squeeze(-1) + 1 if self.sigmoid and self.t_type == 'C' else output.squeeze(-1)
return output if not return_attention else (output, encoded_X.attentions)
class Effectiveness(nn.Module):
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs):
super(Effectiveness, self).__init__(**kwargs)
self.sigmoid = use_sigmoid
if num_layers == 2:
layers = [
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 1)
]
else:
layers = [
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, 1)
]
if use_sigmoid:
layers.append(nn.Sigmoid()) # 仅在需要时添加 Sigmoid 层
self.output = nn.Sequential(*layers)
def forward(self, X):
output = self.output(X)
# 如果使用 Sigmoid 层,调整输出范围到 [1, 5]
if self.sigmoid:
return 4 * output.squeeze(-1) + 1
else:
return output.squeeze(-1)
class Creativity(nn.Module):
"""BERT的下一句预测任务"""
def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs):
super(Creativity, self).__init__(**kwargs)
self.sigmoid = use_sigmoid
if num_layers == 2:
layers = [
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 1)
]
else:
layers = [
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, 1)
]
if use_sigmoid:
layers.append(nn.Sigmoid()) # 仅在需要时添加 Sigmoid 层
self.output = nn.Sequential(*layers)
def forward(self, X):
output = self.output(X)
# 如果使用 Sigmoid 层,调整输出范围到 [1, 5]
if self.sigmoid:
return 4 * output.squeeze(-1) + 1
else:
return output.squeeze(-1)
class BERT2Phase(nn.Module):
def __init__(self, bert, hidden_size=768, type='cls',
num_linear=1, dropout=0.1, use_sigmoid=False):
super(BERT2Phase, self).__init__()
self.encoder = bert
self.type = type
self.sigmoid = use_sigmoid
self.effectiveness = Effectiveness(num_linear, hidden_size, use_sigmoid, dropout)
self.creativity = Creativity(num_linear, hidden_size, use_sigmoid, dropout)
def forward(self, inputs, return_attention=False):
X = {'input_ids':inputs['input_ids'],
'token_type_ids':inputs['token_type_ids'],
'attention_mask':inputs['attention_mask'],
'output_attentions':return_attention}
encoded_X = self.encoder(**X)
if self.type == 'cls':
e_pred = self.effectiveness(encoded_X.last_hidden_state[:, 0, :])
c_pred = self.creativity(encoded_X.last_hidden_state[:, 0, :])
elif self.type == 'pooler':
e_pred = self.effectiveness(encoded_X.pooler_output)
c_pred = self.creativity(encoded_X.pooler_output)
return (c_pred, e_pred) if not return_attention else (c_pred, e_pred, encoded_X.attentions)
def load_model(model_name, pooling_method):
pooling = pooling_method if pooling_method == 'cls' else 'pooler'
if model_name == "One-phase Fine-tuned BERT":
loaded_net = BERTregressor(model, hidden_size=768, num_linear=1, dropout=0.1, o_type=pooling, t_type='C', use_sigmoid=True)
filename = 'model' + f"/OnePhase_BERT_{pooling_method}.pth"
elif model_name == "Two-phase Fine-tuned BERT":
loaded_net = BERT2Phase(model, hidden_size=768, num_linear=1, dropout=0.1, type=pooling, use_sigmoid=True)
filename = 'model' + f"/TwoPhase_BERT_{pooling_method}.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token)
loaded_net.load_state_dict(torch.load(model_path))
loaded_net.eval()
return loaded_net