File size: 6,039 Bytes
6c0ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-
# @Time    : 2022/4/21 5:30 下午
# @Author  : JianingWang
# @File    : span_proto.py

"""
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
"""

import os
from typing import Optional
import torch
import numpy as np
import torch.nn as nn
from typing import Union
from dataclasses import dataclass
from torch.nn import BCEWithLogitsLoss
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
from transformers.file_utils import ModelOutput
from transformers.models.bert import BertPreTrainedModel, BertModel


@dataclass
class TokenProtoOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None


class TokenProto(nn.Module):
    def __init__(self, config):
        """
        word_encoder: Sentence encoder

        You need to set self.cost as your own loss function.
        """
        nn.Module.__init__(self)
        self.config = config
        self.output_dir = "./outputs"
        # self.predict_dir = self.predict_result_path(self.output_dir)
        self.drop = nn.Dropout()
        self.projector = nn.Sequential( # projector
            nn.Linear(self.config.hidden_size, self.config.hidden_size),
            nn.Sigmoid(),
            # nn.LayerNorm(2)
        )
        self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
        # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.max_length = 64
        self.margin_distance = 6.0
        self.global_step = 0

    def predict_result_path(self, path=None):
        if path is None:
            predict_dir = os.path.join(
                self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
            )
        else:
            predict_dir = os.path.join(
                path, "predict"
            )
        # if os.path.exists(predict_dir):
        #     os.rmdir(predict_dir) # 删除历史记录
        if not os.path.exists(predict_dir): # 重新创建一个新的目录
            os.makedirs(predict_dir)
        return predict_dir

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        config = kwargs.pop("config", None)
        model = TokenProto(config=config)
        return model

    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return -(torch.pow(x - y, 2)).sum(dim)

    def __batch_dist__(self, S, Q, q_mask):
        # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
        assert Q.size()[:2] == q_mask.size()
        Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
        return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)

    def __get_proto__(self, embedding, tag, mask):
        proto = []
        embedding = embedding[mask==1].view(-1, embedding.size(-1))
        tag = torch.cat(tag, 0)
        assert tag.size(0) == embedding.size(0)
        for label in range(torch.max(tag)+1):
            proto.append(torch.mean(embedding[tag==label], 0))
        proto = torch.stack(proto)
        return proto, embedding

    def forward(self, support, query):
        """
        support: Inputs of the support set.
        query: Inputs of the query set.
        N: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances in the query set

        support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []}

        """
        # support set和query set分别喂入BERT中获得各个样本的表示
        support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768]
        query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768]
        support_emb = self.drop(support_emb)
        query_emb = self.drop(query_emb)

        # Prototypical Networks
        logits = []
        current_support_num = 0
        current_query_num = 0
        assert support_emb.size()[:2] == support["mask"].size()
        assert query_emb.size()[:2] == query["mask"].size()

        for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据
            sent_query_num = query["sentence_num"][i]
            # Calculate prototype for each class
            # 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
            # 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
            support_proto, embedding = self.__get_proto__(
                support_emb[current_support_num:current_support_num+sent_support_num],
                support["label"][current_support_num:current_support_num+sent_support_num],
                support["text_mask"][current_support_num: current_support_num+sent_support_num])
            # calculate distance to each prototype
            logits.append(self.__batch_dist__(
                support_proto,
                query_emb[current_query_num:current_query_num+sent_query_num],
                query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
            current_query_num += sent_query_num
            current_support_num += sent_support_num
        logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
        _, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果


        # return logits, pred, embedding

        return TokenProtoOutput(
            logits=logits
        )  # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错