import h5py import torch import logging from tqdm import tqdm from transformers import AutoTokenizer, AutoModel, Swinv2Model logger = logging.getLogger(__name__) @torch.no_grad() def create_embeddings_h5(input_h5_path, output_h5_path, batch_size=32, device="cuda"): """ Create a new H5 file with pre-computed embeddings from text and images. Args: input_h5_path (str): Path to input H5 file with raw data output_h5_path (str): Path where to save the new H5 file with embeddings batch_size (int): Batch size for processing device (str): Device to use for computation """ logger.info(f"Creating embeddings H5 file from {input_h5_path}") # Initialize models tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-xsmall") text_encoder = AutoModel.from_pretrained("microsoft/deberta-v3-xsmall").to(device) image_encoder = Swinv2Model.from_pretrained( "microsoft/swinv2-base-patch4-window8-256" ).to(device) # Set models to eval mode text_encoder.eval() image_encoder.eval() # Open input H5 file with h5py.File(input_h5_path, "r") as in_f, h5py.File(output_h5_path, "w") as out_f: total_samples = len(in_f.keys()) # Process in batches for batch_start in tqdm(range(0, total_samples, batch_size)): batch_end = min(batch_start + batch_size, total_samples) batch_indices = range(batch_start, batch_end) # Collect batch data claim_texts = [] doc_texts = [] claim_images = [] doc_images = [] labels = [] for idx in batch_indices: sample = in_f[str(idx)] claim_texts.append(sample["claim"][()].decode()) doc_texts.append(sample["document"][()].decode()) claim_images.append(torch.from_numpy(sample["claim_image"][()])) doc_images.append(torch.from_numpy(sample["document_image"][()])) labels.append(sample["labels"][()]) # Convert to tensors claim_images = torch.stack(claim_images).to(device) doc_images = torch.stack(doc_images).to(device) # Get text embeddings with fixed sequence length claim_text_inputs = tokenizer( claim_texts, truncation=True, padding="max_length", # Changed to max_length return_tensors="pt", max_length=512, ).to(device) doc_text_inputs = tokenizer( doc_texts, truncation=True, padding="max_length", # Changed to max_length return_tensors="pt", max_length=512, ).to(device) claim_text_embeds = text_encoder(**claim_text_inputs).last_hidden_state doc_text_embeds = text_encoder(**doc_text_inputs).last_hidden_state # Verify shapes assert ( claim_text_embeds.shape[1] == 512 ), f"Unexpected claim text shape: {claim_text_embeds.shape}" assert ( doc_text_embeds.shape[1] == 512 ), f"Unexpected doc text shape: {doc_text_embeds.shape}" # Get image embeddings claim_image_embeds = image_encoder(claim_images).last_hidden_state doc_image_embeds = image_encoder(doc_images).last_hidden_state # Store embeddings and labels for batch_idx, idx in enumerate(batch_indices): sample_group = out_f.create_group(str(idx)) # Store embeddings sample_group.create_dataset( "claim_text_embeds", data=claim_text_embeds[batch_idx].cpu().numpy() ) sample_group.create_dataset( "doc_text_embeds", data=doc_text_embeds[batch_idx].cpu().numpy() ) sample_group.create_dataset( "claim_image_embeds", data=claim_image_embeds[batch_idx].cpu().numpy(), ) sample_group.create_dataset( "doc_image_embeds", data=doc_image_embeds[batch_idx].cpu().numpy() ) # Store labels sample_group.create_dataset("labels", data=labels[batch_idx]) logger.info(f"Created embeddings H5 file at {output_h5_path}") if __name__ == "__main__": # Set up logging logging.basicConfig(level=logging.INFO) # Example usage create_embeddings_h5( input_h5_path="data/preprocessed/train.h5", output_h5_path="data/preprocessed/train_embeddings.h5", batch_size=32, device="cuda:0", )