alexnasa commited on
Commit
b488f86
·
verified ·
1 Parent(s): ba815e8

Update inference_coz_single.py

Browse files
Files changed (1) hide show
  1. inference_coz_single.py +57 -25
inference_coz_single.py CHANGED
@@ -71,7 +71,7 @@ def _generate_vlm_prompt(
71
  return_tensors="pt",
72
  ).to(device)
73
 
74
- # (4) Generate tokens→decode
75
  generated = vlm_model.generate(**inputs, max_new_tokens=128)
76
  # strip off the prompt tokens from each generated sequence:
77
  trimmed = [
@@ -86,28 +86,46 @@ def _generate_vlm_prompt(
86
 
87
 
88
  # -------------------------------------------------------------------
89
- # Main Function: recursive_multiscale_sr
90
  # -------------------------------------------------------------------
91
  def recursive_multiscale_sr(
92
  input_png_path: str,
93
  upscale: int,
94
- ) -> list[Image.Image]:
 
 
95
  """
96
- Perform exactly four recursive_multiscale super-resolution steps on a single PNG.
97
  - input_png_path: path to a single .png file on disk.
98
  - upscale: integer up-scale factor per recursion (e.g. 4).
99
- Returns a list of 4 PIL.Image objects, corresponding to each SR output
100
- at recursion steps 1, 2, 3, 4 (in that order).
101
-
102
- All other parameters (model checkpoints, prompt model, process size, etc.)
103
- are hard-coded exactly as in your command-line example.
 
 
 
 
104
  """
 
 
 
 
 
 
 
 
 
 
 
 
105
  ###############################
106
  # 1. Fixed hyper-parameters
107
  ###############################
108
  device = "cuda"
109
  process_size = 512 # same as args.process_size
110
- rec_num = 4 # fixed to 4 recursions
111
  # model checkpoint paths (hard-coded to your example)
112
  LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
113
  VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
@@ -142,7 +160,7 @@ def recursive_multiscale_sr(
142
  ###############################
143
  # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders
144
  sd3 = SD3Euler()
145
- # move all text encoders+transformer+VAE to CUDA:
146
  sd3.text_enc_1.to(device)
147
  sd3.text_enc_2.to(device)
148
  sd3.text_enc_3.to(device)
@@ -163,7 +181,7 @@ def recursive_multiscale_sr(
163
  # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor])
164
 
165
  ###############################
166
- # 4. Load the VLM (Qwen2.5-VL)
167
  ###############################
168
  vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
169
  VLM_NAME,
@@ -173,7 +191,7 @@ def recursive_multiscale_sr(
173
  vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
174
 
175
  ###############################
176
- # 5. Pre-allocate a Temporary Directory
177
  # to hold intermediate JPEG/PNG files
178
  ###############################
179
  unique_id = uuid.uuid4().hex
@@ -193,23 +211,37 @@ def recursive_multiscale_sr(
193
  prev_path = os.path.join(td, "step0_prev.png")
194
  img0.save(prev_path)
195
 
196
- # We will maintain a list of PIL outputs here:
197
  sr_pil_list: list[Image.Image] = []
198
- prompt_list = []
199
 
200
  ###############################
201
- # 7. Recursion loop (exactly 4 times)
202
  ###############################
203
  for rec in range(rec_num):
204
- # (A) Crop + upsample the “prev” image to obtain this step’s input zoomed
205
  prev_pil = Image.open(prev_path).convert("RGB")
206
- w, h = prev_pil.size # should be (512×512) each time
 
 
207
  new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
208
- # center-crop region:
209
- left = (w - new_w) // 2
210
- top = (h - new_h) // 2
 
 
 
 
 
 
 
 
 
 
 
211
  right = left + new_w
212
  bottom = top + new_h
 
213
  cropped = prev_pil.crop((left, top, right, bottom))
214
 
215
  # (B) Resize that crop back up to (512×512) via BICUBIC → zoomed
@@ -228,7 +260,7 @@ def recursive_multiscale_sr(
228
  )
229
  # (By default, no extra user prompt is appended.)
230
 
231
- # (D) Prepare the low-res tensor for SR: convert zoomed→Tensor→[0,1]→[−1,1]
232
  to_tensor = transforms.ToTensor()
233
  lq = to_tensor(zoomed).unsqueeze(0).to(device) # shape (1,3,512,512)
234
  lq = (lq * 2.0) - 1.0
@@ -252,7 +284,7 @@ def recursive_multiscale_sr(
252
  # end for(rec)
253
 
254
  ###############################
255
- # 8. Return the four SR‐PILs
256
  ###############################
257
- # The list sr_pil_list = [ SR1, SR2, SR3, SR4 ] in order.
258
- return sr_pil_list, prompt_list
 
71
  return_tensors="pt",
72
  ).to(device)
73
 
74
+ # (4) Generate tokens decode
75
  generated = vlm_model.generate(**inputs, max_new_tokens=128)
76
  # strip off the prompt tokens from each generated sequence:
77
  trimmed = [
 
86
 
87
 
88
  # -------------------------------------------------------------------
89
+ # Main Function: recursive_multiscale_sr (with multiple centers)
90
  # -------------------------------------------------------------------
91
  def recursive_multiscale_sr(
92
  input_png_path: str,
93
  upscale: int,
94
+ rec_num: int = 4,
95
+ centers: list[tuple[float, float]] = None,
96
+ ) -> tuple[list[Image.Image], list[str]]:
97
  """
98
+ Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG.
99
  - input_png_path: path to a single .png file on disk.
100
  - upscale: integer up-scale factor per recursion (e.g. 4).
101
+ - rec_num: how many recursion steps to perform.
102
+ - centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step,
103
+ indicating where to center the low-res crop for each step. The list
104
+ length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5)
105
+ for all steps.
106
+
107
+ Returns a tuple (sr_pil_list, prompt_list), where:
108
+ - sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order.
109
+ - prompt_list: list of the VLM prompts generated at each recursion.
110
  """
111
+ ###############################
112
+ # 0. Validate / fill default centers
113
+ ###############################
114
+ if centers is None:
115
+ # Default: use center (0.5, 0.5) for every recursion
116
+ centers = [(0.5, 0.5) for _ in range(rec_num)]
117
+ else:
118
+ if not isinstance(centers, (list, tuple)) or len(centers) != rec_num:
119
+ raise ValueError(
120
+ f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}."
121
+ )
122
+
123
  ###############################
124
  # 1. Fixed hyper-parameters
125
  ###############################
126
  device = "cuda"
127
  process_size = 512 # same as args.process_size
128
+
129
  # model checkpoint paths (hard-coded to your example)
130
  LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
131
  VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
 
160
  ###############################
161
  # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders
162
  sd3 = SD3Euler()
163
+ # move all text encoders + transformer + VAE to CUDA:
164
  sd3.text_enc_1.to(device)
165
  sd3.text_enc_2.to(device)
166
  sd3.text_enc_3.to(device)
 
181
  # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor])
182
 
183
  ###############################
184
+ # 4. Load the VLM (Qwen2.5-VL)
185
  ###############################
186
  vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
187
  VLM_NAME,
 
191
  vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
192
 
193
  ###############################
194
+ # 5. Pre-allocate a Temporary Directory
195
  # to hold intermediate JPEG/PNG files
196
  ###############################
197
  unique_id = uuid.uuid4().hex
 
211
  prev_path = os.path.join(td, "step0_prev.png")
212
  img0.save(prev_path)
213
 
214
+ # We will maintain lists of PIL outputs and prompts:
215
  sr_pil_list: list[Image.Image] = []
216
+ prompt_list: list[str] = []
217
 
218
  ###############################
219
+ # 7. Recursion loop (now up to rec_num times)
220
  ###############################
221
  for rec in range(rec_num):
222
+ # (A) Load the previous SR output (or original) and compute crop window
223
  prev_pil = Image.open(prev_path).convert("RGB")
224
+ w, h = prev_pil.size # should be (512×512) each time
225
+
226
+ # (1) Compute the “low-res” window size:
227
  new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
228
+
229
+ # (2) Map normalized center → pixel center, then clamp so crop stays in bounds:
230
+ cx_norm, cy_norm = centers[rec]
231
+ cx = int(cx_norm * w)
232
+ cy = int(cy_norm * h)
233
+ half_w = new_w // 2
234
+ half_h = new_h // 2
235
+
236
+ # If center in pixels is too close to left/top, clamp so left=0 or top=0; same on right/bottom
237
+ left = cx - half_w
238
+ top = cy - half_h
239
+ # clamp left ∈ [0, w - new_w], top ∈ [0, h - new_h]
240
+ left = max(0, min(left, w - new_w))
241
+ top = max(0, min(top, h - new_h))
242
  right = left + new_w
243
  bottom = top + new_h
244
+
245
  cropped = prev_pil.crop((left, top, right, bottom))
246
 
247
  # (B) Resize that crop back up to (512×512) via BICUBIC → zoomed
 
260
  )
261
  # (By default, no extra user prompt is appended.)
262
 
263
+ # (D) Prepare the low-res tensor for SR: convert zoomed Tensor [0,1] [−1,1]
264
  to_tensor = transforms.ToTensor()
265
  lq = to_tensor(zoomed).unsqueeze(0).to(device) # shape (1,3,512,512)
266
  lq = (lq * 2.0) - 1.0
 
284
  # end for(rec)
285
 
286
  ###############################
287
+ # 8. Return the SR outputs & prompts
288
  ###############################
289
+ # The list sr_pil_list = [ SR1, SR2, …, SR_rec_num ] in order.
290
+ return sr_pil_list, prompt_list