Spaces:
Paused
Paused
"""Model definitions""" | |
import tensorflow as tf | |
from transformers import TFAutoModel, TFViTModel | |
from kapre.augmentation import SpecAugment | |
class FixMatchTune(tf.keras.Model): | |
"""fixmatch""" | |
def __init__( | |
self, | |
encoder_name="readerbench/RoBERT-base", | |
num_classes=4, | |
**kwargs | |
): | |
super(FixMatchTune,self).__init__(**kwargs) | |
self.bert = TFAutoModel.from_pretrained(encoder_name) | |
self.num_classes = num_classes | |
self.weak_augment = tf.keras.layers.GaussianNoise(stddev=0.5) | |
self.strong_augment = tf.keras.layers.GaussianNoise(stddev=5) | |
self.cls_head = tf.keras.Sequential([ | |
tf.keras.layers.Dense(256,activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(64,activation="relu"), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax") | |
]) | |
def call(self, inputs, training): | |
ids, mask = inputs | |
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output | |
strongs = self.strong_augment(embeds,training=training) | |
weaks = self.weak_augment(embeds,training=training) | |
strong_preds = self.cls_head(strongs,training=training) | |
weak_preds = self.cls_head(weaks,training=training) | |
return weak_preds, strong_preds | |
class MixMatch(tf.keras.Model): | |
"""mixmatch""" | |
def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs): | |
super(MixMatch,self).__init__(**kwargs) | |
self.bert = TFAutoModel.from_pretrained(bert_model) | |
self.num_classes = num_classes | |
self.cls_head = tf.keras.Sequential([ | |
tf.keras.layers.Dense(256,activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(64,activation="relu"), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax") | |
]) | |
self.augment = tf.keras.layers.GaussianNoise(stddev=2) | |
def call(self, inputs, training): | |
ids, mask = inputs | |
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output | |
augs = self.augment(embeds,training=training) | |
return self.cls_head(augs,training=training) | |
class LPModel(tf.keras.Model): | |
"""label propagation""" | |
def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs): | |
super(LPModel,self).__init__(**kwargs) | |
self.bert = TFAutoModel.from_pretrained(bert_model) | |
self.num_classes = num_classes | |
self.cls_head = tf.keras.Sequential([ | |
tf.keras.layers.Dense(256,activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(64,activation="relu"), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax") | |
]) | |
def call(self, inputs, training): | |
ids, mask = inputs | |
embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output | |
return self.cls_head(embeds, training=training) | |
class AudioFixMatch(tf.keras.Model): | |
def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs): | |
super(AudioFixMatch, self).__init__(**kwargs) | |
self.vit = TFViTModel.from_pretrained(encoder_name) | |
self.num_classes = num_classes | |
self.cls_head = tf.keras.Sequential([ | |
tf.keras.layers.Dense(256,activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(64,activation="relu"), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax") | |
]) | |
self.strong_augment = SpecAugment( | |
freq_mask_param=8, | |
time_mask_param=8, | |
n_freq_masks=2, | |
n_time_masks=2, | |
mask_value=0.0, | |
data_format="channels_first" | |
) | |
self.weak_augment = SpecAugment( | |
freq_mask_param=2, | |
time_mask_param=2, | |
n_freq_masks=2, | |
n_time_masks=2, | |
mask_value=0.0, | |
data_format="channels_first" | |
) | |
def call(self, inputs, training): | |
strong = self.strong_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training) | |
weak = self.weak_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training) | |
embeds_strong = self.vit(pixel_values=tf.repeat(strong,3,axis=1),training=training).pooler_output | |
embeds_weak = self.vit(pixel_values=tf.repeat(weak,3,axis=1),training=training).pooler_output | |
return self.cls_head(embeds_weak), self.cls_head(embeds_strong) | |
class AudioMixMatch(tf.keras.Model): | |
def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs): | |
super(AudioMixMatch, self).__init__(**kwargs) | |
self.vit = TFViTModel.from_pretrained(encoder_name) | |
self.num_classes = num_classes | |
self.cls_head = tf.keras.Sequential([ | |
tf.keras.layers.Dense(256,activation="relu"), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(64,activation="relu"), | |
tf.keras.layers.Dense(self.num_classes, activation="softmax") | |
]) | |
self.augment = SpecAugment( | |
freq_mask_param=3, | |
time_mask_param=3, | |
n_freq_masks=2, | |
n_time_masks=2, | |
mask_value=0.0, | |
data_format="channels_first" | |
) | |
def aug_features(self, inputs, training): | |
aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training) | |
embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output | |
return embeds | |
def call(self, inputs, training): | |
aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training) | |
embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output | |
return self.cls_head(embeds) |