File size: 2,653 Bytes
4730cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import shutil
from omegaconf import OmegaConf
from cog import BasePredictor, Input, Path
from sampler import ResShiftSampler
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.configs = {
"realsr": OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml'),
"bicsr": configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml'),
}
def predict(
self,
image: Path = Input(description="Grayscale input image"),
scale: int = Input(description="Factor to scale image by.", default=4),
chop_size: int = Input(
choices=[512, 256], description="Chopping forward.", default=512
),
task: str = Input(
choices=["realsr", "bicsr"],
description="Choose a task",
default="realsr",
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed.", default=12345
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
configs = self.configs[task]
if task == 'realsr':
ckpt_path = f"weights/resshift_realsrx4_s4_v3.pth"
configs.model.ckpt_path = ckpt_path
else:
ckpt_path = f"weights/resshift_bicsrx4_s4.pth"
configs.model.ckpt_path = ckpt_path
configs.diffusion.params.steps = 4
configs.diffusion.params.sf = scale
configs.autoencoder.ckpt_path = f"weights/autoencoder_vq_f4.pth"
chop_stride = 448 if chop_size == 512 else 224
resshift_sampler = ResShiftSampler(
configs,
sf=scale,
chop_size=chop_size,
chop_stride=chop_stride,
chop_bs=1,
use_amp=True,
seed=seed,
padding_offset=configs.model.params.get('lq_size', 64),
)
out_path = "out_dir"
if os.path.exists(out_path):
shutil.rmtree(out_path)
resshift_sampler.inference(
str(image),
out_path,
mask_path=None,
bs=1,
noise_repeat=False
)
out = "/tmp/out.png"
shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out)
return Path(out)
|