|
import h5py |
|
import torch |
|
import logging |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, AutoModel, Swinv2Model |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@torch.no_grad() |
|
def create_embeddings_h5(input_h5_path, output_h5_path, batch_size=32, device="cuda"): |
|
""" |
|
Create a new H5 file with pre-computed embeddings from text and images. |
|
|
|
Args: |
|
input_h5_path (str): Path to input H5 file with raw data |
|
output_h5_path (str): Path where to save the new H5 file with embeddings |
|
batch_size (int): Batch size for processing |
|
device (str): Device to use for computation |
|
""" |
|
logger.info(f"Creating embeddings H5 file from {input_h5_path}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall") |
|
text_encoder = AutoModel.from_pretrained("microsoft/deberta-v3-xsmall").to(device) |
|
image_encoder = Swinv2Model.from_pretrained( |
|
"microsoft/swinv2-base-patch4-window8-256" |
|
).to(device) |
|
|
|
|
|
text_encoder.eval() |
|
image_encoder.eval() |
|
|
|
|
|
with h5py.File(input_h5_path, "r") as in_f, h5py.File(output_h5_path, "w") as out_f: |
|
total_samples = len(in_f.keys()) |
|
|
|
|
|
for batch_start in tqdm(range(0, total_samples, batch_size)): |
|
batch_end = min(batch_start + batch_size, total_samples) |
|
batch_indices = range(batch_start, batch_end) |
|
|
|
|
|
claim_texts = [] |
|
doc_texts = [] |
|
claim_images = [] |
|
doc_images = [] |
|
labels = [] |
|
|
|
for idx in batch_indices: |
|
sample = in_f[str(idx)] |
|
claim_texts.append(sample["claim"][()].decode()) |
|
doc_texts.append(sample["document"][()].decode()) |
|
claim_images.append(torch.from_numpy(sample["claim_image"][()])) |
|
doc_images.append(torch.from_numpy(sample["document_image"][()])) |
|
labels.append(sample["labels"][()]) |
|
|
|
|
|
claim_images = torch.stack(claim_images).to(device) |
|
doc_images = torch.stack(doc_images).to(device) |
|
|
|
|
|
claim_text_inputs = tokenizer( |
|
claim_texts, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt", |
|
max_length=512, |
|
).to(device) |
|
|
|
doc_text_inputs = tokenizer( |
|
doc_texts, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt", |
|
max_length=512, |
|
).to(device) |
|
|
|
claim_text_embeds = text_encoder(**claim_text_inputs).last_hidden_state |
|
doc_text_embeds = text_encoder(**doc_text_inputs).last_hidden_state |
|
|
|
|
|
assert ( |
|
claim_text_embeds.shape[1] == 512 |
|
), f"Unexpected claim text shape: {claim_text_embeds.shape}" |
|
assert ( |
|
doc_text_embeds.shape[1] == 512 |
|
), f"Unexpected doc text shape: {doc_text_embeds.shape}" |
|
|
|
|
|
claim_image_embeds = image_encoder(claim_images).last_hidden_state |
|
doc_image_embeds = image_encoder(doc_images).last_hidden_state |
|
|
|
|
|
for batch_idx, idx in enumerate(batch_indices): |
|
sample_group = out_f.create_group(str(idx)) |
|
|
|
|
|
sample_group.create_dataset( |
|
"claim_text_embeds", data=claim_text_embeds[batch_idx].cpu().numpy() |
|
) |
|
sample_group.create_dataset( |
|
"doc_text_embeds", data=doc_text_embeds[batch_idx].cpu().numpy() |
|
) |
|
sample_group.create_dataset( |
|
"claim_image_embeds", |
|
data=claim_image_embeds[batch_idx].cpu().numpy(), |
|
) |
|
sample_group.create_dataset( |
|
"doc_image_embeds", data=doc_image_embeds[batch_idx].cpu().numpy() |
|
) |
|
|
|
|
|
sample_group.create_dataset("labels", data=labels[batch_idx]) |
|
|
|
logger.info(f"Created embeddings H5 file at {output_h5_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
create_embeddings_h5( |
|
input_h5_path="data/preprocessed/train.h5", |
|
output_h5_path="data/preprocessed/train_embeddings.h5", |
|
batch_size=32, |
|
device="cuda:0", |
|
) |
|
|