Andrei-Iulian SĂCELEANU commited on
Commit
1f3a9b6
1 Parent(s): 7d0a00c

added audio tab

Browse files
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import re
2
  import gradio as gr
3
- from transformers import AutoTokenizer
 
 
4
  from unidecode import unidecode
5
  from models import *
6
 
7
 
8
  tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
 
9
 
10
  def preprocess(x):
11
  """Preprocess input string x"""
@@ -21,6 +24,7 @@ def preprocess(x):
21
  return s
22
 
23
  label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
 
24
 
25
  def ssl_predict(in_text, model_type):
26
  """main predict function"""
@@ -39,12 +43,12 @@ def ssl_predict(in_text, model_type):
39
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
40
  model.load_weights("./checkpoints/fixmatch_tune")
41
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
42
-
43
  elif model_type == "freematch":
44
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
45
  model.cls_head.load_weights("./checkpoints/freematch_tune")
46
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
47
-
48
  elif model_type == "mixmatch":
49
  model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
50
  model.cls_head.load_weights("./checkpoints/mixmatch")
@@ -68,37 +72,118 @@ def ssl_predict(in_text, model_type):
68
  return d
69
 
70
 
 
 
 
 
71
 
 
 
 
 
 
72
 
73
- with gr.Blocks() as ssl_interface:
74
- with gr.Row():
75
- with gr.Column():
76
- in_text = gr.Textbox(label="Input text")
77
- model_list = gr.Dropdown(
78
- choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"],
79
- max_choices=1,
80
- label="Training method",
81
- allow_custom_value=False,
82
- info="Select trained model according to different SSL techniques from paper",
83
- )
84
-
85
- with gr.Row():
86
- clear_btn = gr.Button(value="Clear")
87
- submit_btn = gr.Button(value="Submit")
88
-
89
- with gr.Column():
90
- out_field = gr.Label(num_top_classes=4, label="Prediction")
91
-
92
- submit_btn.click(
93
- fn=ssl_predict,
94
- inputs=[in_text, model_list],
95
- outputs=[out_field]
96
- )
97
 
98
- clear_btn.click(
99
- fn=lambda: [None for _ in range(2)],
100
- inputs=None,
101
- outputs=[in_text, out_field]
 
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  ssl_interface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import re
2
  import gradio as gr
3
+ import librosa
4
+ import numpy as np
5
+ from transformers import AutoTokenizer,ViTImageProcessor
6
  from unidecode import unidecode
7
  from models import *
8
 
9
 
10
  tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
11
+ processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
12
 
13
  def preprocess(x):
14
  """Preprocess input string x"""
 
24
  return s
25
 
26
  label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
27
+ audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"]
28
 
29
  def ssl_predict(in_text, model_type):
30
  """main predict function"""
 
43
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
44
  model.load_weights("./checkpoints/fixmatch_tune")
45
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
46
+
47
  elif model_type == "freematch":
48
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
49
  model.cls_head.load_weights("./checkpoints/freematch_tune")
50
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
51
+
52
  elif model_type == "mixmatch":
53
  model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
54
  model.cls_head.load_weights("./checkpoints/mixmatch")
 
72
  return d
73
 
74
 
75
+ def ssl_predict2(audio_file, model_type):
76
+ """main predict function"""
77
+
78
+ signal, sr = librosa.load(audio_file.name, sr=16000)
79
 
80
+ length = 5 * 16000
81
+ if len(signal) < length:
82
+ signal = np.pad(signal,(0,length-len(signal)),'constant')
83
+ else:
84
+ signal = signal[:length]
85
 
86
+ spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128)
87
+ spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max)
88
+ spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max()
89
+ spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min)
90
+ spectrogram = spectrogram.astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ inputs = processor.preprocess(
93
+ np.repeat(spectrogram[:,:,:,np.newaxis],3,-1),
94
+ image_mean=(-3.05,-3.05,-3.05),
95
+ image_std=(2.33,2.33,2.33),
96
+ return_tensors="tf"
97
  )
