File size: 5,065 Bytes
5004324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import numpy as np
import PIL
import torchvision.transforms as T
import torch.nn.functional as F
from KandiSuperRes.model.unet import UNet
from KandiSuperRes.model.unet_sr import UNet as UNet_sr
from KandiSuperRes.movq import MoVQ
from KandiSuperRes.model.diffusion_sr import DPMSolver
from KandiSuperRes.model.diffusion_refine import BaseDiffusion, get_named_beta_schedule
from KandiSuperRes.model.diffusion_sr_turbo import BaseDiffusion as BaseDiffusion_turbo


class KandiSuperResPipeline:
    
    def __init__(
        self, 
        scale: int,
        device: str,
        dtype: str,
        flash: bool,
        sr_model: UNet_sr,
        movq: MoVQ = None,
        refiner: UNet = None,
    ):
        self.device = device
        self.dtype = dtype
        self.scale = scale
        self.flash = flash
        self.to_pil = T.ToPILImage()
        self.image_transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.ToTensor(),
            T.Lambda(lambda img: 2. * img - 1.),
        ])
        
        self.sr_model = sr_model
        self.movq = movq
        self.refiner = refiner
        
    def __call__(
        self, 
        pil_image: PIL.Image.Image = None,
        steps: int = 5,
        view_batch_size: int = 15,
        seed: int = 0,
        refine=True
    ) -> PIL.Image.Image:

        if self.flash:
            betas_turbo = get_named_beta_schedule('linear', 1000)
            base_diffusion_sr = BaseDiffusion_turbo(betas_turbo)
    
            old_height = pil_image.size[1]
            old_width = pil_image.size[0]
            height = int(old_height-np.mod(old_height,32))
            width = int(old_width-np.mod(old_width,32))
    
            pil_image = pil_image.resize((width,height))
            lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device['sr_model'])
            
            sr_image = base_diffusion_sr.p_sample_loop(
                self.sr_model, (1, 3, height*self.scale, width*self.scale), self.device['sr_model'], self.dtype['sr_model'], lowres_img=lr_image
            )

            if refine:
                betas = get_named_beta_schedule('cosine', 1000)
                base_diffusion = BaseDiffusion(betas, 0.99)
                
                with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
                    lr_image_latent = self.movq.encode(sr_image)
                
                pil_images = []
                context = torch.load('weights/context.pt').to(self.dtype['refiner'])
                context_mask = torch.load('weights/context_mask.pt').to(self.dtype['refiner'])
                
                with torch.no_grad():       
                    with torch.cuda.amp.autocast(dtype=self.dtype['refiner']):
                        refiner_image = base_diffusion.refine_tiled(self.refiner, lr_image_latent, context, context_mask)
                        
                    with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
                        refiner_image = self.movq.decode(refiner_image)
                        refiner_image = torch.clip((refiner_image + 1.) / 2., 0., 1.)
                    
                if old_height*self.scale != refiner_image.shape[2] or old_width*self.scale != refiner_image.shape[3]:
                    refiner_image = F.interpolate(refiner_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
                refined_pil_image = self.to_pil(refiner_image[0])
                return refined_pil_image
    
            sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
            if old_height*self.scale != sr_image.shape[2] or old_width*self.scale != sr_image.shape[3]:
                sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
            pil_sr_image = self.to_pil(sr_image[0])
            return pil_sr_image

        else:
            base_diffusion = DPMSolver(steps)
            
            lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device)
            
            old_height = pil_image.size[1]
            old_width = pil_image.size[0]
    
            height = int(old_height+np.mod(old_height,2))*self.scale
            width = int(old_width+np.mod(old_width,2))*self.scale
    
            sr_image = base_diffusion.generate_panorama(height, width, self.device, self.dtype, steps, 
                                                       self.sr_model, lowres_img=lr_image, 
                                                       view_batch_size=view_batch_size, eta=0.0, seed=seed)
    
            sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
            if old_height*self.scale != height or old_width*self.scale != width:
                sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
                
            pil_sr_image = self.to_pil(sr_image[0])
            return pil_sr_image