import numpy as np from datasets import Dataset from pynvml import * def print_gpu_utilization(): nvmlInit() handle = nvmlDeviceGetHandleByIndex(0) info = nvmlDeviceGetMemoryInfo(handle) print(f"GPU memory occupied: {info.used//1024**2} MB.") def print_summary(result): print(f"Time: {result.metrics['train_runtime']:.2f}") print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") print_gpu_utilization() seq_len, dataset_size = 512, 512 dummy_data = { "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)), "labels": np.random.randint(0, 1, (dataset_size)), } ds = Dataset.from_dict(dummy_data) ds.set_format("pt")