Spaces:
Build error
Build error
File size: 869 Bytes
783053f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import numpy as np
def fake_face_collator(batch):
"""The data collator for training vision transformer models on fake and real face dataset
Args:
batch (list): A dictionary containing the pixel values and the labels
Returns:
dict: The final dictionary
"""
new_batch = {
'pixel_values': [],
'labels': []
}
for x in batch:
pixel_values = torch.from_numpy(x['pixel_values'][0]) if isinstance(x['pixel_values'][0], np.ndarray) \
else x['pixel_values'][0]
new_batch['pixel_values'].append(pixel_values)
new_batch['labels'].append(torch.tensor(x['labels']))
new_batch['pixel_values'] = torch.stack(new_batch['pixel_values'])
new_batch['labels'] = torch.stack(new_batch['labels'])
return new_batch
|