File size: 2,653 Bytes
4730cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
import shutil
from omegaconf import OmegaConf
from cog import BasePredictor, Input, Path

from sampler import ResShiftSampler

class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        self.configs = {
            "realsr": OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml'),
            "bicsr": configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml'),
        }

    def predict(
        self,
        image: Path = Input(description="Grayscale input image"),
        scale: int = Input(description="Factor to scale image by.", default=4),
        chop_size: int = Input(
            choices=[512, 256], description="Chopping forward.", default=512
        ),
        task: str = Input(
            choices=["realsr", "bicsr"],
            description="Choose a task",
            default="realsr",
        ),
        seed: int = Input(
            description="Random seed. Leave blank to randomize the seed.", default=12345
        ),
    ) -> Path:
        """Run a single prediction on the model"""
        if seed is None:
            seed = int.from_bytes(os.urandom(2), "big")
        print(f"Using seed: {seed}")

        configs = self.configs[task]

        if task == 'realsr':
            ckpt_path = f"weights/resshift_realsrx4_s4_v3.pth"
            configs.model.ckpt_path = ckpt_path
        else:
            ckpt_path = f"weights/resshift_bicsrx4_s4.pth"
            configs.model.ckpt_path = ckpt_path
        configs.diffusion.params.steps = 4
        configs.diffusion.params.sf = scale
        configs.autoencoder.ckpt_path = f"weights/autoencoder_vq_f4.pth"

        chop_stride = 448 if chop_size == 512 else 224

        resshift_sampler = ResShiftSampler(
                configs,
                sf=scale,
                chop_size=chop_size,
                chop_stride=chop_stride,
                chop_bs=1,
                use_amp=True,
                seed=seed,
                padding_offset=configs.model.params.get('lq_size', 64),
                )

        out_path = "out_dir"
        if os.path.exists(out_path):
            shutil.rmtree(out_path)
        resshift_sampler.inference(
                str(image),
                out_path,
                mask_path=None,
                bs=1,
                noise_repeat=False
                )

        out = "/tmp/out.png"
        shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out)

        return Path(out)