98
 
99
+ preds = None
100
+ if model_type == "fixmatch":
101
+ model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch")
102
+ model.cls_head.load_weights("./checkpoints/audio_fixmatch")
103
+ preds, _ = model(inputs["pixel_values"], training=False)
104
+
105
+ elif model_type == "freematch":
106
+ model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch")
107
+ model.cls_head.load_weights("./checkpoints/audio_freematch")
108
+ preds, _ = model(inputs["pixel_values"], training=False)
109
+
110
+ elif model_type == "mixmatch":
111
+ model = AudioMixMatch(bert_model="andrei-saceleanu/vit-base-mixmatch")
112
+ model.cls_head.load_weights("./checkpoints/audio_mixmatch")
113
+ preds = model(inputs["pixel_values"], training=False)
114
+
115
+ probs = list(preds[0].numpy())
116
+
117
+ d = {}
118
+ for k, v in zip(audio_label_names, probs):
119
+ d[k] = float(v)
120
+ return d
121
+
122
+ with gr.Blocks() as ssl_interface:
123
+
124
+ with gr.Tab("Text (RO-Offense)"):
125
+ with gr.Row():
126
+ with gr.Column():
127
+ in_text = gr.Textbox(label="Input text")
128
+ model_list = gr.Dropdown(
129
+ choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"],
130
+ max_choices=1,
131
+ label="Training method",
132
+ allow_custom_value=False,
133
+ info="Select trained model according to different SSL techniques from paper",
134
+ )
135
+
136
+ with gr.Row():
137
+ clear_btn = gr.Button(value="Clear")
138
+ submit_btn = gr.Button(value="Submit")
139
+
140
+ with gr.Column():
141
+ out_field = gr.Label(num_top_classes=4, label="Prediction")
142
+
143
+ submit_btn.click(
144
+ fn=ssl_predict,
145
+ inputs=[in_text, model_list],
146
+ outputs=[out_field]
147
+ )
148
+
149
+ clear_btn.click(
150
+ fn=lambda: [None for _ in range(2)],
151
+ inputs=None,
152
+ outputs=[in_text, out_field]
153
+ )
154
+ with gr.Tab("Audio (VocalSound)"):
155
+ with gr.Row():
156
+ with gr.Column():
157
+ audio_file = gr.File(
158
+ label="Input audio",
159
+ file_count="single",
160
+ file_types=["audio"]
161
+ )
162
+ model_list2 = gr.Dropdown(
163
+ choices=["fixmatch", "freematch", "mixmatch"],
164
+ max_choices=1,
165
+ label="Training method",
166
+ allow_custom_value=False,
167
+ info="Select trained model according to different SSL techniques from paper",
168
+ )
169
+
170
+ with gr.Row():
171
+ clear_btn2 = gr.Button(value="Clear")
172
+ submit_btn2 = gr.Button(value="Submit")
173
+
174
+ with gr.Column():
175
+ out_field2 = gr.Label(num_top_classes=6, label="Prediction")
176
+
177
+ submit_btn2.click(
178
+ fn=ssl_predict2,
179
+ inputs=[audio_file, model_list2],
180
+ outputs=[out_field2]
181
+ )
182
+
183
+ clear_btn2.click(
184
+ fn=lambda: [None for _ in range(2)],
185
+ inputs=None,
186
+ outputs=[audio_file, out_field2]
187
+ )
188
+
189
  ssl_interface.launch(server_name="0.0.0.0", server_port=7860)
checkpoints/audio_fixmatch.data-00000-of-00001 ADDED
Binary file (856 kB). View file
 
checkpoints/audio_fixmatch.index ADDED
Binary file (518 Bytes). View file
 
checkpoints/audio_freematch.data-00000-of-00001 ADDED
Binary file (856 kB). View file
 
