alexnasa commited on
Commit
24ee135
·
verified ·
1 Parent(s): 80bb1dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -54
app.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
  from PIL import Image
7
  import spaces
8
 
 
9
  # -----------------------------------------------------------------------------
10
  # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
11
  # -----------------------------------------------------------------------------
@@ -18,13 +19,13 @@ OUTPUT_DIR = "inference_results/coz_vlmprompt"
18
  # -----------------------------------------------------------------------------
19
 
20
  @spaces.GPU()
21
- def run_with_upload(uploaded_image_path):
22
  """
23
- 1) Clear out INPUT_DIR (so old samples don’t linger).
24
- 2) Copy the uploaded image into INPUT_DIR.
25
- 3) Run your inference_coz.py command (which reads from -i INPUT_DIR).
26
- 4) After it finishes, find the most recently‐modified PNG in OUTPUT_DIR.
27
- 5) Return a PIL.Image, which Gradio will display.
28
  """
29
 
30
  # 1) Make sure INPUT_DIR exists; if it does, delete everything inside.
@@ -61,13 +62,15 @@ def run_with_upload(uploaded_image_path):
61
 
62
  # 3) Build and run your inference_coz.py command.
63
  # This will block until it completes.
 
 
64
  cmd = [
65
  "python", "inference_coz.py",
66
  "-i", INPUT_DIR,
67
  "-o", OUTPUT_DIR,
68
  "--rec_type", "recursive_multiscale",
69
  "--prompt_type", "vlm",
70
- "--upscale", "2",
71
  "--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
72
  "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
73
  "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
@@ -81,41 +84,49 @@ def run_with_upload(uploaded_image_path):
81
  print("Inference failed:", err)
82
  return None
83
 
84
- # 4) After it finishes, scan OUTPUT_DIR for .png files.
85
-
86
- RECUSIVE_DIR = f'{OUTPUT_DIR}/recursive'
87
-
88
- if not os.path.isdir(RECUSIVE_DIR):
89
- return None
 
 
 
 
 
 
 
 
 
 
90
 
91
- png_files = [
92
- os.path.join(RECUSIVE_DIR, fn)
93
- for fn in os.listdir(RECUSIVE_DIR)
94
- if fn.lower().endswith(".png")
95
- ]
96
- if not png_files:
97
  return None
98
 
99
- # 5) Pick the most recently‐modified PNG
100
- latest_png = max(png_files, key=os.path.getmtime)
101
 
102
- # 6) Open and return a PIL.Image. Gradio will display it automatically.
103
- try:
104
- img = Image.open(latest_png).convert("RGB")
105
- except Exception as e:
106
- print(f"Error opening {latest_png}: {e}")
107
- return None
108
 
109
- return img
 
 
 
110
 
111
- # -----------------------------------------------------------------------------
 
 
112
  # BUILD THE GRADIO INTERFACE
113
  # -----------------------------------------------------------------------------
114
 
115
  css="""
116
  #col-container {
117
  margin: 0 auto;
118
- max-width: 720px;
119
  }
120
  """
121
 
@@ -138,32 +149,43 @@ with gr.Blocks(css=css) as demo:
138
 
139
  with gr.Column(elem_id="col-container"):
140
 
141
- # 1) Image upload component. We set type="filepath" so the callback
142
- # (run_with_upload) will receive a local path to the uploaded file.
143
- upload_image = gr.Image(
144
- label="Upload your input image",
145
- type="filepath"
146
- )
147
-
148
- # 2) A button that the user will click to launch inference.
149
- run_button = gr.Button("Run Inference")
150
-
151
- # 3) An output <Image> where we will show the final PNG.
152
- output_image = gr.Image(
153
- label="Inference Result",
154
- type="pil" # because run_with_upload() returns a PIL.Image
155
- )
156
-
157
- # Wire the button: when clicked, call run_with_upload(upload_image), put
158
- # its return value into output_image.
159
- run_button.click(
160
- fn=run_with_upload,
161
- inputs=upload_image,
162
- outputs=output_image
163
- )
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # -----------------------------------------------------------------------------
166
  # START THE GRADIO SERVER
167
  # -----------------------------------------------------------------------------
168
 
169
- demo.launch(share=True)
 
6
  from PIL import Image
7
  import spaces
8
 
9
+
10
  # -----------------------------------------------------------------------------
11
  # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
12
  # -----------------------------------------------------------------------------
 
19
  # -----------------------------------------------------------------------------
20
 
21
  @spaces.GPU()
22
+ def run_with_upload(uploaded_image_path, upscale_option):
23
  """
24
+ 1) Clear INPUT_DIR
25
+ 2) Save the uploaded file as input.png in INPUT_DIR
26
+ 3) Read `upscale_option` (e.g. "1x", "2x", "4x") turn it into "1", "2", or "4"
27
+ 4) Call inference_coz.py with `--upscale <that_value>`
28
+ 5) (Here we assume you still stitch together 1.png–4.png, or however you want.)
29
  """
30
 
31
  # 1) Make sure INPUT_DIR exists; if it does, delete everything inside.
 
62
 
63
  # 3) Build and run your inference_coz.py command.
64
  # This will block until it completes.
65
+ upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2"
66
+
67
  cmd = [
68
  "python", "inference_coz.py",
69
  "-i", INPUT_DIR,
70
  "-o", OUTPUT_DIR,
71
  "--rec_type", "recursive_multiscale",
72
  "--prompt_type", "vlm",
73
+ "--upscale", upscale_value,
74
  "--lora_path", "ckpt/SR_LoRA/model_20001.pkl",
75
  "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt",
76
  "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers",
 
84
  print("Inference failed:", err)
85
  return None
86
 
87
+ # -------------------------------------------------------------------------
88
+ # 4) After inference, look for the four numbered PNGs and stitch them
89
+ # -------------------------------------------------------------------------
90
+ per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input")
91
+ expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)]
92
+ pil_images = []
93
+ for fp in expected_files:
94
+ if not os.path.isfile(fp):
95
+ print(f"Warning: expected file not found: {fp}")
96
+ return None
97
+ try:
98
+ img = Image.open(fp).convert("RGB")
99
+ pil_images.append(img)
100
+ except Exception as e:
101
+ print(f"Error opening {fp}: {e}")
102
+ return None
103
 
