File size: 3,397 Bytes
eaacc51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e69b49
eaacc51
 
 
 
 
 
2a32736
 
 
 
eaacc51
 
9e69b49
 
eaacc51
86d6675
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from transformers import PreTrainedModel
import torch
import os

url_map = {
	"inception_v3": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
}

class InceptionV3ModelForImageClassification(PreTrainedModel):
	def __init__(self, config):
		super().__init__(config)

		model_path = f"{self.config.model_name}.bin".replace("/","_")

		if self.config.model_name == "google-safesearch-mini":
			self.model = torch.jit.load(model_path)
		elif self.config.model_name == "inception_v3":
			self.model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
		else:
			if not os.path.exists(model_path):
				from urllib.request import urlretrieve
				urlretrieve(f"https://huggingface.co/{self.config.model_name}/resolve/main/pytorch_model.bin", model_path)
			self.model = torch.jit.load(model_path) if self.config.use_jit else torch.load(model_path)

	def forward(self, input_ids):
		out, aux = self.model(input_ids)
		return out, aux

	def freeze(self):
		for param in self.model.parameters():
			param.requires_grad = False

	def unfreeze(self):
		for param in self.model.parameters():
			param.requires_grad = True

	def train(self, mode=True):
		super().train(mode)
		self.model.train(mode)

	def eval(self):
		return self.train(False)

	def to(self, device):
		self.model.to(device)
		return self

	def cuda(self, device=None):
		return self.to("cuda")

	def cpu(self):
		return self.to("cpu")

	def state_dict(self, destination=None, prefix='', keep_vars=False):
		return self.model.state_dict(destination, prefix, keep_vars)

	def load_state_dict(self, state_dict, strict=True):
		return self.model.load_state_dict(state_dict, strict)

	def parameters(self, recurse=True):
		return self.model.parameters(recurse)

	def named_parameters(self, prefix='', recurse=True):
		return self.model.named_parameters(prefix, recurse)

	def children(self):
		return self.model.children()

	def named_children(self):
		return self.model.named_children()

	def modules(self):
		return self.model.modules()

	def named_modules(self, memo=None, prefix=''):
		return self.model.named_modules(memo, prefix)

	def zero_grad(self, set_to_none=False):
		return self.model.zero_grad(set_to_none)

	def share_memory(self):
		return self.model.share_memory()

	def transform(self, image):
		from torchvision import transforms
		transform = transforms.Compose([
			transforms.Resize(299),
			transforms.ToTensor(),
			transforms.Normalize(mean=self.config.mean, std=self.config.std)
		])
		image = transform(image)
		return image

	def open_image(self, path):
		from PIL import Image
		path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
		if path.startswith('http://') or path.startswith('https://'):
			import requests
			from io import BytesIO
			response = requests.get(path)
			image = Image.open(BytesIO(response.content)).convert('RGB')
		else:
			image = Image.open(path).convert('RGB')
		return image

	def predict(self, path, device="cuda", print_tensor=True):
		image = self.open_image(path)
		image = self.transform(image)
		image = image.unsqueeze(0)
		self.eval()
		if device == "cuda":
			image = image.cuda()
			self.cuda()
		else:
			image = image.cpu()
			self.cpu()
		with torch.no_grad():
			out, aux = self(image)
			if print_tensor:
				print(out)
			_, predicted = torch.max(out.data, 1)
		return self.config.classes[str(predicted.item())]