from transformers import PreTrainedModel import torch import os url_map = { "inception_v3": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" } class InceptionV3ModelForImageClassification(PreTrainedModel): def __init__(self, config): super().__init__(config) model_path = f"{self.config.model_name}.bin".replace("/","_") if self.config.model_name == "google-safesearch-mini": self.model = torch.jit.load(model_path) elif self.config.model_name == "inception_v3": self.model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True) else: if not os.path.exists(model_path): from urllib.request import urlretrieve urlretrieve(f"https://huggingface.co/{self.config.model_name}/resolve/main/pytorch_model.bin", model_path) self.model = torch.jit.load(model_path) if self.config.use_jit else torch.load(model_path) def forward(self, input_ids): out, aux = self.model(input_ids) return out, aux def freeze(self): for param in self.model.parameters(): param.requires_grad = False def unfreeze(self): for param in self.model.parameters(): param.requires_grad = True def train(self, mode=True): super().train(mode) self.model.train(mode) def eval(self): return self.train(False) def to(self, device): self.model.to(device) return self def cuda(self, device=None): return self.to("cuda") def cpu(self): return self.to("cpu") def state_dict(self, destination=None, prefix='', keep_vars=False): return self.model.state_dict(destination, prefix, keep_vars) def load_state_dict(self, state_dict, strict=True): return self.model.load_state_dict(state_dict, strict) def parameters(self, recurse=True): return self.model.parameters(recurse) def named_parameters(self, prefix='', recurse=True): return self.model.named_parameters(prefix, recurse) def children(self): return self.model.children() def named_children(self): return self.model.named_children() def modules(self): return self.model.modules() def named_modules(self, memo=None, prefix=''): return self.model.named_modules(memo, prefix) def zero_grad(self, set_to_none=False): return self.model.zero_grad(set_to_none) def share_memory(self): return self.model.share_memory() def transform(self, image): from torchvision import transforms transform = transforms.Compose([ transforms.Resize(299), transforms.ToTensor(), transforms.Normalize(mean=self.config.mean, std=self.config.std) ]) image = transform(image) return image def open_image(self, path): from PIL import Image path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1' if path.startswith('http://') or path.startswith('https://'): import requests from io import BytesIO response = requests.get(path) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(path).convert('RGB') return image def predict(self, path, device="cuda"): image = self.open_image(path) image = self.transform(image) image = image.unsqueeze(0) self.eval() if device == "cuda": image = image.cuda() with torch.no_grad(): out, aux = self(image) print(out) _, predicted = torch.max(out.data, 1) return self.config.classes[predicted.item()]