File size: 1,045 Bytes
24cdfa1 2073823 24cdfa1 e77d408 9413287 24cdfa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from transformers import Pipeline
from PIL import Image
import torchvision.transforms as v2
import torch
import torch.nn.functional as F
class TrashClassificationPipeline(Pipeline):
def __init__(self, **kwargs):
Pipeline.__init__(self, **kwargs)
self.transform = v2.Compose([
v2.CenterCrop(size=(224, 224)),
v2.PILToTensor(),
v2.ConvertImageDtype(torch.float32),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, inputs):
tensor = self.transform(inputs)
tensor = tensor.unsqueeze(0)
return tensor
def _forward(self, tensor):
self.model.eval()
with torch.no_grad():
out = self.model(tensor)["logits"]
return out
def postprocess(self, out):
pred = F.softmax(out, dim=1).argmax(dim=1)[0]
label = self.model.config.id2label[str(int(pred))]
return label |