alexnasa commited on
Commit
7f5e4af
·
verified ·
1 Parent(s): 680b212

Create inference_coz_single.py

Browse files
Files changed (1) hide show
  1. inference_coz_single.py +258 -0
inference_coz_single.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import uuid
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
+ from qwen_vl_utils import process_vision_info
9
+ from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
10
+
11
+ # -------------------------------------------------------------------
12
+ # Helper: Resize & center-crop to a fixed square
13
+ # -------------------------------------------------------------------
14
+ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
15
+ w, h = img.size
16
+ scale = size / min(w, h)
17
+ new_w, new_h = int(w * scale), int(h * scale)
18
+ img = img.resize((new_w, new_h), Image.LANCZOS)
19
+ left = (new_w - size) // 2
20
+ top = (new_h - size) // 2
21
+ return img.crop((left, top, left + size, top + size))
22
+
23
+
24
+ # -------------------------------------------------------------------
25
+ # Helper: Generate a single VLM prompt for recursive_multiscale
26
+ # -------------------------------------------------------------------
27
+ def _generate_vlm_prompt(
28
+ vlm_model,
29
+ vlm_processor,
30
+ process_vision_info,
31
+ prev_image_path: str,
32
+ zoomed_image_path: str,
33
+ device: str = "cuda"
34
+ ) -> str:
35
+ """
36
+ Given two image file paths:
37
+ - prev_image_path: the “full” image at the previous recursion.
38
+ - zoomed_image_path: the cropped+resized (zoom) image for this step.
39
+ This builds a single “recursive_multiscale” prompt via Qwen2.5-VL.
40
+ Returns a string like “cat on sofa, pet, indoor, living room”, etc.
41
+ """
42
+ # (1) Define the system message for recursive_multiscale:
43
+ message_text = (
44
+ "The second image is a zoom-in of the first image. "
45
+ "Based on this knowledge, what is in the second image? "
46
+ "Give me a set of words."
47
+ )
48
+
49
+ # (2) Build the two-image “chat” payload:
50
+ messages = [
51
+ {"role": "system", "content": message_text},
52
+ {
53
+ "role": "user",
54
+ "content": [
55
+ {"type": "image", "image": prev_image_path},
56
+ {"type": "image", "image": zoomed_image_path},
57
+ ],
58
+ },
59
+ ]
60
+
61
+ # (3) Wrap through the VL processor to get “inputs”:
62
+ text = vlm_processor.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True
64
+ )
65
+ image_inputs, video_inputs = process_vision_info(messages)
66
+ inputs = vlm_processor(
67
+ text=[text],
68
+ images=image_inputs,
69
+ videos=video_inputs,
70
+ padding=True,
71
+ return_tensors="pt",
72
+ ).to(device)
73
+
74
+ # (4) Generate tokens→decode
75
+ generated = vlm_model.generate(**inputs, max_new_tokens=128)
76
+ # strip off the prompt tokens from each generated sequence:
77
+ trimmed = [
78
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
79
+ ]
80
+ out_text = vlm_processor.batch_decode(
81
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
82
+ )[0]
83
+
84
+ # (5) Return exactly the bare words (no extra “,” if no additional user prompt)
85
+ return out_text.strip()
86
+
87
+
88
+ # -------------------------------------------------------------------
89
+ # Main Function: recursive_multiscale_sr
90
+ # -------------------------------------------------------------------
91
+ def recursive_multiscale_sr(
92
+ input_png_path: str,
93
+ upscale: int,
94
+ ) -> list[Image.Image]:
95
+ """
96
+ Perform exactly four recursive_multiscale super-resolution steps on a single PNG.
97
+ - input_png_path: path to a single .png file on disk.
98
+ - upscale: integer up-scale factor per recursion (e.g. 4).
99
+ Returns a list of 4 PIL.Image objects, corresponding to each SR output
100
+ at recursion steps 1, 2, 3, 4 (in that order).
101
+
102
+ All other parameters (model checkpoints, prompt model, process size, etc.)
103
+ are hard-coded exactly as in your command-line example.
104
+ """
105
+ ###############################
106
+ # 1. Fixed hyper-parameters
107
+ ###############################
108
+ device = "cuda"
109
+ process_size = 512 # same as args.process_size
110
+ rec_num = 4 # fixed to 4 recursions
111
+ # model checkpoint paths (hard-coded to your example)
112
+ LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
113
+ VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt"
114
+ SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
115
+ # VLM model name (hard-coded)
116
+ VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
117
+
118
+ ###############################
119
+ # 2. Build a dummy “args” namespace
120
+ # to satisfy OSEDiff_SD3_TEST constructor.
121
+ ###############################
122
+ class _Args:
123
+ pass
124
+
125
+ args = _Args()
126
+ args.upscale = upscale
127
+ args.lora_path = LORA_PATH
128
+ args.vae_path = VAE_PATH
129
+ args.pretrained_model_name_or_path = SD3_MODEL
130
+ args.merge_and_unload_lora = False
131
+ args.lora_rank = 4
132
+ args.vae_decoder_tiled_size = 224
133
+ args.vae_encoder_tiled_size = 1024
134
+ args.latent_tiled_size = 96
135
+ args.latent_tiled_overlap = 32
136
+ args.mixed_precision = "fp16"
137
+ args.efficient_memory = False
138
+ # (other flags are not used by OSEDiff_SD3_TEST, so we skip them)
139
+
140
+ ###############################
141
+ # 3. Load the SD3 SR model (non-efficient)
142
+ ###############################
143
+ # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders
144
+ sd3 = SD3Euler()
145
+ # move all text encoders+transformer+VAE to CUDA:
146
+ sd3.text_enc_1.to(device)
147
+ sd3.text_enc_2.to(device)
148
+ sd3.text_enc_3.to(device)
149
+ sd3.transformer.to(device, dtype=torch.float32)
150
+ sd3.vae.to(device, dtype=torch.float32)
151
+ # freeze
152
+ for p in (
153
+ sd3.text_enc_1,
154
+ sd3.text_enc_2,
155
+ sd3.text_enc_3,
156
+ sd3.transformer,
157
+ sd3.vae,
158
+ ):
159
+ p.requires_grad_(False)
160
+
161
+ # 3.2 Wrap in OSEDiff_SD3_TEST helper:
162
+ model_test = OSEDiff_SD3_TEST(args, sd3)
163
+ # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor])
164
+
165
+ ###############################
166
+ # 4. Load the VLM (Qwen2.5-VL)
167
+ ###############################
168
+ vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
169
+ VLM_NAME,
170
+ torch_dtype="auto",
171
+ device_map="auto" # immediately dispatches layers onto available GPUs
172
+ )
173
+ vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
174
+
175
+ ###############################
176
+ # 5. Pre-allocate a Temporary Directory
177
+ # to hold intermediate JPEG/PNG files
178
+ ###############################
179
+ unique_id = uuid.uuid4().hex
180
+ prefix = f"recms_{unique_id}_"
181
+
182
+ with tempfile.TemporaryDirectory(prefix=prefix) as td:
183
+ # (we’ll write “prev.png” and “zoom.png” at each step)
184
+
185
+ ###############################
186
+ # 6. Prepare the very first “full” image
187
+ ###############################
188
+ # 6.1 Load + center crop → first_image is (512×512) PIL on CPU
189
+ img0 = Image.open(input_png_path).convert("RGB")
190
+ img0 = resize_and_center_crop(img0, process_size)
191
+
192
+ # 6.2 Save it once so VLM can read it as “prev.png”
193
+ prev_path = os.path.join(td, "step0_prev.png")
194
+ img0.save(prev_path)
195
+
196
+ # We will maintain a list of PIL outputs here:
197
+ sr_pil_list: list[Image.Image] = []
198
+ prompt_list = []
199
+
200
+ ###############################
201
+ # 7. Recursion loop (exactly 4 times)
202
+ ###############################
203
+ for rec in range(rec_num):
204
+ # (A) Crop + upsample the “prev” image to obtain this step’s input → zoomed
205
+ prev_pil = Image.open(prev_path).convert("RGB")
206
+ w, h = prev_pil.size # should be (512×512) each time
207
+ new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
208
+ # center-crop region:
209
+ left = (w - new_w) // 2
210
+ top = (h - new_h) // 2
211
+ right = left + new_w
212
+ bottom = top + new_h
213
+ cropped = prev_pil.crop((left, top, right, bottom))
214
+
215
+ # (B) Resize that crop back up to (512×512) via BICUBIC → zoomed
216
+ zoomed = cropped.resize((w, h), Image.BICUBIC)
217
+ zoom_path = os.path.join(td, f"step{rec+1}_zoom.png")
218
+ zoomed.save(zoom_path)
219
+
220
+ # (C) Generate a recursive_multiscale VLM “tag” prompt
221
+ prompt_tag = _generate_vlm_prompt(
222
+ vlm_model=vlm_model,
223
+ vlm_processor=vlm_processor,
224
+ process_vision_info=process_vision_info,
225
+ prev_image_path=prev_path,
226
+ zoomed_image_path=zoom_path,
227
+ device=device,
228
+ )
229
+ # (By default, no extra user prompt is appended.)
230
+
231
+ # (D) Prepare the low-res tensor for SR: convert zoomed→Tensor→[0,1]→[−1,1]
232
+ to_tensor = transforms.ToTensor()
233
+ lq = to_tensor(zoomed).unsqueeze(0).to(device) # shape (1,3,512,512)
234
+ lq = (lq * 2.0) - 1.0
235
+
236
+ # (E) Do SR inference:
237
+ with torch.no_grad():
238
+ out_tensor = model_test(lq, prompt=prompt_tag)[0] # (3,512,512) on CPU or GPU
239
+ out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
240
+ # back to PIL in [0,1]:
241
+ out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
242
+
243
+ # (F) Save this step’s SR output as “prev.png” for next iteration:
244
+ out_path = os.path.join(td, f"step{rec+1}_sr.png")
245
+ out_pil.save(out_path)
246
+ prev_path = out_path
247
+
248
+ # (G) Append the PIL to our list:
249
+ sr_pil_list.append(out_pil)
250
+ prompt_list.append(prompt_tag)
251
+
252
+ # end for(rec)
253
+
254
+ ###############################
255
+ # 8. Return the four SR‐PILs
256
+ ###############################
257
+ # The list sr_pil_list = [ SR1, SR2, SR3, SR4 ] in order.
258
+ return sr_pil_list, prompt_list