DionTimmer commited on
Commit
69943db
1 Parent(s): b15da3c

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +77 -1
  2. model_class.py +202 -0
README.md CHANGED
@@ -1,3 +1,79 @@
1
  ---
2
- license: mit
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
  ---
7
+ # Whisper Multitask Analyzer
8
+
9
+ A transformer encoder-decoder model for automatic audio captioning. As opposed to speech-to-text, captioning describes the content and features of audio clips.
10
+
11
+ - **Model, codebase & card adapted from:** MU-NLPC/whisper-small-audio-captioning
12
+ - **Model type:** Whisper encoder-decoder transformer
13
+ - **Language(s) (NLP):** en
14
+ - **License:** cc-by-4.0
15
+ - **Parent Model:** openai/whisper-small
16
+
17
+ ## Usage
18
+
19
+ The model expects an audio clip (up to 30s) to the encoder as an input and information about caption style as forced prefix to the decoder.
20
+ The forced prefix is an integer which is mapped to various tasks. This mapping is defined in the model config and can be retrieved with a function.
21
+
22
+ The tag mapping of the current model is:
23
+
24
+ | Task | ID | Description |
25
+ | -------- | -- | ------------------------------------------------------ |
26
+ | tags | 0 | General descriptions, can include genres and features. |
27
+ | genre | 1 | Estimated musical genres. |
28
+ | mood | 2 | Estimated emotional feeling. |
29
+ | movement | 3 | Estimated audio pace and expression. |
30
+ | theme | 4 | Estimated audio usage (not very accurate) |
31
+
32
+ ```
33
+
34
+ Minimal example:
35
+
36
+ ```python
37
+ # Load model
38
+ checkpoint = "DionTimmer/whisper-small-multitask-analyzer"
39
+ model = WhisperForAudioCaptioning.from_pretrained(checkpoint)
40
+ tokenizer = transformers.WhisperTokenizer.from_pretrained(checkpoint, language="en", task="transcribe")
41
+ feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(checkpoint)
42
+
43
+ # Load and preprocess audio
44
+ input_file = "..."
45
+ audio, sampling_rate = librosa.load(input_file, sr=feature_extractor.sampling_rate)
46
+ features = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features
47
+
48
+ # Mappings by ID
49
+ print(model.task_mapping) # {0: 'tags', 1: 'genre', 2: 'mood', 3: 'movement', 4: 'theme'}
50
+
51
+ # Inverted
52
+ print(model.named_task_mapping) # {'tags': 0, 'genre': 1, 'mood': 2, 'movement': 3, 'theme': 4}
53
+
54
+ # Prepare caption style
55
+ style_prefix = f"{model.named_task_mapping['tags']}: "
56
+ style_prefix_tokens = tokenizer("", text_target=style_prefix, return_tensors="pt", add_special_tokens=False).labels
57
+
58
+ # Generate caption
59
+ model.eval()
60
+ outputs = model.generate(
61
+ inputs=features.to(model.device),
62
+ forced_ac_decoder_ids=style_prefix_tokens,
63
+ max_length=100,
64
+ )
65
+
66
+ print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
67
+ ```
68
+
69
+ Example output:
70
+ *0: advertising, beautiful, beauty, bright, cinematic, commercial, corporate, emotional, epic, film, heroic, hopeful, inspiration, inspirational, inspiring, love, love story, movie, orchestra, orchestral, piano, positive, presentation, romantic, sentimental*
71
+
72
+ WhisperTokenizer must be initialized with `language="en"` and `task="transcribe"`.
73
+
74
+ The model class `WhisperForAudioCaptioning` can be found in the git repository or here on the HuggingFace Hub in the model repository. The class overrides default Whisper `generate` method to support forcing decoder prefix.
75
+
76
+
77
+ ## Licence
78
+
79
+ The model weights are published under non-commercial license CC BY-NC 4.0 as the model was finetuned on a dataset for non-commercial use.
model_class.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ from typing import Optional, Tuple, Union
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput
5
+ from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
6
+ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
7
+
8
+
9
+ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ @property
14
+ def task_mapping(self):
15
+ return {v: k for k, v in self.config.task_mapping.items()}
16
+
17
+ @property
18
+ def named_task_mapping(self):
19
+ return self.config.task_mapping
20
+
21
+ def forward(
22
+ self,
23
+ input_features: Optional[torch.FloatTensor] = None,
24
+ attention_mask: Optional[torch.LongTensor] = None,
25
+ decoder_input_ids: Optional[torch.LongTensor] = None,
26
+ decoder_position_ids: Optional[torch.LongTensor] = None,
27
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
28
+ head_mask: Optional[torch.Tensor] = None,
29
+ decoder_head_mask: Optional[torch.Tensor] = None,
30
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
31
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
32
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
33
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
34
+ labels: Optional[torch.LongTensor] = None,
35
+ use_cache: Optional[bool] = None,
36
+ output_attentions: Optional[bool] = None,
37
+ output_hidden_states: Optional[bool] = None,
38
+ return_dict: Optional[bool] = None,
39
+ forced_ac_decoder_ids: Optional[
40
+ torch.LongTensor
41
+ ] = None, # added to be ignored when passed from trainer
42
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
43
+ return super().forward(
44
+ input_features=input_features,
45
+ attention_mask=attention_mask,
46
+ decoder_input_ids=decoder_input_ids,
47
+ decoder_position_ids=decoder_position_ids,
48
+ decoder_attention_mask=decoder_attention_mask,
49
+ head_mask=head_mask,
50
+ decoder_head_mask=decoder_head_mask,
51
+ cross_attn_head_mask=cross_attn_head_mask,
52
+ encoder_outputs=encoder_outputs,
53
+ past_key_values=past_key_values,
54
+ decoder_inputs_embeds=decoder_inputs_embeds,
55
+ labels=labels,
56
+ use_cache=use_cache,
57
+ output_attentions=output_attentions,
58
+ output_hidden_states=output_hidden_states,
59
+ return_dict=return_dict,
60
+ )
61
+
62
+ # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
63
+ def generate(
64
+ self,
65
+ inputs: Optional[torch.Tensor] = None,
66
+ forced_ac_decoder_ids: Optional[torch.Tensor] = None,
67
+ generation_config=None,
68
+ logits_processor=None,
69
+ stopping_criteria=None,
70
+ prefix_allowed_tokens_fn=None,
71
+ synced_gpus=False,
72
+ return_timestamps=None,
73
+ task="transcribe",
74
+ language="english",
75
+ **kwargs,
76
+ ):
77
+ if generation_config is None:
78
+ generation_config = self.generation_config
79
+
80
+ if return_timestamps is not None:
81
+ if not hasattr(generation_config, "no_timestamps_token_id"):
82
+ raise ValueError(
83
+ "You are trying to return timestamps, but the generation config is not properly set."
84
+ "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
85
+ "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
86
+ )
87
+
88
+ generation_config.return_timestamps = return_timestamps
89
+ else:
90
+ generation_config.return_timestamps = False
91
+
92
+ if language is not None:
93
+ generation_config.language = language
94
+ if task is not None:
95
+ generation_config.task = task
96
+
97
+ forced_decoder_ids = []
98
+ if task is not None or language is not None:
99
+ if hasattr(generation_config, "language"):
100
+ if generation_config.language in generation_config.lang_to_id.keys():
101
+ language_token = generation_config.language
102
+ elif generation_config.language in TO_LANGUAGE_CODE.keys():
103
+ language_token = (
104
+ f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
105
+ )
106
+ else:
107
+ raise ValueError(
108
+ f"Unsupported language: {language}. Language should be one of:"
109
+ f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
110
+ )
111
+ forced_decoder_ids.append(
112
+ (1, generation_config.lang_to_id[language_token])
113
+ )
114
+ else:
115
+ forced_decoder_ids.append(
116
+ (1, None)
117
+ ) # automatically detect the language
118
+
119
+ if hasattr(generation_config, "task"):
120
+ if generation_config.task in TASK_IDS:
121
+ forced_decoder_ids.append(
122
+ (2, generation_config.task_to_id[generation_config.task])
123
+ )
124
+ else:
125
+ raise ValueError(
126
+ f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
127
+ )
128
+ else:
129
+ forced_decoder_ids.append(
130
+ (2, generation_config.task_to_id["transcribe"])
131
+ ) # defaults to transcribe
132
+ if (
133
+ hasattr(generation_config, "no_timestamps_token_id")
134
+ and not generation_config.return_timestamps
135
+ ):
136
+ idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
137
+ forced_decoder_ids.append(
138
+ (idx, generation_config.no_timestamps_token_id)
139
+ )
140
+
141
+ # Legacy code for backward compatibility
142
+ elif (
143
+ hasattr(self.config, "forced_decoder_ids")
144
+ and self.config.forced_decoder_ids is not None
145
+ ):
146
+ forced_decoder_ids = self.config.forced_decoder_ids
147
+ elif (
148
+ hasattr(self.generation_config, "forced_decoder_ids")
149
+ and self.generation_config.forced_decoder_ids is not None
150
+ ):
151
+ forced_decoder_ids = self.generation_config.forced_decoder_ids
152
+
153
+ if generation_config.return_timestamps:
154
+ logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
155
+
156
+ decoder_input_ids = None
157
+
158
+ if len(forced_decoder_ids) > 0:
159
+ # get the token sequence coded in forced_decoder_ids
160
+ forced_decoder_ids.sort()
161
+ if min(forced_decoder_ids)[0] != 0:
162
+ forced_decoder_ids = [
163
+ (0, self.config.decoder_start_token_id)
164
+ ] + forced_decoder_ids
165
+
166
+ position_indices, decoder_input_ids = zip(*forced_decoder_ids)
167
+ assert tuple(position_indices) == tuple(
168
+ range(len(position_indices))
169
+ ), "forced_decoder_ids is not a (continuous) prefix, we can't handle that"
170
+
171
+ device = self.get_decoder().device
172
+
173
+ if forced_ac_decoder_ids is None:
174
+ forced_ac_decoder_ids = torch.tensor(
175
+ [[]], device=device, dtype=torch.long
176
+ )
177
+
178
+ # enrich every sample's forced_ac_decoder_ids with Whisper's forced_decoder_ids
179
+ batch_size = forced_ac_decoder_ids.shape[0]
180
+ fluff_len = len(decoder_input_ids)
181
+ decoder_input_ids = torch.tensor(
182
+ decoder_input_ids, device=device, dtype=torch.long
183
+ )
184
+ decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len))
185
+ decoder_input_ids = torch.cat(
186
+ [decoder_input_ids, forced_ac_decoder_ids], dim=1
187
+ )
188
+
189
+ generation_config.forced_decoder_ids = forced_decoder_ids
190
+
191
+ return super(
192
+ transformers.WhisperPreTrainedModel, self
193
+ ).generate( # changed by adam (calling grandparent)
194
+ inputs,
195
+ generation_config,
196
+ logits_processor,
197
+ stopping_criteria,
198
+ prefix_allowed_tokens_fn,
199
+ synced_gpus,
200
+ decoder_input_ids=decoder_input_ids,
201
+ **kwargs,
202
+ )