import os import torch from torch import nn from transformers import AutoModel from huggingface_hub import hf_hub_download token=os.getenv("HF_TOKEN") repo_id = "Siyunb323/CreativityEvaluation" model = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese") class BERTregressor(nn.Module): def __init__(self, bert, hidden_size=768, num_linear=1, dropout=0.1, o_type='cls', t_type= 'C', use_sigmoid=False): super(BERTregressor, self).__init__() self.encoder = bert self.o_type = o_type self.t_type = t_type self.sigmoid = use_sigmoid if num_linear == 2: layers = [nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1)] elif num_linear == 1: layers = [nn.Dropout(dropout), nn.Linear(hidden_size, 1)] if use_sigmoid: layers.append(nn.Sigmoid()) self.output = nn.Sequential(*layers) def forward(self, inputs, return_attention=False): X = {'input_ids':inputs['input_ids'], 'token_type_ids':inputs['token_type_ids'], 'attention_mask':inputs['attention_mask'], 'output_attentions':return_attention} encoded_X = self.encoder(**X) if self.o_type == 'cls': output = self.output(encoded_X.last_hidden_state[:, 0, :]) elif self.o_type == 'pooler': output = self.output(encoded_X.pooler_output) output = 4 * output.squeeze(-1) + 1 if self.sigmoid and self.t_type == 'C' else output.squeeze(-1) return output if not return_attention else (output, encoded_X.attentions) class Effectiveness(nn.Module): def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs): super(Effectiveness, self).__init__(**kwargs) self.sigmoid = use_sigmoid if num_layers == 2: layers = [ nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1) ] else: layers = [ nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size, 1) ] if use_sigmoid: layers.append(nn.Sigmoid()) # 仅在需要时添加 Sigmoid 层 self.output = nn.Sequential(*layers) def forward(self, X): output = self.output(X) # 如果使用 Sigmoid 层,调整输出范围到 [1, 5] if self.sigmoid: return 4 * output.squeeze(-1) + 1 else: return output.squeeze(-1) class Creativity(nn.Module): """BERT的下一句预测任务""" def __init__(self, num_layers, hidden_size=768, use_sigmoid=True, dropout=0.2, **kwargs): super(Creativity, self).__init__(**kwargs) self.sigmoid = use_sigmoid if num_layers == 2: layers = [ nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1) ] else: layers = [ nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size, 1) ] if use_sigmoid: layers.append(nn.Sigmoid()) # 仅在需要时添加 Sigmoid 层 self.output = nn.Sequential(*layers) def forward(self, X): output = self.output(X) # 如果使用 Sigmoid 层,调整输出范围到 [1, 5] if self.sigmoid: return 4 * output.squeeze(-1) + 1 else: return output.squeeze(-1) class BERT2Phase(nn.Module): def __init__(self, bert, hidden_size=768, type='cls', num_linear=1, dropout=0.1, use_sigmoid=False): super(BERT2Phase, self).__init__() self.encoder = bert self.type = type self.sigmoid = use_sigmoid self.effectiveness = Effectiveness(num_linear, hidden_size, use_sigmoid, dropout) self.creativity = Creativity(num_linear, hidden_size, use_sigmoid, dropout) def forward(self, inputs, return_attention=False): X = {'input_ids':inputs['input_ids'], 'token_type_ids':inputs['token_type_ids'], 'attention_mask':inputs['attention_mask'], 'output_attentions':return_attention} encoded_X = self.encoder(**X) if self.type == 'cls': e_pred = self.effectiveness(encoded_X.last_hidden_state[:, 0, :]) c_pred = self.creativity(encoded_X.last_hidden_state[:, 0, :]) elif self.type == 'pooler': e_pred = self.effectiveness(encoded_X.pooler_output) c_pred = self.creativity(encoded_X.pooler_output) return (c_pred, e_pred) if not return_attention else (c_pred, e_pred, encoded_X.attentions) def load_model(model_name, pooling_method): pooling = pooling_method if pooling_method == 'cls' else 'pooler' if model_name == "One-phase Fine-tuned BERT": loaded_net = BERTregressor(model, hidden_size=768, num_linear=1, dropout=0.1, o_type=pooling, t_type='C', use_sigmoid=True) filename = 'model' + f"/OnePhase_BERT_{pooling_method}.pth" elif model_name == "Two-phase Fine-tuned BERT": loaded_net = BERT2Phase(model, hidden_size=768, num_linear=1, dropout=0.1, type=pooling, use_sigmoid=True) filename = 'model' + f"/TwoPhase_BERT_{pooling_method}.pth" model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token) loaded_net.load_state_dict(torch.load(model_path)) loaded_net.eval() return loaded_net