checkpoints/audio_freematch.index ADDED
Binary file (518 Bytes). View file
 
checkpoints/audio_mixmatch.data-00000-of-00001 ADDED
Binary file (856 kB). View file
 
checkpoints/audio_mixmatch.index ADDED
Binary file (518 Bytes). View file
 
models.py CHANGED
@@ -1,6 +1,7 @@
1
  """Model definitions"""
2
  import tensorflow as tf
3
- from transformers import TFAutoModel
 
4
 
5
 
6
  class FixMatchTune(tf.keras.Model):
@@ -82,4 +83,73 @@ class LPModel(tf.keras.Model):
82
 
83
  embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
84
 
85
- return self.cls_head(embeds, training=training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Model definitions"""
2
  import tensorflow as tf
3
+ from transformers import TFAutoModel, TFViTModel
4
+ from kapre.augmentation import SpecAugment
5
 
6
 
7
  class FixMatchTune(tf.keras.Model):
 
83
 
84
  embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
85
 
86
+ return self.cls_head(embeds, training=training)
87
+
88
+ class AudioFixMatch(tf.keras.Model):
89
+ def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs):
90
+ super(AudioFixMatch, self).__init__(**kwargs)
91
+ self.vit = TFViTModel.from_pretrained(encoder_name)
92
+ self.num_classes = num_classes
93
+ self.cls_head = tf.keras.Sequential([
94
+ tf.keras.layers.Dense(256,activation="relu"),
95
+ tf.keras.layers.Dropout(0.2),
96
+ tf.keras.layers.Dense(64,activation="relu"),
97
+ tf.keras.layers.Dense(self.num_classes, activation="softmax")
98
+ ])
99
+ self.strong_augment = SpecAugment(
100
+ freq_mask_param=8,
101
+ time_mask_param=8,
102
+ n_freq_masks=2,
103
+ n_time_masks=2,
104
+ mask_value=0.0,
105
+ data_format="channels_first"
106
+ )
107
+ self.weak_augment = SpecAugment(
108
+ freq_mask_param=2,
109
+ time_mask_param=2,
110
+ n_freq_masks=2,
111
+ n_time_masks=2,
112
+ mask_value=0.0,
113
+ data_format="channels_first"
114
+ )
115
+
116
+ def call(self, inputs, training):
117
+
118
+ strong = self.strong_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
119
+ weak = self.weak_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
120
+ embeds_strong = self.vit(pixel_values=tf.repeat(strong,3,axis=1),training=training).pooler_output
121
+ embeds_weak = self.vit(pixel_values=tf.repeat(weak,3,axis=1),training=training).pooler_output
122
+
123
+ return self.cls_head(embeds_weak), self.cls_head(embeds_strong)
124
+
125
+ class AudioMixMatch(tf.keras.Model):
126
+ def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs):
127
+ super(AudioMixMatch, self).__init__(**kwargs)
128
+ self.vit = TFViTModel.from_pretrained(encoder_name)
129
+ self.num_classes = num_classes
130
+ self.cls_head = tf.keras.Sequential([
131
+ tf.keras.layers.Dense(256,activation="relu"),
132
+ tf.keras.layers.Dropout(0.2),
133
+ tf.keras.layers.Dense(64,activation="relu"),
134
+ tf.keras.layers.Dense(self.num_classes, activation="softmax")
135
+ ])
136
+ self.augment = SpecAugment(
137
+ freq_mask_param=3,
138
+ time_mask_param=3,
139
+ n_freq_masks=2,
140
+ n_time_masks=2,
141
+ mask_value=0.0,
142
+ data_format="channels_first"
143
+ )
144
+
145
+ def aug_features(self, inputs, training):
146
+ aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
147
+ embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output
148
+ return embeds
149
+
150
+ def call(self, inputs, training):
151
+
152
+ aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
153
+ embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output
154
+
155
+ return self.cls_head(embeds)