import tensorflow as tf from transformers import T5TokenizerFast import seqio import logging import numpy as np import pyarrow.parquet as pq from concurrent.futures import ProcessPoolExecutor import multiprocessing # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Load the T5 tokenizer tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-base") # Define the maximum input and target lengths MAX_INPUT_LENGTH = 1024 MAX_TARGET_LENGTH = 200 # Batch size for processing BATCH_SIZE = 32 def tokenize_batch(batch): inputs = [text.lower() for text in batch] model_inputs = tokenizer( inputs, max_length=MAX_INPUT_LENGTH, padding="max_length", truncation=True, return_tensors="np" ) model_targets = tokenizer( inputs, max_length=MAX_TARGET_LENGTH, padding="max_length", truncation=True, return_tensors="np" ) return model_inputs.input_ids, model_targets.input_ids def create_tf_dataset(split, shuffle_files=False, seed=None): """Creates a TensorFlow Dataset from a single Parquet file.""" file_path = split # We'll pass the file path as the 'split' parameter try: logging.info(f"Loading dataset from file: {file_path}") table = pq.read_table(file_path) dataset = table.to_pandas() logging.info("Starting to process examples") # Create batches batches = [dataset['text'][i:i+BATCH_SIZE] for i in range(0, len(dataset), BATCH_SIZE)] # Use multiprocessing to tokenize batches with ProcessPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor: results = list(executor.map(tokenize_batch, batches)) input_ids = np.concatenate([r[0] for r in results]) target_ids = np.concatenate([r[1] for r in results]) logging.info(f"Finished processing {len(input_ids)} examples") return tf.data.Dataset.from_tensor_slices({ "inputs": input_ids.astype(np.int32), "targets": target_ids.astype(np.int32) }) except Exception as e: logging.error(f"Error in create_tf_dataset: {str(e)}") raise # Define the SeqIO Task seqio.TaskRegistry.add( "hf_inference_task", source=seqio.FunctionDataSource( dataset_fn=create_tf_dataset, splits=["train"] # We'll use this to pass the file path ), preprocessors=[], output_features={ "inputs": seqio.Feature( vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model'), add_eos=True, dtype=tf.int32 ), "targets": seqio.Feature( vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model'), add_eos=True, dtype=tf.int32 ) }, metric_fns=[] )