104
+ if len(pil_images) != 4:
105
+ print(f"Error: found {len(pil_images)} images, but need 4.")
 
 
 
 
106
  return None
107
 
108
+ widths, heights = zip(*(im.size for im in pil_images))
109
+ w, h = widths[0], heights[0]
110
 
111
+ grid_w = w * 2
112
+ grid_h = h * 2
113
+ # composite = Image.new("RGB", (grid_w, grid_h))
 
 
 
114
 
115
+ # composite.paste(pil_images[0], (0, 0))
116
+ # composite.paste(pil_images[1], (w, 0))
117
+ # composite.paste(pil_images[2], (0, h))
118
+ # composite.paste(pil_images[3], (w, h))
119
 
120
+ return [pil_images[0], pil_images[1], pil_images[2], pil_images[3]]
121
+
122
+ # -------------------------------------------------------------
123
  # BUILD THE GRADIO INTERFACE
124
  # -----------------------------------------------------------------------------
125
 
126
  css="""
127
  #col-container {
128
  margin: 0 auto;
129
+ max-width: 1024px;
130
  }
131
  """
132
 
 
149
 
150
  with gr.Column(elem_id="col-container"):
151
 
152
+ with gr.Row():
153
+
154
+ with gr.Column():
155
+ # 1) Image upload component. We set type="filepath" so the callback
156
+ # (run_with_upload) will receive a local path to the uploaded file.
157
+ upload_image = gr.Image(
158
+ label="Upload your input image",
159
+ type="filepath"
160
+ )
161
+ # 2) Radio for choosing 1× / 2× / 4× upscaling
162
+ upscale_radio = gr.Radio(
163
+ choices=["1x", "2x", "4x"],
164
+ value="2x",
165
+ show_label=False
166
+ )
167
+
168
+ # 2) A button that the user will click to launch inference.
169
+ run_button = gr.Button("Chain-of-Zoom it")
170
+
171
+ # (3) Gallery to display multiple output images
172
+ output_gallery = gr.Gallery(
173
+ label="Inference Results",
174
+ show_label=True,
175
+ elem_id="gallery",
176
+ columns=[2], rows=[2]
177
+ )
178
+
179
+ # Wire the button: when clicked, call run_with_upload(upload_image), put
180
+ # its return value into output_image.
181
+ run_button.click(
182
+ fn=run_with_upload,
183
+ inputs=[upload_image, upscale_radio],
184
+ outputs=output_gallery
185
+ )
186
 
187
  # -----------------------------------------------------------------------------
188
  # START THE GRADIO SERVER
189
  # -----------------------------------------------------------------------------
190
 
191
+ demo.launch(share=True)