File size: 1,262 Bytes
02768a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from transformers import TFAutoModel


class FixMatchTune(tf.keras.Model):
    def __init__(
        self,
        encoder_name="readerbench/RoBERT-base",
        num_classes=4,
        **kwargs
    ):
        super(FixMatchTune,self).__init__(**kwargs)

        self.bert = TFAutoModel.from_pretrained(encoder_name)
        self.num_classes = num_classes
        self.weak_augment = tf.keras.layers.GaussianNoise(stddev=0.5)
        self.strong_augment = tf.keras.layers.GaussianNoise(stddev=5)

        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])

    def call(self, inputs, training):
        ids, mask = inputs

        embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output

        strongs = self.strong_augment(embeds,training=training)
        weaks = self.weak_augment(embeds,training=training)

        strong_preds = self.cls_head(strongs,training=training)
        weak_preds = self.cls_head(weaks,training=training)

        return weak_preds, strong_preds