|
from transformers import PreTrainedModel |
|
import torch |
|
import os |
|
|
|
url_map = { |
|
"inception_v3": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" |
|
} |
|
|
|
class InceptionV3ModelForImageClassification(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
model_path = f"{self.config.model_name}.bin".replace("/","_") |
|
|
|
if self.config.model_name == "google-safesearch-mini": |
|
self.model = torch.jit.load(model_path) |
|
elif self.config.model_name == "inception_v3": |
|
self.model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True) |
|
else: |
|
if not os.path.exists(model_path): |
|
from urllib.request import urlretrieve |
|
urlretrieve(f"https://huggingface.co/{self.config.model_name}/resolve/main/pytorch_model.bin", model_path) |
|
self.model = torch.jit.load(model_path) if self.config.use_jit else torch.load(model_path) |
|
|
|
def forward(self, input_ids): |
|
out, aux = self.model(input_ids) |
|
return out, aux |
|
|
|
def freeze(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = True |
|
|
|
def train(self, mode=True): |
|
super().train(mode) |
|
self.model.train(mode) |
|
|
|
def eval(self): |
|
return self.train(False) |
|
|
|
def to(self, device): |
|
self.model.to(device) |
|
return self |
|
|
|
def cuda(self, device=None): |
|
return self.to("cuda") |
|
|
|
def cpu(self): |
|
return self.to("cpu") |
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False): |
|
return self.model.state_dict(destination, prefix, keep_vars) |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
return self.model.load_state_dict(state_dict, strict) |
|
|
|
def parameters(self, recurse=True): |
|
return self.model.parameters(recurse) |
|
|
|
def named_parameters(self, prefix='', recurse=True): |
|
return self.model.named_parameters(prefix, recurse) |
|
|
|
def children(self): |
|
return self.model.children() |
|
|
|
def named_children(self): |
|
return self.model.named_children() |
|
|
|
def modules(self): |
|
return self.model.modules() |
|
|
|
def named_modules(self, memo=None, prefix=''): |
|
return self.model.named_modules(memo, prefix) |
|
|
|
def zero_grad(self, set_to_none=False): |
|
return self.model.zero_grad(set_to_none) |
|
|
|
def share_memory(self): |
|
return self.model.share_memory() |
|
|
|
def transform(self, image): |
|
from torchvision import transforms |
|
transform = transforms.Compose([ |
|
transforms.Resize(299), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=self.config.mean, std=self.config.std) |
|
]) |
|
image = transform(image) |
|
return image |
|
|
|
def open_image(self, path): |
|
from PIL import Image |
|
path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1' |
|
if path.startswith('http://') or path.startswith('https://'): |
|
import requests |
|
from io import BytesIO |
|
response = requests.get(path) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(path).convert('RGB') |
|
return image |
|
|
|
def predict(self, path, device="cuda"): |
|
image = self.open_image(path) |
|
image = self.transform(image) |
|
image = image.unsqueeze(0) |
|
self.eval() |
|
if device == "cuda": |
|
image = image.cuda() |
|
with torch.no_grad(): |
|
out, aux = self(image) |
|
print(out) |
|
_, predicted = torch.max(out.data, 1) |
|
return self.config.classes[predicted.item()] |
|
|