kvaishnavi commited on
Commit
073038f
·
1 Parent(s): 98aa2b7

Upload Phi-4-multimodal-instruct scripts to make ONNX models

Browse files
onnx/builder.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import onnx
4
+ import onnxruntime as ort
5
+ import onnxscript
6
+ import os
7
+ import requests
8
+ import shutil
9
+ import soundfile
10
+ import subprocess
11
+ import sys
12
+ import torch
13
+
14
+ from onnx import helper, numpy_helper, TensorProto
15
+ from onnxruntime_genai.models.builder import create_model
16
+ from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper
17
+ from onnxscript import ir
18
+ from PIL import Image
19
+ from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM
20
+
21
+
22
+ def build_vision(args):
23
+ # Many images:
24
+ prompt = f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\nWhat is shown in these four images?{prompt_suffix}{assistant_prompt}"
25
+ url = "https://www.ilankelman.org/stopsigns/australia.jpg"
26
+ image_1 = Image.open(requests.get(url, stream=True).raw)
27
+ url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
28
+ image_2 = Image.open(requests.get(url, stream=True).raw)
29
+ url = "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
30
+ image_3 = Image.open(requests.get(url, stream=True).raw)
31
+ url = "https://wallpaper.dog/large/10809054.jpg"
32
+ image_4 = Image.open(requests.get(url, stream=True).raw)
33
+ images = [image_1, image_2, image_3, image_4]
34
+ inputs = processor(prompt, images=images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
35
+ inputs["input_image_embeds"] = inputs["input_image_embeds"].to(args.precision)
36
+ inputs["image_attention_mask"] = inputs["image_attention_mask"].to(args.precision)
37
+
38
+ # TorchScript export
39
+ dummy_inputs = (
40
+ inputs["input_image_embeds"], # image_embeds: torch.FloatTensor
41
+ inputs["image_attention_mask"], # image_attention_mask: torch.FloatTensor
42
+ inputs["image_sizes"], # image_sizes: torch.LongTensor
43
+ )
44
+ dynamic_axes = {
45
+ "pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"},
46
+ "image_attention_mask": {0: "num_images", 1: "max_num_crops"},
47
+ "image_sizes": {0: "num_images"},
48
+ "image_features": {0: "num_image_tokens"},
49
+ }
50
+ filename = "phi-4-mm-vision.onnx"
51
+
52
+ temp_folder_1 = os.path.join(args.output, "vision_init_export")
53
+ os.makedirs(temp_folder_1, exist_ok=True)
54
+
55
+ fpath_1 = os.path.join(temp_folder_1, filename)
56
+ torch.onnx.export(
57
+ model.model.embed_tokens_extend.image_embed,
58
+ args=dummy_inputs,
59
+ f=fpath_1,
60
+ export_params=True,
61
+ input_names=["pixel_values", "image_attention_mask", "image_sizes"],
62
+ output_names=["image_features"],
63
+ dynamic_axes=dynamic_axes,
64
+ opset_version=14,
65
+ do_constant_folding=True,
66
+ )
67
+
68
+ onnx.checker.check_model(fpath_1)
69
+ onnx.shape_inference.infer_shapes_path(fpath_1)
70
+ onnx_model = onnx.load_model(fpath_1, load_external_data=True)
71
+
72
+ temp_folder_2 = os.path.join(args.output, "vision_after_export")
73
+ os.makedirs(temp_folder_2, exist_ok=True)
74
+
75
+ fpath_2 = os.path.join(temp_folder_2, filename)
76
+ onnx.save_model(
77
+ onnx_model,
78
+ fpath_2,
79
+ save_as_external_data=True,
80
+ all_tensors_to_one_file=True,
81
+ location=f"{filename}.data",
82
+ size_threshold=0,
83
+ convert_attribute=False,
84
+ )
85
+ shutil.rmtree(temp_folder_1)
86
+
87
+ # ORT transformer optimizer
88
+ temp_folder_3 = os.path.join(args.output, "vision_after_opt")
89
+ fpath_3 = os.path.join(temp_folder_3, filename)
90
+ subprocess.run(
91
+ [
92
+ f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer",
93
+ "--input", fpath_2,
94
+ "--output", fpath_3,
95
+ "--model_type", "clip",
96
+ "--num_heads", str(16),
97
+ "--hidden_size", str(1152),
98
+ "--use_external_data_format",
99
+ "--opt_level", str(0),
100
+ "--disable_shape_inference",
101
+ ]
102
+ )
103
+ shutil.rmtree(temp_folder_2)
104
+
105
+ # ORT 4-bits quantizer
106
+ fpath_4 = os.path.join(args.output, filename)
107
+ cmd = [
108
+ f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
109
+ "--input_model", fpath_3,
110
+ "--output_model", fpath_4,
111
+ "--block_size", str(32),
112
+ ]
113
+ if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
114
+ subprocess.run(cmd)
115
+ shutil.rmtree(temp_folder_3)
116
+
117
+
118
+ def build_speech(args):
119
+ # Speech file:
120
+ prompt = f"{user_prompt}<|audio_1|>\n<|audio_2|>\nWhat are the stories that these audios come from?{prompt_suffix}{assistant_prompt}"
121
+ audio1 = soundfile.read(os.path.join(args.input, "examples", "1272-128104-0004.wav"))
122
+ audio2 = soundfile.read(os.path.join(args.input, "examples", "1272-128104-0009.wav"))
123
+ inputs = processor(prompt, audios=[audio1, audio2], return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
124
+ inputs["input_audio_embeds"] = inputs["input_audio_embeds"].to(args.precision)
125
+
126
+ # TorchScript export
127
+ dummy_inputs = (
128
+ inputs["input_audio_embeds"], # audio_embeds: torch.FloatTensor
129
+ inputs["audio_attention_mask"], # audio_attention_mask: torch.BoolTensor
130
+ inputs["audio_embed_sizes"], # audio_sizes: torch.FloatTensor
131
+ inputs["input_mode"], # audio_projection_mode: int
132
+ )
133
+ dynamic_axes = {
134
+ "audio_embeds": {0: "num_audios", 1: "num_frames", 2: "feature_size"},
135
+ "audio_attention_mask": {0: "num_audios", 1: "num_frames"},
136
+ "audio_sizes": {0: "num_audios"},
137
+ "audio_features": {0: "num_audio_tokens"},
138
+ }
139
+ filename = "phi-4-mm-speech.onnx"
140
+
141
+ temp_folder_1 = os.path.join(args.output, "speech_init_export")
142
+ os.makedirs(temp_folder_1, exist_ok=True)
143
+
144
+ fpath_1 = os.path.join(temp_folder_1, filename)
145
+ torch._dynamo.config.capture_scalar_outputs = True
146
+ ep = torch.export.export(
147
+ model.model.embed_tokens_extend.audio_embed, args=dummy_inputs, strict=False,
148
+ dynamic_shapes=[
149
+ {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO, 2: torch.export.Dim.AUTO},
150
+ {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
151
+ {0: torch.export.Dim.AUTO},
152
+ {0: torch.export.Dim.AUTO},
153
+ ]
154
+ )
155
+ onnx_program = torch.onnx.export(ep, (), input_names=["audio_embeds", "audio_attention_mask", "audio_sizes", "audio_projection_mode"], output_names=["audio_features"])
156
+ onnx_program.optimize()
157
+ onnx_program.save(fpath_1, external_data=True)
158
+
159
+ onnx.checker.check_model(fpath_1)
160
+ onnx.shape_inference.infer_shapes_path(fpath_1)
161
+ onnx_model = onnx.load_model(fpath_1, load_external_data=True)
162
+
163
+ temp_folder_2 = os.path.join(args.output, "speech_after_export")
164
+ os.makedirs(temp_folder_2, exist_ok=True)
165
+
166
+ fpath_2 = os.path.join(temp_folder_2, filename)
167
+ onnx.save_model(
168
+ onnx_model,
169
+ fpath_2,
170
+ save_as_external_data=True,
171
+ all_tensors_to_one_file=True,
172
+ location=f"{filename}.data",
173
+ size_threshold=0,
174
+ convert_attribute=False,
175
+ )
176
+ shutil.rmtree(temp_folder_1)
177
+
178
+ # ONNX/ORT rewriter
179
+ temp_folder_3 = os.path.join(args.output, "speech_after_rewrite")
180
+ os.makedirs(temp_folder_3, exist_ok=True)
181
+
182
+ onnx_model = ir.load(fpath_2)
183
+ DynamoOnnxHelper.fold_transpose_initializers(onnx_model)
184
+ onnxscript.rewriter.rewrite(onnx_model)
185
+ onnxscript.optimizer.optimize(onnx_model, onnx_shape_inference=False, input_size_limit=4*2048*2048, output_size_limit=4*2048*2048)
186
+
187
+ fpath_3 = os.path.join(temp_folder_3, filename)
188
+ ir.save(onnx_model, fpath_3, external_data=f"{filename}.data")
189
+ shutil.rmtree(temp_folder_2)
190
+
191
+ onnx_model = onnx.load_model(fpath_3, load_external_data=True)
192
+ # Fix labels of dynamic axes since they can't be specified during Dynamo export currently
193
+ onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "num_audios"
194
+ onnx_model.graph.input[0].type.tensor_type.shape.dim[1].dim_param = "num_frames"
195
+ onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_param = "num_audios"
196
+ onnx_model.graph.input[1].type.tensor_type.shape.dim[1].dim_param = "num_frames"
197
+ onnx_model.graph.input[2].type.tensor_type.shape.dim[0].dim_param = "num_audios"
198
+ onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = "num_audio_tokens"
199
+
200
+ onnx_model = DynamoOnnxHelper(onnx_model)
201
+ onnx_model.convert_constants_to_initializers()
202
+ onnx_model.clear_metadata()
203
+
204
+ os.remove(fpath_3)
205
+ os.remove(fpath_3 + ".data")
206
+ onnx_model.model.save_model_to_file(fpath_3, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True) # convert_attribute = True needed because of ONNX/ORT rewriter
207
+
208
+ # ORT transformer optimizer
209
+ temp_folder_4 = os.path.join(args.output, "speech_after_opt")
210
+ fpath_4 = os.path.join(temp_folder_4, filename)
211
+ subprocess.run(
212
+ [
213
+ f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer",
214
+ "--input", fpath_3,
215
+ "--output", fpath_4,
216
+ "--model_type", "conformer",
217
+ "--num_heads", str(16),
218
+ "--hidden_size", str(1024),
219
+ "--use_external_data_format",
220
+ "--opt_level", str(0),
221
+ "--disable_shape_inference",
222
+ "--convert_attribute",
223
+ ]
224
+ )
225
+ shutil.rmtree(temp_folder_3)
226
+
227
+ # ORT 4-bits quantizer
228
+ fpath_5 = os.path.join(args.output, filename)
229
+ cmd = [
230
+ f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
231
+ "--input_model", fpath_4,
232
+ "--output_model", fpath_5,
233
+ "--block_size", str(32),
234
+ ]
235
+ if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
236
+ subprocess.run(cmd)
237
+ shutil.rmtree(temp_folder_4)
238
+
239
+
240
+ def build_embedding(args):
241
+ # TorchScript export
242
+ batch_size, sequence_length, num_image_tokens, num_audio_tokens = 2, 8, 2, 2
243
+ inputs = {
244
+ "input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64),
245
+ "image_features": torch.randn(num_image_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
246
+ "audio_features": torch.randn(num_audio_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
247
+ }
248
+ inputs["input_ids"][0][0] = -1
249
+ inputs["input_ids"][0][1] = -1
250
+ inputs["input_ids"][0][2] = -10000
251
+ inputs["input_ids"][0][3] = -10000
252
+ dummy_inputs = (
253
+ inputs["input_ids"], # input_ids: torch.LongTensor
254
+ inputs["image_features"], # image_features: Optional[torch.FloatTensor] = None,
255
+ inputs["audio_features"], # audio_features: Optional[torch.FloatTensor] = None,
256
+ )
257
+ dynamic_axes = {
258
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
259
+ "image_features": {0: "num_image_tokens"},
260
+ "audio_features": {0: "num_audio_tokens"},
261
+ "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
262
+ }
263
+ filename = "phi-4-mm-embedding.onnx"
264
+
265
+ temp_folder_1 = os.path.join(args.output, "embedding_init_export")
266
+ os.makedirs(temp_folder_1, exist_ok=True)
267
+
268
+ fpath_1 = os.path.join(temp_folder_1, filename)
269
+ torch.onnx.export(
270
+ model.model.combined_embed,
271
+ args=dummy_inputs,
272
+ f=fpath_1,
273
+ export_params=True,
274
+ input_names=["input_ids", "image_features", "audio_features"],
275
+ output_names=["inputs_embeds"],
276
+ dynamic_axes=dynamic_axes,
277
+ opset_version=14,
278
+ do_constant_folding=True,
279
+ )
280
+
281
+ onnx.checker.check_model(fpath_1)
282
+ onnx.shape_inference.infer_shapes_path(fpath_1)
283
+ onnx_model = onnx.load_model(fpath_1, load_external_data=True)
284
+
285
+ fpath_2 = os.path.join(args.output, filename)
286
+ onnx.save_model(
287
+ onnx_model,
288
+ fpath_2,
289
+ save_as_external_data=True,
290
+ all_tensors_to_one_file=True,
291
+ location=f"{filename}.data",
292
+ size_threshold=0,
293
+ convert_attribute=False,
294
+ )
295
+ shutil.rmtree(temp_folder_1)
296
+
297
+
298
+ def build_text(args):
299
+ # Create ONNX model
300
+ model_name = None
301
+ precision = "int4"
302
+ extra_options = {
303
+ "exclude_embeds": "true",
304
+ "filename": "phi-4-mm-text.onnx",
305
+ }
306
+ if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4
307
+ create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options)
308
+
309
+
310
+ def build_adapters(args):
311
+ # setattr(args, 'use_ortvalue', True)
312
+ # build_float_adapters(args)
313
+
314
+ setattr(args, 'use_ortvalue', False)
315
+ build_quantized_adapters(args)
316
+
317
+
318
+ def extract_adapters_from_torch(args):
319
+ # Extract LoRAs from PyTorch model
320
+ hidden_size = config.hidden_size
321
+ num_kv_heads = config.num_key_value_heads
322
+ num_attn_heads = config.num_attention_heads
323
+ head_size = hidden_size // num_attn_heads
324
+
325
+ q_size = num_attn_heads * head_size
326
+ kv_size = num_kv_heads * head_size
327
+ intermediate_size = config.intermediate_size
328
+
329
+ vision_scaling = config.vision_lora["lora_alpha"] / config.vision_lora["r"]
330
+ speech_scaling = config.speech_lora["lora_alpha"] / config.speech_lora["r"]
331
+
332
+ vision_adapters = {}
333
+ speech_adapters = {}
334
+ for key, val in model.state_dict().items():
335
+ # Map name in graph as key
336
+ new_dict = {}
337
+ key = key.replace("self_attn", "attn").replace("lora_A", "lora_A.MatMul").replace("lora_B", "lora_B.MatMul")
338
+
339
+ if "lora_A" in key:
340
+ # LoRA_A is shared across projections
341
+ if "qkv_proj" in key:
342
+ new_dict[key.replace("qkv_proj", "q_proj")] = val
343
+ new_dict[key.replace("qkv_proj", "k_proj")] = val
344
+ new_dict[key.replace("qkv_proj", "v_proj")] = val
345
+ elif "gate_up_proj" in key:
346
+ new_dict[key.replace("gate_up_proj", "gate_proj")] = val
347
+ new_dict[key.replace("gate_up_proj", "up_proj")] = val
348
+ else:
349
+ new_dict[key] = val
350
+
351
+ elif "lora_B" in key:
352
+ # LoRA_B is split across projections
353
+ if "qkv_proj" in key:
354
+ new_dict[key.replace("qkv_proj", "q_proj")] = val[: q_size, :]
355
+ new_dict[key.replace("qkv_proj", "k_proj")] = val[q_size : q_size + kv_size, :]
356
+ new_dict[key.replace("qkv_proj", "v_proj")] = val[q_size + kv_size :, :]
357
+ elif "gate_up_proj" in key:
358
+ new_dict[key.replace("gate_up_proj", "gate_proj")] = val[: intermediate_size, :]
359
+ new_dict[key.replace("gate_up_proj", "up_proj")] = val[intermediate_size :, :]
360
+ else:
361
+ new_dict[key] = val
362
+
363
+ else:
364
+ continue
365
+
366
+ for new_key, new_val in new_dict.items():
367
+ new_key = new_key.replace(".vision", "").replace(".speech", "")
368
+ if "vision" in key:
369
+ np_data = new_val.detach().cpu().to(args.precision).numpy().transpose()
370
+ if "lora_B" in key:
371
+ np_data *= vision_scaling
372
+ vision_adapters[new_key] = ort.OrtValue.ortvalue_from_numpy(np_data) if args.use_ortvalue else np_data
373
+ elif "speech" in key:
374
+ np_data = new_val.detach().cpu().to(args.precision).numpy().transpose()
375
+ if "lora_B" in key:
376
+ np_data *= speech_scaling
377
+ speech_adapters[new_key] = ort.OrtValue.ortvalue_from_numpy(np_data) if args.use_ortvalue else np_data
378
+ else:
379
+ raise ValueError(f"Unknown LoRA key found: {key}")
380
+
381
+ return vision_adapters, speech_adapters
382
+
383
+
384
+ def build_onnx_adapters(vision_adapters, speech_adapters):
385
+ # Convert vision LoRAs
386
+ adapter_format = ort.AdapterFormat()
387
+ adapter_format.set_adapter_version(1)
388
+ adapter_format.set_model_version(1)
389
+ adapter_format.set_parameters(vision_adapters)
390
+ adapter_format.export_adapter(os.path.join(args.output, "phi-4-mm-vision.onnx_adapter"))
391
+
392
+ # Convert speech LoRAs
393
+ adapter_format = ort.AdapterFormat()
394
+ adapter_format.set_adapter_version(1)
395
+ adapter_format.set_model_version(1)
396
+ adapter_format.set_parameters(speech_adapters)
397
+ adapter_format.export_adapter(os.path.join(args.output, "phi-4-mm-speech.onnx_adapter"))
398
+
399
+ # Convert LoRA weights in ONNX model to inputs
400
+ filename = "phi-4-mm-text.onnx"
401
+ fpath = os.path.join(args.output, filename)
402
+ onnx_model = onnx.load_model(fpath)
403
+
404
+ to_proto = {
405
+ "tensor(int8)": TensorProto.INT8,
406
+ "tensor(uint8)": TensorProto.UINT8,
407
+ "tensor(float16)": TensorProto.FLOAT16,
408
+ "tensor(float)": TensorProto.FLOAT,
409
+ }
410
+ for key, val in vision_adapters.items():
411
+ # Handle different sized feature dimensions between adapters by using dynamic axes
412
+ shape = val.shape()
413
+ if "lora_A.MatMul.weight_Q4" in key:
414
+ shape[0] = "out_features"
415
+ elif "lora_B.MatMul.weight_Q4" in key:
416
+ shape[1] = "(in_features + block_size - 1) // block_size"
417
+ elif "lora_A.MatMul.weight_scales" in key or "lora_B.MatMul.weight_scales" in key:
418
+ shape[0] = "in_features * out_features / block_size"
419
+ elif "lora_A.MatMul.weight" in key:
420
+ shape[1] = "out_features"
421
+ elif "lora_B.MatMul.weight" in key:
422
+ shape[0] = "in_features"
423
+
424
+ new_input = helper.make_tensor_value_info(key, to_proto[val.data_type()], shape)
425
+ onnx_model.graph.input.extend([new_input])
426
+ for initializer in onnx_model.graph.initializer:
427
+ if initializer.name == key:
428
+ # Add 0-filled static initializer for when LoRA isn't used
429
+ # since size of inner dims in LoRA path don't matter
430
+ zero_initializer = helper.make_tensor(
431
+ name=initializer.name,
432
+ data_type=initializer.data_type,
433
+ dims=val.shape(),
434
+ vals=np.zeros(val.shape(), dtype=helper.tensor_dtype_to_np_dtype(initializer.data_type)).flatten(),
435
+ )
436
+ onnx_model.graph.initializer.remove(initializer)
437
+ onnx_model.graph.initializer.append(zero_initializer)
438
+ break
439
+
440
+ os.remove(fpath)
441
+ os.remove(fpath + ".data")
442
+ onnx.save_model(
443
+ onnx_model,
444
+ fpath,
445
+ save_as_external_data=True,
446
+ all_tensors_to_one_file=True,
447
+ location=f"{filename}.data",
448
+ size_threshold=0,
449
+ convert_attribute=False,
450
+ )
451
+
452
+
453
+ def build_float_adapters(args):
454
+ vision_adapters, speech_adapters = extract_adapters_from_torch(args)
455
+ build_onnx_adapters(vision_adapters, speech_adapters)
456
+
457
+
458
+ def build_adapter_only_onnx_model(args, adapters, filename, fpath):
459
+ inputs, outputs, initializers, value_infos, nodes = [], [], [], [], []
460
+ dtype = TensorProto.FLOAT16 if args.precision == torch.float16 else TensorProto.FLOAT
461
+ for key, val in adapters.items():
462
+ # Create input and output
463
+ inputs.append(helper.make_tensor_value_info(f"input.{key}", dtype, ["batch_size", "sequence_length", val.shape[0]]))
464
+ outputs.append(helper.make_tensor_value_info(f"output.{key}", dtype, ["batch_size", "sequence_length", val.shape[1]]))
465
+
466
+ # Create initializer data
467
+ tensor = numpy_helper.from_array(val)
468
+ tensor.name = key
469
+ initializers.append(tensor)
470
+
471
+ # Create MatMul node
472
+ matmul_node = helper.make_node(
473
+ "MatMul",
474
+ inputs=[inputs[-1].name, tensor.name],
475
+ outputs=[outputs[-1].name],
476
+ name=f"node.{key}",
477
+ )
478
+ nodes.append(matmul_node)
479
+
480
+ model = helper.make_model(
481
+ opset_imports=[helper.make_operatorsetid('', 14)],
482
+ ir_version=7,
483
+ producer_name="onnxruntime-genai",
484
+ producer_version="0.0.0",
485
+ graph=helper.make_graph(
486
+ name="main_graph",
487
+ inputs=inputs,
488
+ outputs=outputs,
489
+ initializer=initializers,
490
+ value_info=value_infos,
491
+ nodes=nodes,
492
+ )
493
+ )
494
+ onnx.save_model(
495
+ model,
496
+ fpath,
497
+ save_as_external_data=True,
498
+ all_tensors_to_one_file=True,
499
+ location=f"{filename}.data",
500
+ size_threshold=0,
501
+ convert_attribute=False,
502
+ )
503
+
504
+
505
+ def extract_adapters_from_onnx(args, fpath):
506
+ adapters = {}
507
+ model = onnx.load_model(fpath)
508
+ for initializer in model.graph.initializer:
509
+ val = numpy_helper.to_array(initializer)
510
+ adapters[initializer.name] = ort.OrtValue.ortvalue_from_numpy(val)
511
+ return adapters
512
+
513
+
514
+ def build_quantized_adapters(args):
515
+ # 1. Extract LoRAs from PyTorch model
516
+ vision_adapters, speech_adapters = extract_adapters_from_torch(args)
517
+
518
+ # 2. Put LoRAs into separate ONNX models
519
+ filename = "phi-4-mm-lora-vision.onnx"
520
+ fpath_1 = os.path.join(args.output, filename)
521
+ vision_model = build_adapter_only_onnx_model(args, vision_adapters, filename, fpath_1)
522
+
523
+ filename = "phi-4-mm-lora-speech.onnx"
524
+ fpath_2 = os.path.join(args.output, filename)
525
+ speech_model = build_adapter_only_onnx_model(args, speech_adapters, filename, fpath_2)
526
+
527
+ # 3. Quantize ONNX models to int4
528
+ filename = "phi-4-mm-qlora-vision.onnx"
529
+ fpath_3 = os.path.join(args.output, filename)
530
+ cmd = [
531
+ f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
532
+ "--input_model", fpath_1,
533
+ "--output_model", fpath_3,
534
+ "--block_size", str(32),
535
+ ]
536
+ if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
537
+ subprocess.run(cmd)
538
+
539
+ filename = "phi-4-mm-qlora-speech.onnx"
540
+ fpath_4 = os.path.join(args.output, filename)
541
+ cmd = [
542
+ f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
543
+ "--input_model", fpath_2,
544
+ "--output_model", fpath_4,
545
+ "--block_size", str(32),
546
+ ]
547
+ if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
548
+ subprocess.run(cmd)
549
+
550
+ os.remove(fpath_1)
551
+ os.remove(fpath_1 + ".data")
552
+ os.remove(fpath_2)
553
+ os.remove(fpath_2 + ".data")
554
+
555
+ # 4. Extract quantized LoRAs from ONNX models
556
+ vision_adapters = extract_adapters_from_onnx(args, fpath_3)
557
+ speech_adapters = extract_adapters_from_onnx(args, fpath_4)
558
+
559
+ # 5. Store quantized LoRAs in adapter files
560
+ build_onnx_adapters(vision_adapters, speech_adapters)
561
+
562
+ os.remove(fpath_3)
563
+ os.remove(fpath_3 + ".data")
564
+ os.remove(fpath_4)
565
+ os.remove(fpath_4 + ".data")
566
+
567
+
568
+ def get_args():
569
+ parser = argparse.ArgumentParser()
570
+
571
+ parser.add_argument(
572
+ "-i",
573
+ "--input",
574
+ required=True,
575
+ help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.",
576
+ )
577
+
578
+ parser.add_argument(
579
+ "-o",
580
+ "--output",
581
+ required=True,
582
+ help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)",
583
+ )
584
+
585
+ parser.add_argument(
586
+ "-p",
587
+ "--precision",
588
+ required=True,
589
+ choices=["fp16", "fp32"],
590
+ help="Precision to export PyTorch components with",
591
+ )
592
+
593
+ parser.add_argument(
594
+ "-e",
595
+ "--execution_provider",
596
+ required=True,
597
+ choices=["cpu", "cuda", "dml"],
598
+ help="Execution provider for Phi-3.5 vision components",
599
+ )
600
+
601
+ parser.add_argument(
602
+ "-c",
603
+ "--cache_dir",
604
+ required=False,
605
+ default=os.path.join('.', 'cache_dir'),
606
+ help="Cache directory for Hugging Face files and temporary ONNX external data files",
607
+ )
608
+
609
+ args = parser.parse_args()
610
+ args.precision = torch.float16 if args.precision == "fp16" else torch.float32
611
+ return args
612
+
613
+ if __name__ == "__main__":
614
+ user_prompt = '<|user|>\n'
615
+ assistant_prompt = '<|assistant|>\n'
616
+ prompt_suffix = "<|end|>\n"
617
+
618
+ args = get_args()
619
+ config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
620
+ processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
621
+ model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda"))
622
+
623
+ # Build model components
624
+ build_vision(args)
625
+ build_speech(args)
626
+ build_embedding(args)
627
+ build_text(args)
628
+ build_adapters(args)
onnx/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16fb355ba07bea3ffdf794f297f2005aee4f4ee6aba9742e264ad4471535e966
3
+ size 4585
onnx/modeling_phio.py ADDED
The diff for this file is too large to render. See raw diff
 
onnx/processing_phio.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Processor class for PhiO
17
+ """
18
+ import re
19
+ from typing import List, Optional, Tuple, Union
20
+ import math
21
+ from enum import Enum
22
+
23
+ import numpy as np
24
+ import scipy
25
+ import torch
26
+ import torchvision
27
+
28
+ from transformers import AutoFeatureExtractor, AutoImageProcessor
29
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
30
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
31
+ from transformers.image_utils import (
32
+ ImageInput,
33
+ make_list_of_images,
34
+ valid_images,
35
+ )
36
+ from transformers.processing_utils import ProcessorMixin
37
+ from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
38
+ from transformers.utils import TensorType, logging
39
+ from torch.nn.utils.rnn import pad_sequence
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ # Special tokens
45
+ _COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r'<\|image_\d+\|>' # For backward compatibility
46
+ _COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r'<\|audio_\d+\|>' # For backward compatibility
47
+ _IMAGE_SPECIAL_TOKEN = '<|endoftext10|>'
48
+ _AUDIO_SPECIAL_TOKEN = '<|endoftext11|>'
49
+ _IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`)
50
+ _AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
51
+
52
+
53
+ class InputMode(Enum):
54
+ LANGUAGE = 0
55
+ VISION = 1
56
+ SPEECH = 2
57
+ VISION_SPEECH = 3
58
+
59
+
60
+ class PhiOImageProcessor(BaseImageProcessor):
61
+ r"""
62
+ Constructs a PhiO image processor.
63
+ """
64
+ model_input_names = ["input_image_embeds", "image_sizes", "image_attention_mask"]
65
+
66
+ def __init__(
67
+ self,
68
+ dynamic_hd,
69
+ **kwargs,
70
+ ) -> None:
71
+ super().__init__(**kwargs)
72
+ self.dynamic_hd = dynamic_hd
73
+
74
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
75
+ best_ratio_diff = float('inf')
76
+ best_ratio = (1, 1)
77
+ area = width * height
78
+ for ratio in target_ratios:
79
+ target_aspect_ratio = ratio[0] / ratio[1]
80
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
81
+ if ratio_diff < best_ratio_diff:
82
+ best_ratio_diff = ratio_diff
83
+ best_ratio = ratio
84
+ elif ratio_diff == best_ratio_diff:
85
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
86
+ best_ratio = ratio
87
+ return best_ratio
88
+
89
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=384, mask_size=27, use_thumbnail=True):
90
+ orig_width, orig_height = image.size
91
+
92
+ w_crop_num = math.ceil(orig_width/float(image_size))
93
+ h_crop_num = math.ceil(orig_height/float(image_size))
94
+ if w_crop_num * h_crop_num > max_num:
95
+
96
+ aspect_ratio = orig_width / orig_height
97
+
98
+ # calculate the existing image aspect ratio
99
+ target_ratios = set(
100
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
101
+ i * j <= max_num and i * j >= min_num)
102
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
+
104
+ # find the closest aspect ratio to the target
105
+ target_aspect_ratio = self.find_closest_aspect_ratio(
106
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
107
+
108
+ # calculate the target width and height
109
+ target_width = image_size * target_aspect_ratio[0]
110
+ target_height = image_size * target_aspect_ratio[1]
111
+ print(target_aspect_ratio)
112
+ else:
113
+ target_width = image_size * w_crop_num
114
+ target_height = image_size * h_crop_num
115
+ target_aspect_ratio = (w_crop_num, h_crop_num)
116
+
117
+ # Calculate the ratio
118
+ ratio_width = target_width / orig_width
119
+ ratio_height = target_height / orig_height
120
+ if ratio_width < ratio_height:
121
+ new_size = (target_width, int(orig_height * ratio_width))
122
+ padding_width = 0
123
+ padding_height = target_height - int(orig_height * ratio_width)
124
+ else:
125
+ new_size = (int(orig_width * ratio_height), target_height)
126
+ padding_width = target_width - int(orig_width * ratio_height)
127
+ padding_height = 0
128
+
129
+ attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0])))
130
+ if padding_width >= 14:
131
+ attention_mask[:, -math.floor(padding_width/14):] = 0
132
+ if padding_height >= 14:
133
+ attention_mask[-math.floor(padding_height/14):,:] = 0
134
+ assert attention_mask.sum() > 0
135
+
136
+ if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
137
+ raise ValueError(f'the aspect ratio is very extreme {new_size}')
138
+
139
+ image = torchvision.transforms.functional.resize(image, [new_size[1], new_size[0]],)
140
+
141
+ resized_img = torchvision.transforms.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255])
142
+
143
+ return resized_img, attention_mask
144
+
145
+ def pad_to_max_num_crops(self, images, max_crops=5):
146
+ """
147
+ images: B x 3 x H x W, B<=max_crops
148
+ """
149
+ B, _, H, W = images.shape
150
+ if B < max_crops:
151
+ pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
152
+ images = torch.cat([images, pad], dim=0)
153
+ return images
154
+
155
+ def pad_mask_to_max_num_crops(self, masks, max_crops=5):
156
+ B, H, W = masks.shape
157
+ if B < max_crops:
158
+ pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device)
159
+ masks = torch.cat([masks, pad], dim=0)
160
+ return masks
161
+
162
+ def preprocess(
163
+ self,
164
+ images: ImageInput,
165
+ return_tensors: Optional[Union[str, TensorType]] = None,
166
+ ):
167
+ """
168
+ Args:
169
+ images (`ImageInput`):
170
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
171
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
172
+ return_tensors (`str` or `TensorType`, *optional*):
173
+ The type of tensors to return. Can be one of:
174
+ - Unset: Return a list of `np.ndarray`.
175
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
176
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
177
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
178
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
179
+ """
180
+ images = make_list_of_images(images)
181
+
182
+ if not valid_images(images):
183
+ raise ValueError(
184
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
185
+ "torch.Tensor, tf.Tensor or jax.ndarray."
186
+ )
187
+
188
+ # Basic settings.
189
+ img_processor = torchvision.transforms.Compose([
190
+ torchvision.transforms.ToTensor(),
191
+ torchvision.transforms.Normalize(
192
+ (0.5, 0.5, 0.5),
193
+ (0.5, 0.5, 0.5)
194
+ ),
195
+ ])
196
+ dyhd_base_resolution = 448
197
+
198
+ # Dynamic HD
199
+ base_resolution = dyhd_base_resolution
200
+ images = [image.convert('RGB') for image in images]
201
+ # cover 384 and 448 resolution
202
+ mask_resolution = base_resolution // 14
203
+ elems, image_attention_masks = [], []
204
+ for im in images:
205
+ elem, attention_mask = self.dynamic_preprocess(im, max_num=self.dynamic_hd, image_size=base_resolution, mask_size=mask_resolution)
206
+ elems.append(elem)
207
+ image_attention_masks.append(attention_mask)
208
+ hd_images = [img_processor(im) for im in elems]
209
+ global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images]
210
+ shapes = [[im.size(1), im.size(2)] for im in hd_images]
211
+ mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks]
212
+ global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images]
213
+ hd_images_reshape = [im.reshape(1, 3,
214
+ h//base_resolution,
215
+ base_resolution,
216
+ w//base_resolution,
217
+ base_resolution
218
+ ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)]
219
+ attention_masks_reshape = [mask.reshape(1,
220
+ h//mask_resolution,
221
+ mask_resolution,
222
+ w//mask_resolution,
223
+ mask_resolution
224
+ ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)]
225
+ downsample_attention_masks = [mask[:,0::2,0::2].reshape(1,
226
+ h//mask_resolution,
227
+ w//mask_resolution,
228
+ mask_resolution//2+mask_resolution%2,
229
+ mask_resolution//2+mask_resolution%2
230
+ ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)]
231
+ downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks]
232
+ num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks]
233
+
234
+ hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
235
+ hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)]
236
+ max_crops = max([img.size(0) for img in hd_images_reshape])
237
+ image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape]
238
+ image_transformed = torch.stack(image_transformed, dim=0)
239
+ mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape]
240
+ mask_transformed = torch.stack(mask_transformed, dim=0)
241
+
242
+ returned_input_image_embeds = image_transformed
243
+ returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
244
+ returned_image_attention_mask = mask_transformed
245
+ returned_num_img_tokens = num_img_tokens
246
+
247
+ data = {
248
+ "input_image_embeds": returned_input_image_embeds,
249
+ "image_sizes": returned_image_sizes,
250
+ "image_attention_mask": returned_image_attention_mask,
251
+ "num_img_tokens": returned_num_img_tokens,
252
+ }
253
+
254
+ return BatchFeature(data=data, tensor_type=return_tensors)
255
+
256
+
257
+ AudioInput = Tuple[Union[np.ndarray, torch.Tensor], int]
258
+ AudioInputs = List[AudioInput]
259
+
260
+
261
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
262
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
263
+
264
+ Args:
265
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
266
+ n_fft (int): FFT size. int > 0 [scalar]
267
+ n_mel (int): Mel filter size. int > 0 [scalar]
268
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
269
+ float >= 0 [scalar]
270
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
271
+ float >= 0 [scalar]
272
+
273
+ Returns
274
+ out (numpy.ndarray): Mel transform matrix
275
+ [shape=(n_mels, 1 + n_fft/2)]
276
+ """
277
+
278
+ bank_width = int(n_fft // 2 + 1)
279
+ if fmax is None:
280
+ fmax = sample_rate / 2
281
+ if fmin is None:
282
+ fmin = 0
283
+ assert fmin >= 0, "fmin cannot be negtive"
284
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
285
+
286
+ def mel(f):
287
+ return 1127.0 * np.log(1.0 + f / 700.0)
288
+
289
+ def bin2mel(fft_bin):
290
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
291
+
292
+ def f2bin(f):
293
+ return int((f * n_fft / sample_rate) + 0.5)
294
+
295
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
296
+ klo = f2bin(fmin) + 1
297
+ khi = f2bin(fmax)
298
+
299
+ khi = max(khi, klo)
300
+
301
+ # Spec 2: SpeechLib uses trianges in Mel space
302
+ mlo = mel(fmin)
303
+ mhi = mel(fmax)
304
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
305
+ ms = (mhi - mlo) / (n_mels + 1)
306
+
307
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
308
+ for m in range(0, n_mels):
309
+ left = m_centers[m]
310
+ center = m_centers[m + 1]
311
+ right = m_centers[m + 2]
312
+ for fft_bin in range(klo, khi):
313
+ mbin = bin2mel(fft_bin)
314
+ if left < mbin < right:
315
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
316
+
317
+ return matrix
318
+
319
+
320
+ class PhiOAudioFeatureExtractor(SequenceFeatureExtractor):
321
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
322
+
323
+ def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
324
+ feature_size = 80
325
+ sampling_rate = 16000
326
+ padding_value = 0.0
327
+ super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
328
+
329
+ self.compression_rate = audio_compression_rate
330
+ self.qformer_compression_rate = audio_downsample_rate
331
+ self.feat_stride = audio_feat_stride
332
+
333
+ self._eightk_method = "fillzero"
334
+ self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
335
+
336
+ self._hamming400 = np.hamming(400) # for 16k audio
337
+ self._hamming200 = np.hamming(200) # for 8k audio
338
+
339
+ def duration_to_frames(self, duration):
340
+ """duration in s, estimated frames"""
341
+ frame_rate = 10
342
+
343
+ num_frames = duration * 1000 // frame_rate
344
+ return num_frames
345
+
346
+ def __call__(
347
+ self,
348
+ audios: List[AudioInput],
349
+ return_tensors: Optional[Union[str, TensorType]] = None,
350
+ ):
351
+ # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
352
+ returned_input_audio_embeds = []
353
+ returned_audio_embed_sizes = []
354
+ audio_frames_list = []
355
+ # import pdb; pdb.set_trace()
356
+
357
+ for audio_data, sample_rate in audios:
358
+ audio_embeds = self._extract_features(audio_data, sample_rate)
359
+ audio_frames = len(audio_embeds) * self.feat_stride
360
+ audio_embed_size = self._compute_audio_embed_size(audio_frames)
361
+
362
+ returned_input_audio_embeds.append(torch.tensor(audio_embeds))
363
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
364
+ audio_frames_list.append(audio_frames)
365
+
366
+ returned_input_audio_embeds = pad_sequence(
367
+ returned_input_audio_embeds, batch_first=True
368
+ )
369
+ returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
370
+ audio_frames = torch.tensor(audio_frames_list)
371
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
372
+
373
+ data = {
374
+ "input_audio_embeds": returned_input_audio_embeds,
375
+ "audio_embed_sizes": returned_audio_embed_sizes,
376
+ }
377
+ if returned_audio_attention_mask is not None:
378
+ data["audio_attention_mask"] = returned_audio_attention_mask
379
+
380
+ return BatchFeature(data=data, tensor_type=return_tensors)
381
+
382
+ def _extract_spectrogram(self, wav, fs):
383
+ """Extract spectrogram features from waveform.
384
+ Args:
385
+ wav (1D array): waveform of the input
386
+ fs (int): sampling rate of the waveform, 16000 or 8000.
387
+ If fs=8000, the waveform will be resampled to 16000Hz.
388
+ Output:
389
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
390
+ D=80, and T is the number of frames.
391
+ """
392
+ if wav.ndim > 1:
393
+ wav = np.squeeze(wav)
394
+
395
+ # by default, we extract the mean if stereo
396
+ if len(wav.shape) == 2:
397
+ wav = wav.mean(1)
398
+
399
+ # Resample to 16000 or 8000 if needed
400
+ if fs > 16000:
401
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
402
+ fs = 16000
403
+ elif 8000 < fs < 16000:
404
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
405
+ fs = 8000
406
+ elif fs < 8000:
407
+ raise RuntimeError(f"Unsupported sample rate {fs}")
408
+
409
+ if fs == 8000:
410
+ if self._eightk_method == "resample":
411
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
412
+ # extraction
413
+ wav = scipy.signal.resample_poly(wav, 2, 1)
414
+ fs = 16000
415
+ # Do nothing here for fillzero method
416
+ elif fs != 16000:
417
+ # Input audio is not a supported sample rate.
418
+ raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
419
+
420
+ preemphasis = 0.97
421
+
422
+ if fs == 8000:
423
+ n_fft = 256
424
+ win_length = 200
425
+ hop_length = 80
426
+ fft_window = self._hamming200
427
+ elif fs == 16000:
428
+ n_fft = 512
429
+ win_length = 400
430
+ hop_length = 160
431
+ fft_window = self._hamming400
432
+
433
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
434
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
435
+ # Here we don't use stride_tricks since the input array may not satisfy
436
+ # memory layout requirement and we need writeable output
437
+ # Here we only use list of views before copy to desination
438
+ # so it is more efficient than broadcasting
439
+ y_frames = np.array(
440
+ [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
441
+ dtype=np.float32,
442
+ )
443
+
444
+ # Spec 2: SpeechLib applies preemphasis within each batch
445
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
446
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
447
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
448
+
449
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
450
+
451
+ if fs == 8000:
452
+ # Need to pad the output to look like 16 kHz data but with zeros in
453
+ # the 4 to 8 kHz bins.
454
+ frames, bins = S.shape
455
+ padarray = np.zeros((frames, bins))
456
+ S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
457
+
458
+ spec = np.abs(S).astype(np.float32)
459
+ return spec
460
+
461
+ def _extract_features(self, wav, fs):
462
+ """Extract log filterbank features from waveform.
463
+ Args:
464
+ wav (1D array): waveform of the input
465
+ fs (int): sampling rate of the waveform, 16000 or 8000.
466
+ If fs=8000, the waveform will be resampled to 16000Hz.
467
+ Output:
468
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
469
+ D=80, and T is the number of frames.
470
+ """
471
+ spec = self._extract_spectrogram(wav, fs)
472
+ spec_power = spec**2
473
+
474
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
475
+ log_fbank = np.log(fbank_power).astype(np.float32)
476
+
477
+ return log_fbank
478
+
479
+ def _compute_audio_embed_size(self, audio_frames):
480
+ integer = audio_frames // self.compression_rate
481
+ remainder = audio_frames % self.compression_rate
482
+
483
+ result = integer if remainder == 0 else integer + 1
484
+
485
+ integer = result // self.qformer_compression_rate
486
+ remainder = result % self.qformer_compression_rate
487
+ result = integer if remainder == 0 else integer + 1 # qformer compression
488
+
489
+ return result
490
+
491
+
492
+ class PhiOProcessor(ProcessorMixin):
493
+ r"""
494
+ Constructs a PhiO processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.
495
+
496
+ [`PhiOProcessor`] offers all the functionalities of [`PhiOImageProcessor`] and [`GPT2Tokenizer`]. See the
497
+ [`~PhiOProcessor.__call__`] and [`~PhiOProcessor.decode`] for more information.
498
+
499
+ Args:
500
+ image_processor ([`PhiOImageProcessor`], *optional*):
501
+ The image processor is a required input.
502
+ tokenizer ([`GPT2Tokenizer`], *optional*):
503
+ The tokenizer is a required input.
504
+ """
505
+
506
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
507
+ tokenizer_class = "GPT2TokenizerFast"
508
+ image_processor_class = "AutoImageProcessor" # PhiOImageProcessor will be registered later
509
+ audio_processor_class = "AutoFeatureExtractor" # PhiOAudioFeatureExtractor will be registered later
510
+
511
+ def __init__(self, image_processor, audio_processor, tokenizer):
512
+ self.image_processor = image_processor
513
+ self.audio_processor = audio_processor
514
+ self.tokenizer = tokenizer
515
+
516
+ def __call__(
517
+ self,
518
+ text: Union[TextInput, List[TextInput]],
519
+ images: Optional[ImageInput] = None,
520
+ audios: Optional[AudioInputs] = None,
521
+ padding: Union[bool, str, PaddingStrategy] = False,
522
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
523
+ max_length=None,
524
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
525
+ ) -> BatchFeature:
526
+ """
527
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
528
+ and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
529
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
530
+ PhiOImageProcessor's [`~PhiOImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
531
+ of the above two methods for more information.
532
+
533
+ Args:
534
+ text (`str`, `List[str]`, `List[List[str]]`):
535
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
536
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
537
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
538
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
539
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
540
+ tensor. Both channels-first and channels-last formats are supported.
541
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
542
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
543
+ index) among:
544
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
545
+ sequence if provided).
546
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
547
+ acceptable input length for the model if that argument is not provided.
548
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
549
+ lengths).
550
+ max_length (`int`, *optional*):
551
+ Maximum length of the returned list and optionally padding length (see above).
552
+ truncation (`bool`, *optional*):
553
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
554
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
555
+ If set, will return tensors of a particular framework. Acceptable values are:
556
+
557
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
558
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
559
+ - `'np'`: Return NumPy `np.ndarray` objects.
560
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
561
+
562
+ Returns:
563
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
564
+
565
+ - **input_ids** -- List of token ids to be fed to a model.
566
+ - **input_image_embeds** -- Pixel values to be fed to a model.
567
+ - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
568
+ - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
569
+ - **input_audio_embeds** -- Audio embeddings to be fed to a model.
570
+ - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
571
+ - **audio_attention_mask** -- List of attention masks for each audio in `input_audio_embeds`.
572
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
573
+ """
574
+ image_inputs = self.image_processor(images, return_tensors=return_tensors) if images is not None else {}
575
+ audio_inputs = self.audio_processor(audios, return_tensors=return_tensors) if audios is not None else {}
576
+ inputs = self._convert_images_audios_text_to_inputs(
577
+ image_inputs,
578
+ audio_inputs,
579
+ text,
580
+ padding=padding,
581
+ truncation=truncation,
582
+ max_length=max_length,
583
+ return_tensors=return_tensors,
584
+ )
585
+
586
+ # idenfity the input mode
587
+ if len(image_inputs) > 0 and len(audio_inputs) > 0:
588
+ input_mode = InputMode.VISION_SPEECH
589
+ elif len(image_inputs) > 0:
590
+ input_mode = InputMode.VISION
591
+ elif len(audio_inputs) > 0:
592
+ input_mode = InputMode.SPEECH
593
+ else:
594
+ input_mode = InputMode.LANGUAGE
595
+ inputs["input_mode"] = torch.tensor([input_mode.value], dtype=torch.long)
596
+
597
+ return inputs
598
+
599
+ @property
600
+ def special_image_token_id(self):
601
+ return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
602
+
603
+ def get_special_image_token_id(self):
604
+ return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
605
+
606
+ def _convert_images_audios_text_to_inputs(
607
+ self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None
608
+ ):
609
+ # prepare image id to image input ids
610
+ if len(images) > 0:
611
+ input_image_embeds = images["input_image_embeds"]
612
+ image_sizes = images["image_sizes"]
613
+ image_attention_mask = images["image_attention_mask"]
614
+ num_img_tokens = images['num_img_tokens']
615
+ else:
616
+ input_image_embeds = torch.tensor([])
617
+ image_sizes = torch.tensor([])
618
+ image_attention_mask = torch.tensor([])
619
+ num_img_tokens = []
620
+
621
+ # prepare audio id to audio input ids
622
+ if len(audios) > 0:
623
+ input_audio_embeds = audios["input_audio_embeds"]
624
+ audio_embed_sizes = audios["audio_embed_sizes"]
625
+ audio_attention_mask = audios.get("audio_attention_mask", torch.tensor([]))
626
+ else:
627
+ input_audio_embeds = torch.tensor([])
628
+ audio_embed_sizes = torch.tensor([])
629
+ audio_attention_mask = torch.tensor([])
630
+
631
+ # Replace certain special tokens for compatibility
632
+ # Ref: https://stackoverflow.com/questions/11475885/python-replace-regex
633
+ if isinstance(text, str):
634
+ text = [text]
635
+ assert isinstance(text, list)
636
+ processed_text = [re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, t) for t in text]
637
+ processed_text = [re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, t) for t in processed_text]
638
+
639
+ input_ids_list = [self.tokenizer(t).input_ids for t in processed_text]
640
+
641
+ img_cnt, audio_cnt = 0, 0 # only needed for later assertion
642
+ image_token_count_iter = iter(num_img_tokens)
643
+ audio_embed_size_iter = iter(audio_embed_sizes.tolist())
644
+ new_input_ids_list = []
645
+ for input_ids in input_ids_list:
646
+ i = 0
647
+ while i < len(input_ids):
648
+ token_id = input_ids[i]
649
+ if token_id == _AUDIO_SPECIAL_TOKEN_ID:
650
+ token_count = next(audio_embed_size_iter)
651
+ audio_cnt += 1
652
+ elif token_id == _IMAGE_SPECIAL_TOKEN_ID:
653
+ token_count = next(image_token_count_iter)
654
+ img_cnt += 1
655
+ else:
656
+ i += 1
657
+ continue
658
+ tokens = [token_id] * token_count
659
+ input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
660
+ i += token_count
661
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
662
+ new_input_ids_list.append(input_ids)
663
+ lengths = torch.tensor([len(input_ids) for input_ids in new_input_ids_list])
664
+ max_len = lengths.max()
665
+ input_ids = input_ids.new_full((len(new_input_ids_list), max_len), self.tokenizer.pad_token_id)
666
+ # batched inference requires left padding
667
+ for i in range(len(new_input_ids_list)):
668
+ input_ids[i, max_len - len(new_input_ids_list[i]):] = new_input_ids_list[i]
669
+
670
+ # If the below assertion fails, it might be that input pure-text
671
+ # messages contain image/audio special tokens literally
672
+ # (<|endoftext10|>, <|endoftext11|>).
673
+ assert (
674
+ img_cnt == len(num_img_tokens)
675
+ ), (
676
+ f"Number of image tokens in prompt_token_ids ({img_cnt}) "
677
+ f"does not match number of images ({len(num_img_tokens)})"
678
+ )
679
+ assert (
680
+ audio_cnt == len(audio_embed_sizes)
681
+ ), (
682
+ f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
683
+ f"does not match number of audios ({len(audio_embed_sizes)})"
684
+ )
685
+
686
+ # prepare attention mask
687
+ seq_range = torch.arange(max_len - 1, -1, -1)
688
+ attention_mask = seq_range.unsqueeze(0) < lengths.unsqueeze(1)
689
+
690
+ # prepare batch feature
691
+ data = {
692
+ "input_ids": input_ids,
693
+ "input_image_embeds": input_image_embeds,
694
+ "image_sizes": image_sizes,
695
+ "image_attention_mask": image_attention_mask,
696
+ "input_audio_embeds": input_audio_embeds,
697
+ "audio_embed_sizes": audio_embed_sizes,
698
+ "audio_attention_mask": audio_attention_mask,
699
+ "attention_mask": attention_mask,
700
+ }
701
+
702
+ return BatchFeature(
703
+ data=data
704
+ )
705
+
706
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
707
+ def batch_decode(self, *args, **kwargs):
708
+ """
709
+ This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
710
+ refer to the docstring of this method for more information.
711
+ """
712
+ return self.tokenizer.batch_decode(*args, **kwargs)
713
+
714
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
715
+ def decode(self, *args, **kwargs):
716
+ """
717
+ This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
718
+ the docstring of this method for more information.
719
+ """
720
+ return self.tokenizer.decode(*args, **kwargs)
721
+
722
+ @property
723
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
724
+ def model_input_names(self):
725
+ tokenizer_input_names = self.tokenizer.model_input_names
726
+ image_processor_input_names = self.image_processor.model_input_names
727
+ audio_processor_input_names = self.audio_processor.model_input_names
728
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
729
+
730
+
731
+ AutoImageProcessor.register("PhiOImageProcessor", PhiOImageProcessor)
732
+ AutoFeatureExtractor.register("PhiOAudioFeatureExtractor", PhiOAudioFeatureExtractor)
onnx/speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
onnx/vision_siglip_navit.py ADDED
@@ -0,0 +1,1721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
+ }
29
+
30
+
31
+ class SiglipTextConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
42
+ the `inputs_ids` passed when calling [`SiglipModel`].
43
+ hidden_size (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ num_hidden_layers (`int`, *optional*, defaults to 12):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 12):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ max_position_embeddings (`int`, *optional*, defaults to 64):
52
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
53
+ just in case (e.g., 512 or 1024 or 2048).
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
55
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
58
+ The epsilon used by the layer normalization layers.
59
+ attention_dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout ratio for the attention probabilities.
61
+ pad_token_id (`int`, *optional*, defaults to 1):
62
+ The id of the padding token in the vocabulary.
63
+ bos_token_id (`int`, *optional*, defaults to 49406):
64
+ The id of the beginning-of-sequence token in the vocabulary.
65
+ eos_token_id (`int`, *optional*, defaults to 49407):
66
+ The id of the end-of-sequence token in the vocabulary.
67
+ Example:
68
+ ```python
69
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
70
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
71
+ >>> configuration = SiglipTextConfig()
72
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
73
+ >>> model = SiglipTextModel(configuration)
74
+ >>> # Accessing the model configuration
75
+ >>> configuration = model.config
76
+ ```"""
77
+
78
+ model_type = "siglip_text_model"
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_size=32000,
83
+ hidden_size=768,
84
+ intermediate_size=3072,
85
+ num_hidden_layers=12,
86
+ num_attention_heads=12,
87
+ max_position_embeddings=64,
88
+ hidden_act="gelu_pytorch_tanh",
89
+ layer_norm_eps=1e-6,
90
+ attention_dropout=0.0,
91
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
92
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
93
+ pad_token_id=1,
94
+ bos_token_id=49406,
95
+ eos_token_id=49407,
96
+ _flash_attn_2_enabled=True,
97
+ **kwargs,
98
+ ):
99
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
100
+
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.layer_norm_eps = layer_norm_eps
108
+ self.hidden_act = hidden_act
109
+ self.attention_dropout = attention_dropout
110
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
111
+
112
+ @classmethod
113
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
114
+ cls._set_token_in_kwargs(kwargs)
115
+
116
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
117
+
118
+ # get the text config dict if we are loading from SiglipConfig
119
+ if config_dict.get("model_type") == "siglip":
120
+ config_dict = config_dict["text_config"]
121
+
122
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
123
+ logger.warning(
124
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
125
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
126
+ )
127
+
128
+ return cls.from_dict(config_dict, **kwargs)
129
+
130
+
131
+ class SiglipVisionConfig(PretrainedConfig):
132
+ r"""
133
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
134
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
135
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
136
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
137
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
138
+ documentation from [`PretrainedConfig`] for more information.
139
+ Args:
140
+ hidden_size (`int`, *optional*, defaults to 768):
141
+ Dimensionality of the encoder layers and the pooler layer.
142
+ intermediate_size (`int`, *optional*, defaults to 3072):
143
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
144
+ num_hidden_layers (`int`, *optional*, defaults to 12):
145
+ Number of hidden layers in the Transformer encoder.
146
+ num_attention_heads (`int`, *optional*, defaults to 12):
147
+ Number of attention heads for each attention layer in the Transformer encoder.
148
+ num_channels (`int`, *optional*, defaults to 3):
149
+ Number of channels in the input images.
150
+ image_size (`int`, *optional*, defaults to 224):
151
+ The size (resolution) of each image.
152
+ patch_size (`int`, *optional*, defaults to 16):
153
+ The size (resolution) of each patch.
154
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
155
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
156
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
157
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
158
+ The epsilon used by the layer normalization layers.
159
+ attention_dropout (`float`, *optional*, defaults to 0.0):
160
+ The dropout ratio for the attention probabilities.
161
+ Example:
162
+ ```python
163
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
164
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
165
+ >>> configuration = SiglipVisionConfig()
166
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
167
+ >>> model = SiglipVisionModel(configuration)
168
+ >>> # Accessing the model configuration
169
+ >>> configuration = model.config
170
+ ```"""
171
+
172
+ model_type = "siglip_vision_model"
173
+
174
+ def __init__(
175
+ self,
176
+ hidden_size=768,
177
+ intermediate_size=3072,
178
+ num_hidden_layers=12,
179
+ num_attention_heads=12,
180
+ num_channels=3,
181
+ image_size=224,
182
+ patch_size=16,
183
+ hidden_act="gelu_pytorch_tanh",
184
+ layer_norm_eps=1e-6,
185
+ attention_dropout=0.0,
186
+ _flash_attn_2_enabled=True,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(**kwargs)
190
+
191
+ self.hidden_size = hidden_size
192
+ self.intermediate_size = intermediate_size
193
+ self.num_hidden_layers = num_hidden_layers
194
+ self.num_attention_heads = num_attention_heads
195
+ self.num_channels = num_channels
196
+ self.patch_size = patch_size
197
+ self.image_size = image_size
198
+ self.attention_dropout = attention_dropout
199
+ self.layer_norm_eps = layer_norm_eps
200
+ self.hidden_act = hidden_act
201
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
202
+
203
+ @classmethod
204
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
205
+ cls._set_token_in_kwargs(kwargs)
206
+
207
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
208
+
209
+ # get the vision config dict if we are loading from SiglipConfig
210
+ if config_dict.get("model_type") == "siglip":
211
+ config_dict = config_dict["vision_config"]
212
+
213
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
214
+ logger.warning(
215
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
216
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
217
+ )
218
+
219
+ return cls.from_dict(config_dict, **kwargs)
220
+
221
+
222
+ class SiglipConfig(PretrainedConfig):
223
+ r"""
224
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
225
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
226
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
227
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
228
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
229
+ documentation from [`PretrainedConfig`] for more information.
230
+ Args:
231
+ text_config (`dict`, *optional*):
232
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
233
+ vision_config (`dict`, *optional*):
234
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
235
+ kwargs (*optional*):
236
+ Dictionary of keyword arguments.
237
+ Example:
238
+ ```python
239
+ >>> from transformers import SiglipConfig, SiglipModel
240
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
+ >>> configuration = SiglipConfig()
242
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
243
+ >>> model = SiglipModel(configuration)
244
+ >>> # Accessing the model configuration
245
+ >>> configuration = model.config
246
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
247
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
248
+ >>> # Initializing a SiglipText and SiglipVision configuration
249
+ >>> config_text = SiglipTextConfig()
250
+ >>> config_vision = SiglipVisionConfig()
251
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
252
+ ```"""
253
+
254
+ model_type = "siglip"
255
+
256
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
257
+ super().__init__(**kwargs)
258
+
259
+ if text_config is None:
260
+ text_config = {}
261
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
262
+
263
+ if vision_config is None:
264
+ vision_config = {}
265
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
266
+
267
+ self.text_config = SiglipTextConfig(**text_config)
268
+ self.vision_config = SiglipVisionConfig(**vision_config)
269
+
270
+ self.initializer_factor = 1.0
271
+
272
+ @classmethod
273
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
274
+ r"""
275
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
276
+ model configuration.
277
+ Returns:
278
+ [`SiglipConfig`]: An instance of a configuration object
279
+ """
280
+
281
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
282
+
283
+ # coding=utf-8
284
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
285
+ #
286
+ # Licensed under the Apache License, Version 2.0 (the "License");
287
+ # you may not use this file except in compliance with the License.
288
+ # You may obtain a copy of the License at
289
+ #
290
+ # http://www.apache.org/licenses/LICENSE-2.0
291
+ #
292
+ # Unless required by applicable law or agreed to in writing, software
293
+ # distributed under the License is distributed on an "AS IS" BASIS,
294
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
295
+ # See the License for the specific language governing permissions and
296
+ # limitations under the License.
297
+ """ PyTorch Siglip model."""
298
+
299
+
300
+ import math
301
+ import warnings
302
+ from dataclasses import dataclass
303
+ from typing import Any, Optional, Tuple, Union
304
+
305
+ import numpy as np
306
+ import torch
307
+ import torch.nn.functional as F
308
+ import torch.utils.checkpoint
309
+ from torch import nn
310
+ from torch.nn.init import _calculate_fan_in_and_fan_out
311
+
312
+ from transformers.activations import ACT2FN
313
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
314
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
315
+ from transformers.modeling_utils import PreTrainedModel
316
+ from transformers.utils import (
317
+ ModelOutput,
318
+ add_start_docstrings,
319
+ add_start_docstrings_to_model_forward,
320
+ is_flash_attn_2_available,
321
+ logging,
322
+ replace_return_docstrings,
323
+ )
324
+
325
+ logger = logging.get_logger(__name__)
326
+
327
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
328
+
329
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
330
+ "google/siglip-base-patch16-224",
331
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
332
+ ]
333
+
334
+ if is_flash_attn_2_available():
335
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
336
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
337
+
338
+
339
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
340
+ def _get_unpad_data(attention_mask):
341
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
344
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
345
+ return (
346
+ indices,
347
+ cu_seqlens,
348
+ max_seqlen_in_batch,
349
+ )
350
+
351
+
352
+ def _trunc_normal_(tensor, mean, std, a, b):
353
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
354
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
355
+ def norm_cdf(x):
356
+ # Computes standard normal cumulative distribution function
357
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
358
+
359
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
360
+ warnings.warn(
361
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
362
+ "The distribution of values may be incorrect.",
363
+ stacklevel=2,
364
+ )
365
+
366
+ # Values are generated by using a truncated uniform distribution and
367
+ # then using the inverse CDF for the normal distribution.
368
+ # Get upper and lower cdf values
369
+ l = norm_cdf((a - mean) / std)
370
+ u = norm_cdf((b - mean) / std)
371
+
372
+ # Uniformly fill tensor with values from [l, u], then translate to
373
+ # [2l-1, 2u-1].
374
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
375
+
376
+ # Use inverse cdf transform for normal distribution to get truncated
377
+ # standard normal
378
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
379
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
380
+ og_dtype = tensor.dtype
381
+ tensor = tensor.to(torch.float32)
382
+ tensor.erfinv_()
383
+ tensor = tensor.to(og_dtype)
384
+ else:
385
+ tensor.erfinv_()
386
+
387
+ # Transform to proper mean, std
388
+ tensor.mul_(std * math.sqrt(2.0))
389
+ tensor.add_(mean)
390
+
391
+ # Clamp to ensure it's in the proper range
392
+ if tensor.dtype == torch.float16:
393
+ # The `clamp_` op is not (yet?) defined in float16+cpu
394
+ tensor = tensor.to(torch.float32)
395
+ tensor.clamp_(min=a, max=b)
396
+ tensor = tensor.to(torch.float16)
397
+ else:
398
+ tensor.clamp_(min=a, max=b)
399
+
400
+
401
+ def trunc_normal_tf_(
402
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
403
+ ) -> torch.Tensor:
404
+ """Fills the input Tensor with values drawn from a truncated
405
+ normal distribution. The values are effectively drawn from the
406
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
407
+ with values outside :math:`[a, b]` redrawn until they are within
408
+ the bounds. The method used for generating the random values works
409
+ best when :math:`a \\leq \text{mean} \\leq b`.
410
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
411
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
412
+ and the result is subsquently scaled and shifted by the mean and std args.
413
+ Args:
414
+ tensor: an n-dimensional `torch.Tensor`
415
+ mean: the mean of the normal distribution
416
+ std: the standard deviation of the normal distribution
417
+ a: the minimum cutoff value
418
+ b: the maximum cutoff value
419
+ """
420
+ with torch.no_grad():
421
+ _trunc_normal_(tensor, 0, 1.0, a, b)
422
+ tensor.mul_(std).add_(mean)
423
+
424
+
425
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
426
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
427
+ if mode == "fan_in":
428
+ denom = fan_in
429
+ elif mode == "fan_out":
430
+ denom = fan_out
431
+ elif mode == "fan_avg":
432
+ denom = (fan_in + fan_out) / 2
433
+
434
+ variance = scale / denom
435
+
436
+ if distribution == "truncated_normal":
437
+ # constant is stddev of standard normal truncated to (-2, 2)
438
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
439
+ elif distribution == "normal":
440
+ with torch.no_grad():
441
+ tensor.normal_(std=math.sqrt(variance))
442
+ elif distribution == "uniform":
443
+ bound = math.sqrt(3 * variance)
444
+ with torch.no_grad():
445
+ tensor.uniform_(-bound, bound)
446
+ else:
447
+ raise ValueError(f"invalid distribution {distribution}")
448
+
449
+
450
+ def lecun_normal_(tensor):
451
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
452
+
453
+
454
+ def default_flax_embed_init(tensor):
455
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
456
+
457
+
458
+ @dataclass
459
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
460
+ class SiglipVisionModelOutput(ModelOutput):
461
+ """
462
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
463
+ Args:
464
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
465
+ The image embeddings obtained by applying the projection layer to the pooler_output.
466
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
467
+ Sequence of hidden-states at the output of the last layer of the model.
468
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
469
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
470
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
471
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
472
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
473
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
474
+ sequence_length)`.
475
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
476
+ heads.
477
+ """
478
+
479
+ image_embeds: Optional[torch.FloatTensor] = None
480
+ last_hidden_state: torch.FloatTensor = None
481
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
482
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
483
+
484
+
485
+ @dataclass
486
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
487
+ class SiglipTextModelOutput(ModelOutput):
488
+ """
489
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
490
+ Args:
491
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
492
+ The text embeddings obtained by applying the projection layer to the pooler_output.
493
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
494
+ Sequence of hidden-states at the output of the last layer of the model.
495
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
496
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
497
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
498
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
499
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
500
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
501
+ sequence_length)`.
502
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
503
+ heads.
504
+ """
505
+
506
+ text_embeds: Optional[torch.FloatTensor] = None
507
+ last_hidden_state: torch.FloatTensor = None
508
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
509
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
510
+
511
+
512
+ @dataclass
513
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
514
+ class SiglipOutput(ModelOutput):
515
+ """
516
+ Args:
517
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
518
+ Contrastive loss for image-text similarity.
519
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
520
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
521
+ similarity scores.
522
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
523
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
524
+ similarity scores.
525
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
526
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
527
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
528
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
529
+ text_model_output(`BaseModelOutputWithPooling`):
530
+ The output of the [`SiglipTextModel`].
531
+ vision_model_output(`BaseModelOutputWithPooling`):
532
+ The output of the [`SiglipVisionModel`].
533
+ """
534
+
535
+ loss: Optional[torch.FloatTensor] = None
536
+ logits_per_image: torch.FloatTensor = None
537
+ logits_per_text: torch.FloatTensor = None
538
+ text_embeds: torch.FloatTensor = None
539
+ image_embeds: torch.FloatTensor = None
540
+ text_model_output: BaseModelOutputWithPooling = None
541
+ vision_model_output: BaseModelOutputWithPooling = None
542
+
543
+ def to_tuple(self) -> Tuple[Any]:
544
+ return tuple(
545
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
546
+ for k in self.keys()
547
+ )
548
+
549
+
550
+ @torch.jit.script_if_tracing
551
+ def filter_position_ids(patch_attention_mask: torch.Tensor, position_ids: torch.Tensor, boundaries: torch.Tensor, num_patches_per_side: int):
552
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
553
+ nb_patches_h = p_attn_mask[:, 0].sum()
554
+ nb_patches_w = p_attn_mask[0].sum()
555
+
556
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
557
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
558
+
559
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
560
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
561
+
562
+ pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
563
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
564
+ return position_ids
565
+
566
+
567
+ class SiglipVisionEmbeddings(nn.Module):
568
+ def __init__(self, config: SiglipVisionConfig):
569
+ super().__init__()
570
+ self.config = config
571
+ self.embed_dim = config.hidden_size
572
+ self.image_size = config.image_size
573
+ self.patch_size = config.patch_size
574
+
575
+ self.patch_embedding = nn.Conv2d(
576
+ in_channels=config.num_channels,
577
+ out_channels=self.embed_dim,
578
+ kernel_size=self.patch_size,
579
+ stride=self.patch_size,
580
+ padding="valid",
581
+ )
582
+
583
+ self.num_patches_per_side = self.image_size // self.patch_size
584
+ self.num_patches = self.num_patches_per_side**2
585
+ self.num_positions = self.num_patches
586
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
587
+
588
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
589
+ batch_size = pixel_values.size(0)
590
+
591
+ patch_embeds = self.patch_embedding(pixel_values)
592
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
593
+
594
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
595
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
596
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
597
+ position_ids = torch.full(
598
+ size=(
599
+ batch_size,
600
+ max_nb_patches_h * max_nb_patches_w,
601
+ ),
602
+ fill_value=0,
603
+ )
604
+
605
+ position_ids = filter_position_ids(patch_attention_mask, position_ids, boundaries, self.num_patches_per_side)
606
+ position_ids = position_ids.to(self.position_embedding.weight.device)
607
+ embeddings = embeddings + self.position_embedding(position_ids)
608
+ return embeddings
609
+
610
+
611
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
612
+ class SiglipTextEmbeddings(nn.Module):
613
+ def __init__(self, config: SiglipTextConfig):
614
+ super().__init__()
615
+ embed_dim = config.hidden_size
616
+
617
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
618
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
619
+
620
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
621
+ self.register_buffer(
622
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
623
+ )
624
+
625
+ def forward(
626
+ self,
627
+ input_ids: Optional[torch.LongTensor] = None,
628
+ position_ids: Optional[torch.LongTensor] = None,
629
+ inputs_embeds: Optional[torch.FloatTensor] = None,
630
+ ) -> torch.Tensor:
631
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
632
+
633
+ if position_ids is None:
634
+ position_ids = self.position_ids[:, :seq_length]
635
+
636
+ if inputs_embeds is None:
637
+ inputs_embeds = self.token_embedding(input_ids)
638
+
639
+ position_embeddings = self.position_embedding(position_ids)
640
+ embeddings = inputs_embeds + position_embeddings
641
+
642
+ return embeddings
643
+
644
+
645
+ class SiglipAttention(nn.Module):
646
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
647
+
648
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
649
+ def __init__(self, config):
650
+ super().__init__()
651
+ self.config = config
652
+ self.embed_dim = config.hidden_size
653
+ self.num_heads = config.num_attention_heads
654
+ self.head_dim = self.embed_dim // self.num_heads
655
+ if self.head_dim * self.num_heads != self.embed_dim:
656
+ raise ValueError(
657
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
658
+ f" {self.num_heads})."
659
+ )
660
+ self.scale = self.head_dim**-0.5
661
+ self.dropout = config.attention_dropout
662
+
663
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
664
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
665
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
666
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ attention_mask: Optional[torch.Tensor] = None,
672
+ output_attentions: Optional[bool] = False,
673
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
674
+ """Input shape: Batch x Time x Channel"""
675
+
676
+ batch_size, q_len, _ = hidden_states.size()
677
+
678
+ query_states = self.q_proj(hidden_states)
679
+ key_states = self.k_proj(hidden_states)
680
+ value_states = self.v_proj(hidden_states)
681
+
682
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
683
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
684
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
685
+
686
+ k_v_seq_len = key_states.shape[-2]
687
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
688
+
689
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
690
+ raise ValueError(
691
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
692
+ f" {attn_weights.size()}"
693
+ )
694
+
695
+ if attention_mask is not None:
696
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
697
+ raise ValueError(
698
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
699
+ )
700
+ attn_weights = attn_weights + attention_mask
701
+
702
+ # upcast attention to fp32
703
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
704
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
705
+ attn_output = torch.matmul(attn_weights, value_states)
706
+
707
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
708
+ raise ValueError(
709
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
710
+ f" {attn_output.size()}"
711
+ )
712
+
713
+ attn_output = attn_output.transpose(1, 2).contiguous()
714
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
715
+
716
+ attn_output = self.out_proj(attn_output)
717
+
718
+ return attn_output, attn_weights
719
+
720
+
721
+ class SiglipFlashAttention2(SiglipAttention):
722
+ """
723
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
724
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
725
+ flash attention and deal with padding tokens in case the input contains any of them.
726
+ """
727
+
728
+ def __init__(self, *args, **kwargs):
729
+ super().__init__(*args, **kwargs)
730
+ self.is_causal = False # Hack to make sure we don't use a causal mask
731
+
732
+ def forward(
733
+ self,
734
+ hidden_states: torch.Tensor,
735
+ attention_mask: Optional[torch.LongTensor] = None,
736
+ position_ids: Optional[torch.LongTensor] = None,
737
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
738
+ output_attentions: bool = False,
739
+ use_cache: bool = False,
740
+ **kwargs,
741
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
742
+ output_attentions = False
743
+
744
+ bsz, q_len, _ = hidden_states.size()
745
+
746
+ query_states = self.q_proj(hidden_states)
747
+ key_states = self.k_proj(hidden_states)
748
+ value_states = self.v_proj(hidden_states)
749
+
750
+ # Flash attention requires the input to have the shape
751
+ # batch_size x seq_length x head_dim x hidden_dim
752
+ # therefore we just need to keep the original shape
753
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
754
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
755
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
756
+
757
+ kv_seq_len = key_states.shape[-2]
758
+ if past_key_value is not None:
759
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
760
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
761
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
762
+
763
+ # if past_key_value is not None:
764
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
765
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
766
+
767
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
768
+ # to be able to avoid many of these transpose/reshape/view.
769
+ query_states = query_states.transpose(1, 2)
770
+ key_states = key_states.transpose(1, 2)
771
+ value_states = value_states.transpose(1, 2)
772
+
773
+ dropout_rate = self.dropout if self.training else 0.0
774
+
775
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
776
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
777
+ # cast them back in the correct dtype just to be sure everything works as expected.
778
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
779
+ # in fp32. (LlamaRMSNorm handles it correctly)
780
+
781
+ input_dtype = query_states.dtype
782
+ if input_dtype == torch.float32:
783
+ if torch.is_autocast_enabled():
784
+ target_dtype = torch.get_autocast_gpu_dtype()
785
+ # Handle the case where the model is quantized
786
+ elif hasattr(self.config, "_pre_quantization_dtype"):
787
+ target_dtype = self.config._pre_quantization_dtype
788
+ else:
789
+ target_dtype = self.q_proj.weight.dtype
790
+
791
+ logger.warning_once(
792
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
793
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
794
+ f" {target_dtype}."
795
+ )
796
+
797
+ query_states = query_states.to(target_dtype)
798
+ key_states = key_states.to(target_dtype)
799
+ value_states = value_states.to(target_dtype)
800
+
801
+ attn_output = self._flash_attention_forward(
802
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
803
+ )
804
+
805
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
806
+ attn_output = self.out_proj(attn_output)
807
+
808
+ if not output_attentions:
809
+ attn_weights = None
810
+
811
+ return attn_output, attn_weights
812
+
813
+ def _flash_attention_forward(
814
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
815
+ ):
816
+ """
817
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
818
+ first unpad the input, then computes the attention scores and pad the final attention scores.
819
+ Args:
820
+ query_states (`torch.Tensor`):
821
+ Input query states to be passed to Flash Attention API
822
+ key_states (`torch.Tensor`):
823
+ Input key states to be passed to Flash Attention API
824
+ value_states (`torch.Tensor`):
825
+ Input value states to be passed to Flash Attention API
826
+ attention_mask (`torch.Tensor`):
827
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
828
+ position of padding tokens and 1 for the position of non-padding tokens.
829
+ dropout (`int`, *optional*):
830
+ Attention dropout
831
+ softmax_scale (`float`, *optional*):
832
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
833
+ """
834
+
835
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
836
+ causal = self.is_causal and query_length != 1
837
+
838
+ # Contains at least one padding token in the sequence
839
+ if attention_mask is not None:
840
+ batch_size = query_states.shape[0]
841
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
842
+ query_states, key_states, value_states, attention_mask, query_length
843
+ )
844
+
845
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
846
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
847
+
848
+ attn_output_unpad = flash_attn_varlen_func(
849
+ query_states,
850
+ key_states,
851
+ value_states,
852
+ cu_seqlens_q=cu_seqlens_q,
853
+ cu_seqlens_k=cu_seqlens_k,
854
+ max_seqlen_q=max_seqlen_in_batch_q,
855
+ max_seqlen_k=max_seqlen_in_batch_k,
856
+ dropout_p=dropout,
857
+ softmax_scale=softmax_scale,
858
+ causal=causal,
859
+ )
860
+
861
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
862
+ else:
863
+ attn_output = flash_attn_func(
864
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
865
+ )
866
+
867
+ return attn_output
868
+
869
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
870
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
871
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
872
+
873
+ key_layer = index_first_axis(
874
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
875
+ )
876
+ value_layer = index_first_axis(
877
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
878
+ )
879
+ if query_length == kv_seq_len:
880
+ query_layer = index_first_axis(
881
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
882
+ )
883
+ cu_seqlens_q = cu_seqlens_k
884
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
885
+ indices_q = indices_k
886
+ elif query_length == 1:
887
+ max_seqlen_in_batch_q = 1
888
+ cu_seqlens_q = torch.arange(
889
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
890
+ ) # There is a memcpy here, that is very bad.
891
+ indices_q = cu_seqlens_q[:-1]
892
+ query_layer = query_layer.squeeze(1)
893
+ else:
894
+ # The -q_len: slice assumes left padding.
895
+ attention_mask = attention_mask[:, -query_length:]
896
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
897
+
898
+ return (
899
+ query_layer,
900
+ key_layer,
901
+ value_layer,
902
+ indices_q,
903
+ (cu_seqlens_q, cu_seqlens_k),
904
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
905
+ )
906
+
907
+
908
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
909
+ class SiglipMLP(nn.Module):
910
+ def __init__(self, config):
911
+ super().__init__()
912
+ self.config = config
913
+ self.activation_fn = ACT2FN[config.hidden_act]
914
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
915
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
916
+
917
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
918
+ hidden_states = self.fc1(hidden_states)
919
+ hidden_states = self.activation_fn(hidden_states)
920
+ hidden_states = self.fc2(hidden_states)
921
+ return hidden_states
922
+
923
+
924
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
925
+ class SiglipEncoderLayer(nn.Module):
926
+ def __init__(self, config: SiglipConfig):
927
+ super().__init__()
928
+ self.embed_dim = config.hidden_size
929
+ self.self_attn = (
930
+ SiglipAttention(config)
931
+ if not getattr(config, "_flash_attn_2_enabled", False)
932
+ else SiglipFlashAttention2(config)
933
+ )
934
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
935
+ self.mlp = SiglipMLP(config)
936
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
937
+
938
+ def forward(
939
+ self,
940
+ hidden_states: torch.Tensor,
941
+ attention_mask: torch.Tensor,
942
+ output_attentions: Optional[bool] = False,
943
+ ) -> Tuple[torch.FloatTensor]:
944
+ """
945
+ Args:
946
+ hidden_states (`torch.FloatTensor`):
947
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
948
+ attention_mask (`torch.FloatTensor`):
949
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
950
+ output_attentions (`bool`, *optional*, defaults to `False`):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
952
+ returned tensors for more detail.
953
+ """
954
+ residual = hidden_states
955
+
956
+ hidden_states = self.layer_norm1(hidden_states)
957
+ hidden_states, attn_weights = self.self_attn(
958
+ hidden_states=hidden_states,
959
+ attention_mask=attention_mask,
960
+ output_attentions=output_attentions,
961
+ )
962
+ hidden_states = residual + hidden_states
963
+
964
+ residual = hidden_states
965
+ hidden_states = self.layer_norm2(hidden_states)
966
+ hidden_states = self.mlp(hidden_states)
967
+ hidden_states = residual + hidden_states
968
+
969
+ outputs = (hidden_states,)
970
+
971
+ if output_attentions:
972
+ outputs += (attn_weights,)
973
+
974
+ return outputs
975
+
976
+
977
+ class SiglipPreTrainedModel(PreTrainedModel):
978
+ """
979
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
980
+ models.
981
+ """
982
+
983
+ config_class = SiglipConfig
984
+ base_model_prefix = "siglip"
985
+ supports_gradient_checkpointing = True
986
+
987
+ def _init_weights(self, module):
988
+ """Initialize the weights"""
989
+
990
+ if isinstance(module, SiglipVisionEmbeddings):
991
+ width = (
992
+ self.config.vision_config.hidden_size
993
+ if isinstance(self.config, SiglipConfig)
994
+ else self.config.hidden_size
995
+ )
996
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
997
+ elif isinstance(module, nn.Embedding):
998
+ default_flax_embed_init(module.weight)
999
+ elif isinstance(module, SiglipAttention):
1000
+ nn.init.normal_(module.q_proj.weight)
1001
+ nn.init.normal_(module.k_proj.weight)
1002
+ nn.init.normal_(module.v_proj.weight)
1003
+ nn.init.normal_(module.out_proj.weight)
1004
+ nn.init.zeros_(module.q_proj.bias)
1005
+ nn.init.zeros_(module.k_proj.bias)
1006
+ nn.init.zeros_(module.v_proj.bias)
1007
+ nn.init.zeros_(module.out_proj.bias)
1008
+ elif isinstance(module, SiglipMLP):
1009
+ nn.init.normal_(module.fc1.weight)
1010
+ nn.init.normal_(module.fc2.weight)
1011
+ nn.init.normal_(module.fc1.bias, std=1e-6)
1012
+ nn.init.normal_(module.fc2.bias, std=1e-6)
1013
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
1014
+ nn.init.normal_(module.probe.data)
1015
+ nn.init.normal_(module.attention.in_proj_weight.data)
1016
+ nn.init.zeros_(module.attention.in_proj_bias.data)
1017
+ elif isinstance(module, SiglipModel):
1018
+ logit_scale_init = torch.tensor(0.0)
1019
+ module.logit_scale.data.fill_(logit_scale_init)
1020
+ module.logit_bias.data.zero_()
1021
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
1022
+ lecun_normal_(module.weight)
1023
+ if module.bias is not None:
1024
+ nn.init.zeros_(module.bias)
1025
+ elif isinstance(module, nn.LayerNorm):
1026
+ module.bias.data.zero_()
1027
+ module.weight.data.fill_(1.0)
1028
+
1029
+
1030
+ SIGLIP_START_DOCSTRING = r"""
1031
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1032
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1033
+ etc.)
1034
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1035
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1036
+ and behavior.
1037
+ Parameters:
1038
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
1039
+ Initializing with a config file does not load the weights associated with the model, only the
1040
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1041
+ """
1042
+
1043
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
1044
+ Args:
1045
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1046
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1047
+ it.
1048
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1049
+ [`PreTrainedTokenizer.__call__`] for details.
1050
+ [What are input IDs?](../glossary#input-ids)
1051
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1052
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1053
+ - 1 for tokens that are **not masked**,
1054
+ - 0 for tokens that are **masked**.
1055
+ [What are attention masks?](../glossary#attention-mask)
1056
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1057
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1058
+ config.max_position_embeddings - 1]`.
1059
+ [What are position IDs?](../glossary#position-ids)
1060
+ output_attentions (`bool`, *optional*):
1061
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1062
+ tensors for more detail.
1063
+ output_hidden_states (`bool`, *optional*):
1064
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1065
+ more detail.
1066
+ return_dict (`bool`, *optional*):
1067
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1068
+ """
1069
+
1070
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
1071
+ Args:
1072
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1073
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1074
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1075
+ output_attentions (`bool`, *optional*):
1076
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1077
+ tensors for more detail.
1078
+ output_hidden_states (`bool`, *optional*):
1079
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1080
+ more detail.
1081
+ return_dict (`bool`, *optional*):
1082
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1083
+ """
1084
+
1085
+ SIGLIP_INPUTS_DOCSTRING = r"""
1086
+ Args:
1087
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1088
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1089
+ it.
1090
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1091
+ [`PreTrainedTokenizer.__call__`] for details.
1092
+ [What are input IDs?](../glossary#input-ids)
1093
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1094
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1095
+ - 1 for tokens that are **not masked**,
1096
+ - 0 for tokens that are **masked**.
1097
+ [What are attention masks?](../glossary#attention-mask)
1098
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1099
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1100
+ config.max_position_embeddings - 1]`.
1101
+ [What are position IDs?](../glossary#position-ids)
1102
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1103
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1104
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1105
+ return_loss (`bool`, *optional*):
1106
+ Whether or not to return the contrastive loss.
1107
+ output_attentions (`bool`, *optional*):
1108
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1109
+ tensors for more detail.
1110
+ output_hidden_states (`bool`, *optional*):
1111
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1112
+ more detail.
1113
+ return_dict (`bool`, *optional*):
1114
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1115
+ """
1116
+
1117
+
1118
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
1119
+ class SiglipEncoder(nn.Module):
1120
+ """
1121
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
1122
+ [`SiglipEncoderLayer`].
1123
+ Args:
1124
+ config: SiglipConfig
1125
+ """
1126
+
1127
+ def __init__(self, config: SiglipConfig):
1128
+ super().__init__()
1129
+ self.config = config
1130
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
1131
+ self.gradient_checkpointing = False
1132
+
1133
+ # Ignore copy
1134
+ def forward(
1135
+ self,
1136
+ inputs_embeds,
1137
+ attention_mask: Optional[torch.Tensor] = None,
1138
+ output_attentions: Optional[bool] = None,
1139
+ output_hidden_states: Optional[bool] = None,
1140
+ return_dict: Optional[bool] = None,
1141
+ ) -> Union[Tuple, BaseModelOutput]:
1142
+ r"""
1143
+ Args:
1144
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1145
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1146
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1147
+ than the model's internal embedding lookup matrix.
1148
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1149
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1150
+ - 1 for tokens that are **not masked**,
1151
+ - 0 for tokens that are **masked**.
1152
+ [What are attention masks?](../glossary#attention-mask)
1153
+ output_attentions (`bool`, *optional*):
1154
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1155
+ returned tensors for more detail.
1156
+ output_hidden_states (`bool`, *optional*):
1157
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1158
+ for more detail.
1159
+ return_dict (`bool`, *optional*):
1160
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1161
+ """
1162
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1163
+ output_hidden_states = (
1164
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1165
+ )
1166
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1167
+
1168
+ encoder_states = () if output_hidden_states else None
1169
+ all_attentions = () if output_attentions else None
1170
+
1171
+ hidden_states = inputs_embeds
1172
+ for encoder_layer in self.layers:
1173
+ if output_hidden_states:
1174
+ encoder_states = encoder_states + (hidden_states,)
1175
+ if self.gradient_checkpointing and self.training:
1176
+ layer_outputs = self._gradient_checkpointing_func(
1177
+ encoder_layer.__call__,
1178
+ hidden_states,
1179
+ attention_mask,
1180
+ output_attentions,
1181
+ )
1182
+ else:
1183
+ layer_outputs = encoder_layer(
1184
+ hidden_states,
1185
+ attention_mask,
1186
+ output_attentions=output_attentions,
1187
+ )
1188
+
1189
+ hidden_states = layer_outputs[0]
1190
+
1191
+ if output_attentions:
1192
+ all_attentions = all_attentions + (layer_outputs[1],)
1193
+
1194
+ if output_hidden_states:
1195
+ encoder_states = encoder_states + (hidden_states,)
1196
+
1197
+ if not return_dict:
1198
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1199
+ return BaseModelOutput(
1200
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1201
+ )
1202
+
1203
+
1204
+ class SiglipTextTransformer(nn.Module):
1205
+ def __init__(self, config: SiglipTextConfig):
1206
+ super().__init__()
1207
+ self.config = config
1208
+ embed_dim = config.hidden_size
1209
+ self.embeddings = SiglipTextEmbeddings(config)
1210
+ self.encoder = SiglipEncoder(config)
1211
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1212
+
1213
+ self.head = nn.Linear(embed_dim, embed_dim)
1214
+
1215
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1216
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1217
+ def forward(
1218
+ self,
1219
+ input_ids: Optional[torch.Tensor] = None,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ position_ids: Optional[torch.Tensor] = None,
1222
+ output_attentions: Optional[bool] = None,
1223
+ output_hidden_states: Optional[bool] = None,
1224
+ return_dict: Optional[bool] = None,
1225
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1226
+ r"""
1227
+ Returns:
1228
+ """
1229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1230
+ output_hidden_states = (
1231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1232
+ )
1233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1234
+
1235
+ if input_ids is None:
1236
+ raise ValueError("You have to specify input_ids")
1237
+
1238
+ input_shape = input_ids.size()
1239
+ input_ids = input_ids.view(-1, input_shape[-1])
1240
+
1241
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
1242
+
1243
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
1244
+ # expand attention_mask
1245
+ if attention_mask is not None:
1246
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
1247
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1248
+
1249
+ encoder_outputs = self.encoder(
1250
+ inputs_embeds=hidden_states,
1251
+ attention_mask=attention_mask,
1252
+ output_attentions=output_attentions,
1253
+ output_hidden_states=output_hidden_states,
1254
+ return_dict=return_dict,
1255
+ )
1256
+
1257
+ last_hidden_state = encoder_outputs[0]
1258
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
1259
+
1260
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
1261
+ pooled_output = last_hidden_state[:, -1, :]
1262
+ pooled_output = self.head(pooled_output)
1263
+
1264
+ if not return_dict:
1265
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1266
+
1267
+ return BaseModelOutputWithPooling(
1268
+ last_hidden_state=last_hidden_state,
1269
+ pooler_output=pooled_output,
1270
+ hidden_states=encoder_outputs.hidden_states,
1271
+ attentions=encoder_outputs.attentions,
1272
+ )
1273
+
1274
+
1275
+ @add_start_docstrings(
1276
+ """The text model from SigLIP without any head or projection on top.""",
1277
+ SIGLIP_START_DOCSTRING,
1278
+ )
1279
+ class SiglipTextModel(SiglipPreTrainedModel):
1280
+ config_class = SiglipTextConfig
1281
+
1282
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
1283
+
1284
+ def __init__(self, config: SiglipTextConfig):
1285
+ super().__init__(config)
1286
+ self.text_model = SiglipTextTransformer(config)
1287
+ # Initialize weights and apply final processing
1288
+ self.post_init()
1289
+
1290
+ def get_input_embeddings(self) -> nn.Module:
1291
+ return self.text_model.embeddings.token_embedding
1292
+
1293
+ def set_input_embeddings(self, value):
1294
+ self.text_model.embeddings.token_embedding = value
1295
+
1296
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1297
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1298
+ def forward(
1299
+ self,
1300
+ input_ids: Optional[torch.Tensor] = None,
1301
+ attention_mask: Optional[torch.Tensor] = None,
1302
+ position_ids: Optional[torch.Tensor] = None,
1303
+ output_attentions: Optional[bool] = None,
1304
+ output_hidden_states: Optional[bool] = None,
1305
+ return_dict: Optional[bool] = None,
1306
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1307
+ r"""
1308
+ Returns:
1309
+ Examples:
1310
+ ```python
1311
+ >>> from transformers import AutoTokenizer, SiglipTextModel
1312
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1313
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1314
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1315
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1316
+ >>> outputs = model(**inputs)
1317
+ >>> last_hidden_state = outputs.last_hidden_state
1318
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
1319
+ ```"""
1320
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1321
+
1322
+ return self.text_model(
1323
+ input_ids=input_ids,
1324
+ attention_mask=attention_mask,
1325
+ position_ids=position_ids,
1326
+ output_attentions=output_attentions,
1327
+ output_hidden_states=output_hidden_states,
1328
+ return_dict=return_dict,
1329
+ )
1330
+
1331
+
1332
+ class SiglipVisionTransformer(nn.Module):
1333
+ def __init__(self, config: SiglipVisionConfig):
1334
+ super().__init__()
1335
+ self.config = config
1336
+ embed_dim = config.hidden_size
1337
+
1338
+ self.embeddings = SiglipVisionEmbeddings(config)
1339
+ self.encoder = SiglipEncoder(config)
1340
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1341
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1342
+
1343
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1344
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1345
+ def forward(
1346
+ self,
1347
+ pixel_values,
1348
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1349
+ output_attentions: Optional[bool] = None,
1350
+ output_hidden_states: Optional[bool] = None,
1351
+ return_dict: Optional[bool] = None,
1352
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1353
+ r"""
1354
+ Returns:
1355
+ """
1356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1357
+ output_hidden_states = (
1358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1359
+ )
1360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1361
+
1362
+ batch_size = pixel_values.size(0)
1363
+ if patch_attention_mask is None:
1364
+ patch_attention_mask = torch.ones(
1365
+ size=(
1366
+ batch_size,
1367
+ pixel_values.size(2) // self.config.patch_size,
1368
+ pixel_values.size(3) // self.config.patch_size,
1369
+ ),
1370
+ dtype=torch.bool,
1371
+ device=pixel_values.device,
1372
+ )
1373
+
1374
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1375
+
1376
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1377
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
1378
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
1379
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
1380
+ if not torch.any(~patch_attention_mask):
1381
+ attention_mask=None
1382
+ else:
1383
+ attention_mask = (
1384
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1385
+ if not self.config._flash_attn_2_enabled
1386
+ else patch_attention_mask
1387
+ )
1388
+
1389
+ encoder_outputs = self.encoder(
1390
+ inputs_embeds=hidden_states,
1391
+ attention_mask=attention_mask,
1392
+ output_attentions=output_attentions,
1393
+ output_hidden_states=output_hidden_states,
1394
+ return_dict=return_dict,
1395
+ )
1396
+
1397
+ last_hidden_state = encoder_outputs[0]
1398
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1399
+
1400
+ pooled_output = self.head(
1401
+ hidden_state=last_hidden_state,
1402
+ attention_mask=patch_attention_mask,
1403
+ )
1404
+
1405
+ if not return_dict:
1406
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1407
+
1408
+ return BaseModelOutputWithPooling(
1409
+ last_hidden_state=last_hidden_state,
1410
+ pooler_output=pooled_output,
1411
+ hidden_states=encoder_outputs.hidden_states,
1412
+ attentions=encoder_outputs.attentions,
1413
+ )
1414
+
1415
+
1416
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1417
+ """Multihead Attention Pooling."""
1418
+
1419
+ def __init__(self, config: SiglipVisionConfig):
1420
+ super().__init__()
1421
+
1422
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1423
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1424
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1425
+ self.mlp = SiglipMLP(config)
1426
+
1427
+ def forward(self, hidden_state, attention_mask):
1428
+ batch_size = hidden_state.shape[0]
1429
+ probe = self.probe.repeat(batch_size, 1, 1)
1430
+
1431
+ hidden_state = self.attention(
1432
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1433
+ )[0]
1434
+
1435
+ residual = hidden_state
1436
+ hidden_state = self.layernorm(hidden_state)
1437
+ hidden_state = residual + self.mlp(hidden_state)
1438
+
1439
+ return hidden_state[:, 0]
1440
+
1441
+
1442
+ @add_start_docstrings(
1443
+ """The vision model from SigLIP without any head or projection on top.""",
1444
+ SIGLIP_START_DOCSTRING,
1445
+ )
1446
+ class SiglipVisionModel(SiglipPreTrainedModel):
1447
+ config_class = SiglipVisionConfig
1448
+ main_input_name = "pixel_values"
1449
+
1450
+ def __init__(self, config: SiglipVisionConfig):
1451
+ super().__init__(config)
1452
+
1453
+ self.vision_model = SiglipVisionTransformer(config)
1454
+
1455
+ # Initialize weights and apply final processing
1456
+ self.post_init()
1457
+
1458
+ def get_input_embeddings(self) -> nn.Module:
1459
+ return self.vision_model.embeddings.patch_embedding
1460
+
1461
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1462
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1463
+ def forward(
1464
+ self,
1465
+ pixel_values,
1466
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1467
+ output_attentions: Optional[bool] = None,
1468
+ output_hidden_states: Optional[bool] = None,
1469
+ return_dict: Optional[bool] = None,
1470
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1471
+ r"""
1472
+ Returns:
1473
+ Examples:
1474
+ ```python
1475
+ >>> from PIL import Image
1476
+ >>> import requests
1477
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1478
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1479
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1480
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1481
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1482
+ >>> inputs = processor(images=image, return_tensors="pt")
1483
+ >>> outputs = model(**inputs)
1484
+ >>> last_hidden_state = outputs.last_hidden_state
1485
+ >>> pooled_output = outputs.pooler_output # pooled features
1486
+ ```"""
1487
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1488
+
1489
+ return self.vision_model(
1490
+ pixel_values=pixel_values,
1491
+ patch_attention_mask=patch_attention_mask,
1492
+ output_attentions=output_attentions,
1493
+ output_hidden_states=output_hidden_states,
1494
+ return_dict=return_dict,
1495
+ )
1496
+
1497
+
1498
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1499
+ class SiglipModel(SiglipPreTrainedModel):
1500
+ config_class = SiglipConfig
1501
+
1502
+ def __init__(self, config: SiglipConfig):
1503
+ super().__init__(config)
1504
+
1505
+ if not isinstance(config.text_config, SiglipTextConfig):
1506
+ raise ValueError(
1507
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1508
+ f" {type(config.text_config)}."
1509
+ )
1510
+
1511
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1512
+ raise ValueError(
1513
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1514
+ f" {type(config.vision_config)}."
1515
+ )
1516
+
1517
+ text_config = config.text_config
1518
+ vision_config = config.vision_config
1519
+
1520
+ self.text_model = SiglipTextTransformer(text_config)
1521
+ self.vision_model = SiglipVisionTransformer(vision_config)
1522
+
1523
+ self.logit_scale = nn.Parameter(torch.randn(1))
1524
+ self.logit_bias = nn.Parameter(torch.randn(1))
1525
+
1526
+ # Initialize weights and apply final processing
1527
+ self.post_init()
1528
+
1529
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1530
+ def get_text_features(
1531
+ self,
1532
+ input_ids: Optional[torch.Tensor] = None,
1533
+ attention_mask: Optional[torch.Tensor] = None,
1534
+ position_ids: Optional[torch.Tensor] = None,
1535
+ output_attentions: Optional[bool] = None,
1536
+ output_hidden_states: Optional[bool] = None,
1537
+ return_dict: Optional[bool] = None,
1538
+ ) -> torch.FloatTensor:
1539
+ r"""
1540
+ Returns:
1541
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1542
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1543
+ Examples:
1544
+ ```python
1545
+ >>> from transformers import AutoTokenizer, AutoModel
1546
+ >>> import torch
1547
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1548
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1549
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1550
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1551
+ >>> with torch.no_grad():
1552
+ ... text_features = model.get_text_features(**inputs)
1553
+ ```"""
1554
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1555
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1556
+ output_hidden_states = (
1557
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1558
+ )
1559
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1560
+
1561
+ text_outputs = self.text_model(
1562
+ input_ids=input_ids,
1563
+ attention_mask=attention_mask,
1564
+ position_ids=position_ids,
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ return_dict=return_dict,
1568
+ )
1569
+
1570
+ pooled_output = text_outputs[1]
1571
+
1572
+ return pooled_output
1573
+
1574
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1575
+ def get_image_features(
1576
+ self,
1577
+ pixel_values: Optional[torch.FloatTensor] = None,
1578
+ output_attentions: Optional[bool] = None,
1579
+ output_hidden_states: Optional[bool] = None,
1580
+ return_dict: Optional[bool] = None,
1581
+ ) -> torch.FloatTensor:
1582
+ r"""
1583
+ Returns:
1584
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1585
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1586
+ Examples:
1587
+ ```python
1588
+ >>> from PIL import Image
1589
+ >>> import requests
1590
+ >>> from transformers import AutoProcessor, AutoModel
1591
+ >>> import torch
1592
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1593
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1594
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1595
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1596
+ >>> inputs = processor(images=image, return_tensors="pt")
1597
+ >>> with torch.no_grad():
1598
+ ... image_features = model.get_image_features(**inputs)
1599
+ ```"""
1600
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1601
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1602
+ output_hidden_states = (
1603
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1604
+ )
1605
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1606
+
1607
+ vision_outputs = self.vision_model(
1608
+ pixel_values=pixel_values,
1609
+ output_attentions=output_attentions,
1610
+ output_hidden_states=output_hidden_states,
1611
+ return_dict=return_dict,
1612
+ )
1613
+
1614
+ pooled_output = vision_outputs[1]
1615
+
1616
+ return pooled_output
1617
+
1618
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1619
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1620
+ def forward(
1621
+ self,
1622
+ input_ids: Optional[torch.LongTensor] = None,
1623
+ pixel_values: Optional[torch.FloatTensor] = None,
1624
+ attention_mask: Optional[torch.Tensor] = None,
1625
+ position_ids: Optional[torch.LongTensor] = None,
1626
+ return_loss: Optional[bool] = None,
1627
+ output_attentions: Optional[bool] = None,
1628
+ output_hidden_states: Optional[bool] = None,
1629
+ return_dict: Optional[bool] = None,
1630
+ ) -> Union[Tuple, SiglipOutput]:
1631
+ r"""
1632
+ Returns:
1633
+ Examples:
1634
+ ```python
1635
+ >>> from PIL import Image
1636
+ >>> import requests
1637
+ >>> from transformers import AutoProcessor, AutoModel
1638
+ >>> import torch
1639
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1640
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1641
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1642
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1643
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1644
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1645
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1646
+ >>> with torch.no_grad():
1647
+ ... outputs = model(**inputs)
1648
+ >>> logits_per_image = outputs.logits_per_image
1649
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1650
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1651
+ 31.9% that image 0 is 'a photo of 2 cats'
1652
+ ```"""
1653
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1655
+ output_hidden_states = (
1656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1657
+ )
1658
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1659
+
1660
+ vision_outputs = self.vision_model(
1661
+ pixel_values=pixel_values,
1662
+ output_attentions=output_attentions,
1663
+ output_hidden_states=output_hidden_states,
1664
+ return_dict=return_dict,
1665
+ )
1666
+
1667
+ text_outputs = self.text_model(
1668
+ input_ids=input_ids,
1669
+ attention_mask=attention_mask,
1670
+ position_ids=position_ids,
1671
+ output_attentions=output_attentions,
1672
+ output_hidden_states=output_hidden_states,
1673
+ return_dict=return_dict,
1674
+ )
1675
+
1676
+ image_embeds = vision_outputs[1]
1677
+ text_embeds = text_outputs[1]
1678
+
1679
+ # normalized features
1680
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1681
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1682
+
1683
+ # cosine similarity as logits
1684
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1685
+ logits_per_image = logits_per_text.t()
1686
+
1687
+ loss = None
1688
+ if return_loss:
1689
+ raise NotImplementedError("SigLIP loss to be implemented")
1690
+
1691
+ if not return_dict:
1692
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1693
+ return ((loss,) + output) if loss is not None else output
1694
+
1695
+ return SiglipOutput(
1696
+ loss=loss,
1697
+ logits_per_image=logits_per_image,
1698
+ logits_per_text=logits_per_text,
1699
+ text_embeds=text_embeds,
1700
+ image_embeds=image_embeds,
1701
+ text_model_output=text_outputs,
1702
+ vision_model_output=vision_outputs,
1703
+ )
1704
+
1705
+
1706
+ def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
1707
+ siglip_vision_config = {
1708
+ "hidden_size": 1152,
1709
+ "image_size": 448,
1710
+ "intermediate_size": 4304,
1711
+ "model_type": "siglip_vision_model",
1712
+ "num_attention_heads": 16,
1713
+ "num_hidden_layers": 27,
1714
+ "patch_size": 14,
1715
+ }
1716
+
1717
+ model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs)
1718
+
1719
+ vision_model = SiglipVisionModel(model_config).vision_model
1720
+
1721
+ return vision_model