MultiMatrix commited on
Commit
e46ed6e
·
verified ·
1 Parent(s): cfb9ad6

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +86 -0
inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+
3
+ import torch
4
+
5
+ from accelerate.utils import set_seed
6
+ from utils.inference import (
7
+ V1InferenceLoop,
8
+ BSRInferenceLoop, BFRInferenceLoop, BIDInferenceLoop, UnAlignedBFRInferenceLoop
9
+ )
10
+
11
+
12
+ def check_device(device: str) -> str:
13
+ if device == "cuda":
14
+ if not torch.cuda.is_available():
15
+ print("CUDA not available because the current PyTorch install was not "
16
+ "built with CUDA enabled.")
17
+ device = "cpu"
18
+ else:
19
+ if device == "mps":
20
+ if not torch.backends.mps.is_available():
21
+ if not torch.backends.mps.is_built():
22
+ print("MPS not available because the current PyTorch install was not "
23
+ "built with MPS enabled.")
24
+ device = "cpu"
25
+ else:
26
+ print("MPS not available because the current MacOS version is not 12.3+ "
27
+ "and/or you do not have an MPS-enabled device on this machine.")
28
+ device = "cpu"
29
+ print(f"using device {device}")
30
+ return device
31
+
32
+
33
+ def parse_args() -> Namespace:
34
+ parser = ArgumentParser()
35
+ ### model parameters
36
+ parser.add_argument("--task", type=str, required=True, choices=["sr", "dn", "fr", "fr_bg"])
37
+ parser.add_argument("--upscale", type=float, required=True)
38
+ parser.add_argument("--version", type=str, default="v2", choices=["v1", "v2"])
39
+ ### sampling parameters
40
+ parser.add_argument("--steps", type=int, default=50)
41
+ parser.add_argument("--better_start", action="store_true")
42
+ parser.add_argument("--tiled", action="store_true")
43
+ parser.add_argument("--tile_size", type=int, default=512)
44
+ parser.add_argument("--tile_stride", type=int, default=256)
45
+ parser.add_argument("--pos_prompt", type=str, default="")
46
+ parser.add_argument("--neg_prompt", type=str, default="low quality, blurry, low-resolution, noisy, unsharp, weird textures")
47
+ parser.add_argument("--cfg_scale", type=float, default=4.0)
48
+ ### input parameters
49
+ parser.add_argument("--input", type=str, required=True)
50
+ parser.add_argument("--n_samples", type=int, default=1)
51
+ ### guidance parameters
52
+ parser.add_argument("--guidance", action="store_true")
53
+ parser.add_argument("--g_loss", type=str, default="w_mse", choices=["mse", "w_mse"])
54
+ parser.add_argument("--g_scale", type=float, default=0.0)
55
+ parser.add_argument("--g_start", type=int, default=1001)
56
+ parser.add_argument("--g_stop", type=int, default=-1)
57
+ parser.add_argument("--g_space", type=str, default="latent")
58
+ parser.add_argument("--g_repeat", type=int, default=1)
59
+ ### output parameters
60
+ parser.add_argument("--output", type=str, required=True)
61
+ ### common parameters
62
+ parser.add_argument("--seed", type=int, default=231)
63
+ parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
64
+
65
+ return parser.parse_args()
66
+
67
+
68
+ def main():
69
+ args = parse_args()
70
+ args.device = check_device(args.device)
71
+ set_seed(args.seed)
72
+ if args.version == "v1":
73
+ V1InferenceLoop(args).run()
74
+ else:
75
+ supported_tasks = {
76
+ "sr": BSRInferenceLoop,
77
+ "dn": BIDInferenceLoop,
78
+ "fr": BFRInferenceLoop,
79
+ "fr_bg": UnAlignedBFRInferenceLoop
80
+ }
81
+ supported_tasks[args.task](args).run()
82
+ print("done!")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()