Spaces:
mikitona
/
Running on Zero

mikitona commited on
Commit
3e96387
·
verified ·
1 Parent(s): 8d9c2f3

Delete testing_utils.py

Browse files
Files changed (1) hide show
  1. testing_utils.py +0 -211
testing_utils.py DELETED
@@ -1,211 +0,0 @@
1
- import argparse
2
- import json
3
- from PIL import Image
4
- from torchvision import transforms
5
- import torch.nn.functional as F
6
- from glob import glob
7
-
8
- import cv2
9
- import math
10
- import numpy as np
11
- import os
12
- import os.path as osp
13
- import random
14
- import time
15
- import torch
16
- from pathlib import Path
17
- from torch.utils import data as data
18
-
19
- from basicsr.utils import DiffJPEG, USMSharp
20
- from basicsr.utils.img_process_util import filter2D
21
- #from basicsr.data.transforms import paired_random_crop, triplet_random_crop
22
- from basicsr.data.transforms import paired_random_crop
23
- #from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
24
- from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, bivariate_Gaussian
25
- from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
26
- from basicsr.data.transforms import augment
27
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
28
- from basicsr.utils.registry import DATASET_REGISTRY
29
-
30
-
31
- def parse_args_paired_testing(input_args=None):
32
- """
33
- Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
34
- This function sets up an argument parser to handle various training options.
35
-
36
- Returns:
37
- argparse.Namespace: The parsed command-line arguments.
38
- """
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument("--ref_path", type=str, default=None,)
41
- parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
42
- parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
43
-
44
- # details about the model architecture
45
- parser.add_argument("--sd_path")
46
- parser.add_argument("--de_net_path")
47
- parser.add_argument("--pretrained_path", type=str, default=None,)
48
- parser.add_argument("--revision", type=str, default=None,)
49
- parser.add_argument("--variant", type=str, default=None,)
50
- parser.add_argument("--tokenizer_name", type=str, default=None)
51
- parser.add_argument("--lora_rank_unet", default=32, type=int)
52
- parser.add_argument("--lora_rank_vae", default=16, type=int)
53
-
54
- parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
55
- parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
56
- parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
57
- parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
58
-
59
- parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
60
- parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
61
- parser.add_argument("--latent_tiled_size", type=int, default=96)
62
- parser.add_argument("--latent_tiled_overlap", type=int, default=32)
63
-
64
- parser.add_argument("--align_method", type=str, default="wavelet")
65
-
66
- parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
67
- parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
68
-
69
- # training details
70
- parser.add_argument("--output_dir", type=str, default='output/')
71
- parser.add_argument("--cache_dir", default=None,)
72
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
73
- parser.add_argument("--resolution", type=int, default=512,)
74
- parser.add_argument("--checkpointing_steps", type=int, default=500,)
75
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
76
- parser.add_argument("--gradient_checkpointing", action="store_true",)
77
-
78
- parser.add_argument("--dataloader_num_workers", type=int, default=0,)
79
- parser.add_argument("--allow_tf32", action="store_true",
80
- help=(
81
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
82
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
83
- ),
84
- )
85
- parser.add_argument("--report_to", type=str, default="wandb",
86
- help=(
87
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
88
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
89
- ),
90
- )
91
- parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
92
- parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
93
- parser.add_argument("--set_grads_to_none", action="store_true",)
94
-
95
- parser.add_argument('--world_size', default=1, type=int,
96
- help='number of distributed processes')
97
- parser.add_argument('--local_rank', default=-1, type=int)
98
- parser.add_argument('--dist_url', default='env://',
99
- help='url used to set up distributed training')
100
-
101
- if input_args is not None:
102
- args = parser.parse_args(input_args)
103
- else:
104
- args = parser.parse_args()
105
-
106
- return args
107
-
108
-
109
- class PlainDataset(data.Dataset):
110
- """Modified dataset based on the dataset used for Real-ESRGAN model:
111
- Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
112
-
113
- It loads gt (Ground-Truth) images, and augments them.
114
- It also generates blur kernels and sinc kernels for generating low-quality images.
115
- Note that the low-quality images are processed in tensors on GPUS for faster processing.
116
-
117
- Args:
118
- opt (dict): Config for train datasets. It contains the following keys:
119
- dataroot_gt (str): Data root path for gt.
120
- meta_info (str): Path for meta information file.
121
- io_backend (dict): IO backend type and other kwarg.
122
- use_hflip (bool): Use horizontal flips.
123
- use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
124
- Please see more options in the codes.
125
- """
126
-
127
- def __init__(self, opt):
128
- super(PlainDataset, self).__init__()
129
- self.opt = opt
130
- self.file_client = None
131
- self.io_backend_opt = opt['io_backend']
132
-
133
- if 'image_type' not in opt:
134
- opt['image_type'] = 'png'
135
-
136
- # support multiple type of data: file path and meta data, remove support of lmdb
137
- self.lr_paths = []
138
- if 'lr_path' in opt:
139
- if isinstance(opt['lr_path'], str):
140
- self.lr_paths.extend(sorted(
141
- [str(x) for x in Path(opt['lr_path']).glob('*.png')] +
142
- [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
143
- [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
144
- ))
145
- else:
146
- self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
147
- if len(opt['lr_path']) > 1:
148
- for i in range(len(opt['lr_path'])-1):
149
- self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
150
-
151
- def __getitem__(self, index):
152
- if self.file_client is None:
153
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
154
-
155
- # -------------------------------- Load gt images -------------------------------- #
156
- # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
157
- lr_path = self.lr_paths[index]
158
-
159
- # avoid errors caused by high latency in reading files
160
- retry = 3
161
- while retry > 0:
162
- try:
163
- lr_img_bytes = self.file_client.get(lr_path, 'gt')
164
- except (IOError, OSError) as e:
165
- # logger = get_root_logger()
166
- # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
167
- # change another file to read
168
- index = random.randint(0, self.__len__()-1)
169
- lr_path = self.lr_paths[index]
170
- time.sleep(1) # sleep 1s for occasional server congestion
171
- else:
172
- break
173
- finally:
174
- retry -= 1
175
-
176
- img_lr = imfrombytes(lr_img_bytes, float32=True)
177
-
178
- # BGR to RGB, HWC to CHW, numpy to tensor
179
- img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
180
-
181
- return_d = {'lr': img_lr, 'lr_path': lr_path}
182
- return return_d
183
-
184
- def __len__(self):
185
- return len(self.lr_paths)
186
-
187
-
188
- def lr_proc(config, batch, device):
189
- im_lr = batch['lr'].cuda()
190
- im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
191
-
192
- ori_lr = im_lr
193
-
194
- im_lr = F.interpolate(
195
- im_lr,
196
- size=(im_lr.size(-2) * config.sf,
197
- im_lr.size(-1) * config.sf),
198
- mode='bicubic',
199
- )
200
-
201
- im_lr = im_lr.contiguous()
202
- im_lr = im_lr * 2 - 1.0
203
- im_lr = torch.clamp(im_lr, -1.0, 1.0)
204
-
205
- ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
206
-
207
- pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
208
- pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
209
- im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
210
-
211
- return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)