File size: 123 Bytes
c4ea5b9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
from dataclasses import dataclass

import torch


@dataclass
class Transformer1DModelOutput:
    sample: torch.FloatTensor