Helw150
commited on
Commit
·
d41acc6
1
Parent(s):
11bcf74
Make More Generic; Reduce Config Size
Browse files- config.json +2 -128
- modeling_diva.py +29 -31
config.json
CHANGED
@@ -1,138 +1,12 @@
|
|
1 |
{
|
2 |
"model_type": "diva",
|
|
|
|
|
3 |
"architectures": [ "DiVAModel" ],
|
4 |
"auto_map": {
|
5 |
"AutoConfig": "configuring_diva.DiVAConfig",
|
6 |
"AutoModel": "modeling_diva.DiVAModel"
|
7 |
},
|
8 |
"vocab_size": 128256,
|
9 |
-
"decoder": {
|
10 |
-
"architectures": [
|
11 |
-
"LlamaForCausalLM"
|
12 |
-
],
|
13 |
-
"attention_bias": false,
|
14 |
-
"attention_dropout": 0,
|
15 |
-
"bos_token_id": 128000,
|
16 |
-
"eos_token_id": 128001,
|
17 |
-
"hidden_act": "silu",
|
18 |
-
"hidden_size": 4096,
|
19 |
-
"initializer_range": 0.02,
|
20 |
-
"intermediate_size": 14336,
|
21 |
-
"max_position_embeddings": 8192,
|
22 |
-
"model_type": "llama",
|
23 |
-
"num_attention_heads": 32,
|
24 |
-
"num_hidden_layers": 32,
|
25 |
-
"num_key_value_heads": 8,
|
26 |
-
"pretraining_tp": 1,
|
27 |
-
"rms_norm_eps": 1e-05,
|
28 |
-
"rope_scaling": null,
|
29 |
-
"rope_theta": 500000,
|
30 |
-
"tie_word_embeddings": false,
|
31 |
-
"torch_dtype": "bfloat16",
|
32 |
-
"transformers_version": "4.40.0.dev0",
|
33 |
-
"use_cache": true,
|
34 |
-
"vocab_size": 128256
|
35 |
-
},
|
36 |
-
"encoder": {
|
37 |
-
"_name_or_path": "openai/whisper-large-v3",
|
38 |
-
"activation_dropout": 0,
|
39 |
-
"activation_function": "gelu",
|
40 |
-
"add_cross_attention": false,
|
41 |
-
"apply_spec_augment": false,
|
42 |
-
"architectures": [
|
43 |
-
"WhisperForConditionalGeneration"
|
44 |
-
],
|
45 |
-
"attention_dropout": 0,
|
46 |
-
"bad_words_ids": null,
|
47 |
-
"begin_suppress_tokens": [
|
48 |
-
220,
|
49 |
-
50257
|
50 |
-
],
|
51 |
-
"bos_token_id": 50257,
|
52 |
-
"chunk_size_feed_forward": 0,
|
53 |
-
"classifier_proj_size": 256,
|
54 |
-
"cross_attention_hidden_size": null,
|
55 |
-
"d_model": 1280,
|
56 |
-
"decoder_attention_heads": 20,
|
57 |
-
"decoder_ffn_dim": 5120,
|
58 |
-
"decoder_layerdrop": 0,
|
59 |
-
"decoder_layers": 32,
|
60 |
-
"decoder_start_token_id": 50258,
|
61 |
-
"diversity_penalty": 0,
|
62 |
-
"do_sample": false,
|
63 |
-
"dropout": 0,
|
64 |
-
"early_stopping": false,
|
65 |
-
"encoder_attention_heads": 20,
|
66 |
-
"encoder_ffn_dim": 5120,
|
67 |
-
"encoder_layerdrop": 0,
|
68 |
-
"encoder_layers": 32,
|
69 |
-
"encoder_no_repeat_ngram_size": 0,
|
70 |
-
"eos_token_id": 50257,
|
71 |
-
"exponential_decay_length_penalty": null,
|
72 |
-
"finetuning_task": null,
|
73 |
-
"forced_bos_token_id": null,
|
74 |
-
"forced_eos_token_id": null,
|
75 |
-
"id2label": {
|
76 |
-
"0": "LABEL_0",
|
77 |
-
"1": "LABEL_1"
|
78 |
-
},
|
79 |
-
"init_std": 0.02,
|
80 |
-
"is_decoder": false,
|
81 |
-
"is_encoder_decoder": true,
|
82 |
-
"label2id": {
|
83 |
-
"LABEL_0": 0,
|
84 |
-
"LABEL_1": 1
|
85 |
-
},
|
86 |
-
"length_penalty": 1,
|
87 |
-
"mask_feature_length": 10,
|
88 |
-
"mask_feature_min_masks": 0,
|
89 |
-
"mask_feature_prob": 0,
|
90 |
-
"mask_time_length": 10,
|
91 |
-
"mask_time_min_masks": 2,
|
92 |
-
"mask_time_prob": 0.05,
|
93 |
-
"max_length": 448,
|
94 |
-
"max_source_positions": 1500,
|
95 |
-
"max_target_positions": 448,
|
96 |
-
"median_filter_width": 7,
|
97 |
-
"min_length": 0,
|
98 |
-
"model_type": "whisper",
|
99 |
-
"no_repeat_ngram_size": 0,
|
100 |
-
"num_beam_groups": 1,
|
101 |
-
"num_beams": 1,
|
102 |
-
"num_hidden_layers": 32,
|
103 |
-
"num_mel_bins": 128,
|
104 |
-
"num_return_sequences": 1,
|
105 |
-
"output_attentions": false,
|
106 |
-
"output_hidden_states": false,
|
107 |
-
"output_scores": false,
|
108 |
-
"pad_token_id": 50256,
|
109 |
-
"prefix": null,
|
110 |
-
"problem_type": null,
|
111 |
-
"pruned_heads": {},
|
112 |
-
"remove_invalid_values": false,
|
113 |
-
"repetition_penalty": 1,
|
114 |
-
"return_dict": true,
|
115 |
-
"return_dict_in_generate": false,
|
116 |
-
"scale_embedding": false,
|
117 |
-
"sep_token_id": null,
|
118 |
-
"suppress_tokens": null,
|
119 |
-
"task_specific_params": null,
|
120 |
-
"temperature": 1,
|
121 |
-
"tf_legacy_loss": false,
|
122 |
-
"tie_encoder_decoder": false,
|
123 |
-
"tie_word_embeddings": true,
|
124 |
-
"tokenizer_class": null,
|
125 |
-
"top_k": 50,
|
126 |
-
"top_p": 1,
|
127 |
-
"torch_dtype": "float16",
|
128 |
-
"torchscript": false,
|
129 |
-
"transformers_version": "4.38.2",
|
130 |
-
"typical_p": 1,
|
131 |
-
"use_bfloat16": false,
|
132 |
-
"use_cache": true,
|
133 |
-
"use_weighted_layer_sum": false,
|
134 |
-
"vocab_size": 51866
|
135 |
-
},
|
136 |
-
"time_dialation": 4,
|
137 |
"transformers_version": "4.38.2"
|
138 |
}
|
|
|
1 |
{
|
2 |
"model_type": "diva",
|
3 |
+
"reference_encoder": "openai/whisper-large-v3",
|
4 |
+
"reference_decoder": "meta-llama/Meta-Llama-3-8B-Instruct",
|
5 |
"architectures": [ "DiVAModel" ],
|
6 |
"auto_map": {
|
7 |
"AutoConfig": "configuring_diva.DiVAConfig",
|
8 |
"AutoModel": "modeling_diva.DiVAModel"
|
9 |
},
|
10 |
"vocab_size": 128256,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"transformers_version": "4.38.2"
|
12 |
}
|
modeling_diva.py
CHANGED
@@ -10,13 +10,13 @@ import torch.nn.functional as F
|
|
10 |
from datasets import Audio
|
11 |
from safetensors.torch import load, load_model
|
12 |
from torch import nn
|
13 |
-
from
|
14 |
from transformers import (
|
15 |
AutoProcessor,
|
16 |
AutoTokenizer,
|
17 |
-
|
18 |
PreTrainedModel,
|
19 |
-
|
20 |
)
|
21 |
|
22 |
|
@@ -51,11 +51,9 @@ class DiVAModel(PreTrainedModel):
|
|
51 |
super().__init__(DiVAConfig.from_dict(config_dict))
|
52 |
if speech_encoder_device is None:
|
53 |
speech_encoder_device = "cuda:0"
|
54 |
-
whisper =
|
55 |
-
"openai/whisper-large-v3"
|
56 |
-
)
|
57 |
connector = WhisperConnector()
|
58 |
-
connector.decoder = copy.deepcopy(whisper.
|
59 |
if via_path is not None:
|
60 |
with open(via_path, "rb") as f:
|
61 |
sd = load(f.read())
|
@@ -83,25 +81,25 @@ class DiVAModel(PreTrainedModel):
|
|
83 |
)
|
84 |
|
85 |
self.connector = connector.to(speech_encoder_device)
|
86 |
-
self.whisper_encoder = whisper.
|
87 |
-
self.
|
88 |
-
"
|
89 |
device_map=device_map,
|
90 |
torch_dtype=torch.float16,
|
91 |
)
|
92 |
-
self.processor = AutoProcessor.from_pretrained("
|
93 |
self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
|
94 |
self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
|
95 |
-
self.
|
96 |
)
|
97 |
|
98 |
self.pre_user_suffix = torch.tensor(
|
99 |
self.tokenizer.encode(
|
100 |
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
|
101 |
)
|
102 |
-
).to(self.
|
103 |
self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
|
104 |
-
self.
|
105 |
)
|
106 |
self.speech_encoder_device = speech_encoder_device
|
107 |
|
@@ -161,18 +159,18 @@ class DiVAModel(PreTrainedModel):
|
|
161 |
]
|
162 |
virt_tokens = self.connector(
|
163 |
hidden_states,
|
164 |
-
output_device=self.
|
165 |
).squeeze()
|
166 |
|
167 |
-
prefix_embed = self.
|
168 |
-
suffix_embed = self.
|
169 |
inputs_embeds = torch.cat(
|
170 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
171 |
).unsqueeze(0)
|
172 |
|
173 |
-
outputs = self.
|
174 |
inputs_embeds=inputs_embeds.to(
|
175 |
-
self.
|
176 |
).half(),
|
177 |
return_dict=True,
|
178 |
output_hidden_states=True,
|
@@ -197,7 +195,7 @@ class DiVAModel(PreTrainedModel):
|
|
197 |
]
|
198 |
virt_tokens = self.connector(
|
199 |
hidden_states,
|
200 |
-
output_device=self.
|
201 |
)
|
202 |
bsz = virt_tokens.shape[0]
|
203 |
|
@@ -227,9 +225,9 @@ class DiVAModel(PreTrainedModel):
|
|
227 |
)
|
228 |
else:
|
229 |
prefix = self.prefix
|
230 |
-
prefix_embed = self.
|
231 |
suffix = self.final_header
|
232 |
-
suffix_embed = self.
|
233 |
inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
|
234 |
outs = [[] for i in range(bsz)]
|
235 |
complete = [False] * bsz
|
@@ -238,9 +236,9 @@ class DiVAModel(PreTrainedModel):
|
|
238 |
i = 0
|
239 |
while not all(complete) and len(outs[0]) < max_new_tokens:
|
240 |
past_key_values = outputs.past_key_values if outputs else None
|
241 |
-
outputs = self.
|
242 |
inputs_embeds=inputs_embeds.to(
|
243 |
-
self.
|
244 |
).half(),
|
245 |
return_dict=True,
|
246 |
output_hidden_states=True,
|
@@ -268,7 +266,7 @@ class DiVAModel(PreTrainedModel):
|
|
268 |
if out == 128009:
|
269 |
complete[token_index] = True
|
270 |
|
271 |
-
next_embed = self.
|
272 |
inputs_embeds = next_embed
|
273 |
return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
|
274 |
|
@@ -287,7 +285,7 @@ class DiVAModel(PreTrainedModel):
|
|
287 |
]
|
288 |
virt_tokens = self.connector(
|
289 |
hidden_states,
|
290 |
-
output_device=self.
|
291 |
).squeeze()
|
292 |
|
293 |
if text_prompt != None and text_prompt != "":
|
@@ -300,9 +298,9 @@ class DiVAModel(PreTrainedModel):
|
|
300 |
)
|
301 |
else:
|
302 |
prefix = self.prefix
|
303 |
-
prefix_embed = self.
|
304 |
suffix = self.final_header
|
305 |
-
suffix_embed = self.
|
306 |
inputs_embeds = torch.cat(
|
307 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
308 |
).unsqueeze(0)
|
@@ -312,9 +310,9 @@ class DiVAModel(PreTrainedModel):
|
|
312 |
i = 0
|
313 |
while greedy != 128009 and len(outs) < max_new_tokens:
|
314 |
past_key_values = outputs.past_key_values if outputs else None
|
315 |
-
outputs = self.
|
316 |
inputs_embeds=inputs_embeds.to(
|
317 |
-
self.
|
318 |
).half(),
|
319 |
return_dict=True,
|
320 |
output_hidden_states=True,
|
@@ -337,7 +335,7 @@ class DiVAModel(PreTrainedModel):
|
|
337 |
else:
|
338 |
greedy = next_token_logits.argmax()
|
339 |
outs.append(greedy)
|
340 |
-
next_embed = self.
|
341 |
inputs_embeds = next_embed
|
342 |
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
343 |
"<|eot_id|>", ""
|
|
|
10 |
from datasets import Audio
|
11 |
from safetensors.torch import load, load_model
|
12 |
from torch import nn
|
13 |
+
from configuring_diva import DiVAConfig
|
14 |
from transformers import (
|
15 |
AutoProcessor,
|
16 |
AutoTokenizer,
|
17 |
+
AutoModelForCausalLM,
|
18 |
PreTrainedModel,
|
19 |
+
WhisperModel,
|
20 |
)
|
21 |
|
22 |
|
|
|
51 |
super().__init__(DiVAConfig.from_dict(config_dict))
|
52 |
if speech_encoder_device is None:
|
53 |
speech_encoder_device = "cuda:0"
|
54 |
+
whisper = WhisperModel.from_pretrained(config_dict["reference_encoder"])
|
|
|
|
|
55 |
connector = WhisperConnector()
|
56 |
+
connector.decoder = copy.deepcopy(whisper.decoder)
|
57 |
if via_path is not None:
|
58 |
with open(via_path, "rb") as f:
|
59 |
sd = load(f.read())
|
|
|
81 |
)
|
82 |
|
83 |
self.connector = connector.to(speech_encoder_device)
|
84 |
+
self.whisper_encoder = whisper.encoder.to(speech_encoder_device)
|
85 |
+
self.llm_decoder = AutoModelForCausalLM.from_pretrained(
|
86 |
+
config_dict["reference_decoder"],
|
87 |
device_map=device_map,
|
88 |
torch_dtype=torch.float16,
|
89 |
)
|
90 |
+
self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
|
91 |
self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
|
92 |
self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
|
93 |
+
self.llm_decoder.model.embed_tokens.weight.device
|
94 |
)
|
95 |
|
96 |
self.pre_user_suffix = torch.tensor(
|
97 |
self.tokenizer.encode(
|
98 |
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
|
99 |
)
|
100 |
+
).to(self.llm_decoder.model.embed_tokens.weight.device)
|
101 |
self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
|
102 |
+
self.llm_decoder.model.embed_tokens.weight.device
|
103 |
)
|
104 |
self.speech_encoder_device = speech_encoder_device
|
105 |
|
|
|
159 |
]
|
160 |
virt_tokens = self.connector(
|
161 |
hidden_states,
|
162 |
+
output_device=self.llm_decoder.model.embed_tokens.weight.device,
|
163 |
).squeeze()
|
164 |
|
165 |
+
prefix_embed = self.llm_decoder.model.embed_tokens(prefix_text_tokens)
|
166 |
+
suffix_embed = self.llm_decoder.model.embed_tokens(suffix_text_tokens)
|
167 |
inputs_embeds = torch.cat(
|
168 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
169 |
).unsqueeze(0)
|
170 |
|
171 |
+
outputs = self.llm_decoder(
|
172 |
inputs_embeds=inputs_embeds.to(
|
173 |
+
self.llm_decoder.model.embed_tokens.weight.device
|
174 |
).half(),
|
175 |
return_dict=True,
|
176 |
output_hidden_states=True,
|
|
|
195 |
]
|
196 |
virt_tokens = self.connector(
|
197 |
hidden_states,
|
198 |
+
output_device=self.llm_decoder.model.embed_tokens.weight.device,
|
199 |
)
|
200 |
bsz = virt_tokens.shape[0]
|
201 |
|
|
|
225 |
)
|
226 |
else:
|
227 |
prefix = self.prefix
|
228 |
+
prefix_embed = self.llm_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
|
229 |
suffix = self.final_header
|
230 |
+
suffix_embed = self.llm_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
|
231 |
inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
|
232 |
outs = [[] for i in range(bsz)]
|
233 |
complete = [False] * bsz
|
|
|
236 |
i = 0
|
237 |
while not all(complete) and len(outs[0]) < max_new_tokens:
|
238 |
past_key_values = outputs.past_key_values if outputs else None
|
239 |
+
outputs = self.llm_decoder(
|
240 |
inputs_embeds=inputs_embeds.to(
|
241 |
+
self.llm_decoder.model.embed_tokens.weight.device
|
242 |
).half(),
|
243 |
return_dict=True,
|
244 |
output_hidden_states=True,
|
|
|
266 |
if out == 128009:
|
267 |
complete[token_index] = True
|
268 |
|
269 |
+
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(-1, 1))
|
270 |
inputs_embeds = next_embed
|
271 |
return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
|
272 |
|
|
|
285 |
]
|
286 |
virt_tokens = self.connector(
|
287 |
hidden_states,
|
288 |
+
output_device=self.llm_decoder.model.embed_tokens.weight.device,
|
289 |
).squeeze()
|
290 |
|
291 |
if text_prompt != None and text_prompt != "":
|
|
|
298 |
)
|
299 |
else:
|
300 |
prefix = self.prefix
|
301 |
+
prefix_embed = self.llm_decoder.model.embed_tokens(prefix)
|
302 |
suffix = self.final_header
|
303 |
+
suffix_embed = self.llm_decoder.model.embed_tokens(suffix)
|
304 |
inputs_embeds = torch.cat(
|
305 |
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
306 |
).unsqueeze(0)
|
|
|
310 |
i = 0
|
311 |
while greedy != 128009 and len(outs) < max_new_tokens:
|
312 |
past_key_values = outputs.past_key_values if outputs else None
|
313 |
+
outputs = self.llm_decoder(
|
314 |
inputs_embeds=inputs_embeds.to(
|
315 |
+
self.llm_decoder.model.embed_tokens.weight.device
|
316 |
).half(),
|
317 |
return_dict=True,
|
318 |
output_hidden_states=True,
|
|
|
335 |
else:
|
336 |
greedy = next_token_logits.argmax()
|
337 |
outs.append(greedy)
|
338 |
+
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
|
339 |
inputs_embeds = next_embed
|
340 |
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
341 |
"<|eot_id|>", ""
|