|
from transformers import Pipeline |
|
from PIL import Image |
|
import torchvision.transforms as v2 |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
class TrashClassificationPipeline(Pipeline): |
|
def __init__(self, **kwargs): |
|
Pipeline.__init__(self, **kwargs) |
|
|
|
self.transform = v2.Compose([ |
|
v2.CenterCrop(size=(224, 224)), |
|
v2.PILToTensor(), |
|
v2.ConvertImageDtype(torch.float32), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
tensor = self.transform(inputs) |
|
tensor = tensor.unsqueeze(0) |
|
|
|
return tensor |
|
|
|
def _forward(self, tensor): |
|
self.model.eval() |
|
with torch.no_grad(): |
|
out = self.model(tensor)["logits"] |
|
|
|
return out |
|
|
|
def postprocess(self, out): |
|
pred = F.softmax(out, dim=1).argmax(dim=1)[0] |
|
label = self.model.config.id2label[str(int(pred))] |
|
|
|
return label |