FredZhang7's picture
Upload model
f75d25f
raw
history blame
3.21 kB
from transformers import PreTrainedModel
import torch
import os
class InceptionV3ModelForImageClassification(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
model_path = "google-safesearch-mini.bin"
if self.config.model_name == "google-safesearch-mini":
if not os.path.exists(model_path):
import requests
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
r = requests.get(url, allow_redirects=True)
open(model_path, 'wb').write(r.content)
self.model = torch.jit.load(model_path)
else:
raise ValueError(f"Model {self.config.model_name} not found.")
def forward(self, input_ids):
return self.model(input_ids), None if self.config.model_name == "inception_v3" else self.model(input_ids)
def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True
def train(self, mode=True):
super().train(mode)
self.model.train(mode)
def eval(self):
return self.train(False)
def to(self, device):
self.model.to(device)
return self
def cuda(self, device=None):
return self.to("cuda")
def cpu(self):
return self.to("cpu")
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(destination, prefix, keep_vars)
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict)
def parameters(self, recurse=True):
return self.model.parameters(recurse)
def named_parameters(self, prefix='', recurse=True):
return self.model.named_parameters(prefix, recurse)
def children(self):
return self.model.children()
def named_children(self):
return self.model.named_children()
def modules(self):
return self.model.modules()
def named_modules(self, memo=None, prefix=''):
return self.model.named_modules(memo, prefix)
def zero_grad(self, set_to_none=False):
return self.model.zero_grad(set_to_none)
def share_memory(self):
return self.model.share_memory()
def transform(self, image):
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(299),
transforms.ToTensor(),
transforms.Normalize(mean=self.config.mean, std=self.config.std)
])
image = transform(image)
return image
def open_image(self, path):
from PIL import Image
path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
if path.startswith('http://') or path.startswith('https://'):
import requests
from io import BytesIO
response = requests.get(path)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(path).convert('RGB')
return image
def predict(self, path, device="cuda", print_tensor=True):
image = self.open_image(path)
image = self.transform(image)
image = image.unsqueeze(0)
if device == "cuda":
image = image.cuda()
self.cuda()
else:
image = image.cpu()
self.cpu()
with torch.no_grad():
out, aux = self(image)
if print_tensor:
print(out)
_, predicted = torch.max(out.logits, 1)
return self.config.classes[str(predicted.item())]