codermert commited on
Commit
bc1d7b8
·
verified ·
1 Parent(s): b209bd4

Create model.py

Browse files
Files changed (1) hide show
  1. RealESRGAN/model.py +93 -0
RealESRGAN/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+ from huggingface_hub import hf_hub_url, hf_hub_download, cached_download
8
+
9
+ from .rrdbnet_arch import RRDBNet
10
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
11
+ unpad_image
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
41
+ config = HF_MODELS[self.scale]
42
+ cache_dir = os.path.dirname(model_path)
43
+ local_filename = os.path.basename(model_path)
44
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
45
+ htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
46
+ filename=config['filename'])
47
+ print(htr)
48
+ # cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
49
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
50
+
51
+ loadnet = torch.load(model_path)
52
+ if 'params' in loadnet:
53
+ self.model.load_state_dict(loadnet['params'], strict=True)
54
+ elif 'params_ema' in loadnet:
55
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
56
+ else:
57
+ self.model.load_state_dict(loadnet, strict=True)
58
+ self.model.eval()
59
+ self.model.to(self.device)
60
+
61
+ # @torch.cuda.amp.autocast()
62
+ def predict(self, lr_image, batch_size=4, patches_size=192,
63
+ padding=24, pad_size=15):
64
+ torch.autocast(device_type=self.device.type)
65
+ scale = self.scale
66
+ device = self.device
67
+ lr_image = np.array(lr_image)
68
+ lr_image = pad_reflect(lr_image, pad_size)
69
+
70
+ patches, p_shape = split_image_into_overlapping_patches(
71
+ lr_image, patch_size=patches_size, padding_size=padding
72
+ )
73
+ img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
74
+
75
+ with torch.no_grad():
76
+ res = self.model(img[0:batch_size])
77
+ for i in range(batch_size, img.shape[0], batch_size):
78
+ res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
79
+
80
+ sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
81
+ np_sr_image = sr_image.numpy()
82
+
83
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
84
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
85
+ np_sr_image = stich_together(
86
+ np_sr_image, padded_image_shape=padded_size_scaled,
87
+ target_shape=scaled_image_shape, padding_size=padding * scale
88
+ )
89
+ sr_img = (np_sr_image * 255).astype(np.uint8)
90
+ sr_img = unpad_image(sr_img, pad_size * scale)
91
+ sr_img = Image.fromarray(sr_img)
92
+
93
+ return sr_img