from transformers import PreTrainedModel
import torch
import os

url_map = {
	"inception_v3": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
}

class InceptionV3ModelForImageClassification(PreTrainedModel):
	def __init__(self, config):
		super().__init__(config)

		model_path = f"{self.config.model_name}.bin".replace("/","_")

		if self.config.model_name == "google-safesearch-mini":
			self.model = torch.jit.load(model_path)
		elif self.config.model_name == "inception_v3":
			self.model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
		else:
			if not os.path.exists(model_path):
				from urllib.request import urlretrieve
				urlretrieve(f"https://huggingface.co/{self.config.model_name}/resolve/main/pytorch_model.bin", model_path)
			self.model = torch.jit.load(model_path) if self.config.use_jit else torch.load(model_path)

	def forward(self, input_ids):
		out, aux = self.model(input_ids)
		return out, aux

	def freeze(self):
		for param in self.model.parameters():
			param.requires_grad = False

	def unfreeze(self):
		for param in self.model.parameters():
			param.requires_grad = True

	def train(self, mode=True):
		super().train(mode)
		self.model.train(mode)

	def eval(self):
		return self.train(False)

	def to(self, device):
		self.model.to(device)
		return self

	def cuda(self, device=None):
		return self.to("cuda")

	def cpu(self):
		return self.to("cpu")

	def state_dict(self, destination=None, prefix='', keep_vars=False):
		return self.model.state_dict(destination, prefix, keep_vars)

	def load_state_dict(self, state_dict, strict=True):
		return self.model.load_state_dict(state_dict, strict)

	def parameters(self, recurse=True):
		return self.model.parameters(recurse)

	def named_parameters(self, prefix='', recurse=True):
		return self.model.named_parameters(prefix, recurse)

	def children(self):
		return self.model.children()

	def named_children(self):
		return self.model.named_children()

	def modules(self):
		return self.model.modules()

	def named_modules(self, memo=None, prefix=''):
		return self.model.named_modules(memo, prefix)

	def zero_grad(self, set_to_none=False):
		return self.model.zero_grad(set_to_none)

	def share_memory(self):
		return self.model.share_memory()

	def transform(self, image):
		from torchvision import transforms
		transform = transforms.Compose([
			transforms.Resize(299),
			transforms.ToTensor(),
			transforms.Normalize(mean=self.config.mean, std=self.config.std)
		])
		image = transform(image)
		return image

	def open_image(self, path):
		from PIL import Image
		path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
		if path.startswith('http://') or path.startswith('https://'):
			import requests
			from io import BytesIO
			response = requests.get(path)
			image = Image.open(BytesIO(response.content)).convert('RGB')
		else:
			image = Image.open(path).convert('RGB')
		return image

	def predict(self, path, device="cuda"):
		image = self.open_image(path)
		image = self.transform(image)
		image = image.unsqueeze(0)
		self.eval()
		if device == "cuda":
			image = image.cuda()
		with torch.no_grad():
			out, aux = self(image)
			print(out)
			_, predicted = torch.max(out.data, 1)
		return self.config.classes[predicted.item()]