import torch import act_mem import layers if __name__ == "__main__": batch_size, seq_len, d_model, n_heads = 1, 128, 1024, 32 print(f"Batch size: {batch_size}, sequence length: {seq_len}, d_model: {d_model}, n_heads: {n_heads}") dtype = torch.bfloat16 inputs = torch.randn( batch_size, seq_len, d_model, device="cuda", requires_grad=True, dtype=dtype, ) attn = layers.Attention( d_model=d_model, n_heads=n_heads, device="cuda", dtype=dtype, ) with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( ignored_tensors=attn.parameters() ) as saved: out = attn(inputs) stm = saved.saved_tensor_mem print(f'{mem.delta["current"]=}') print(f"{stm=}") print(f"{stm/out.numel()=}")