broadfield-dev commited on
Commit
b2c1451
·
verified ·
1 Parent(s): 815019f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -19
app.py CHANGED
@@ -7,6 +7,11 @@ import tempfile
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download, login
9
  import torch
 
 
 
 
 
10
 
11
  # --- Configuration ---
12
  LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
@@ -21,26 +26,26 @@ CHECKPOINT_FILES_CONFIG = {
21
 
22
  # --- Device Detection ---
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
- print(f"Using device: {DEVICE}")
25
 
26
  # --- Hugging Face Token ---
27
  HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")
28
  if HF_AUTH_TOKEN:
29
  try:
30
  login(token=HF_AUTH_TOKEN)
31
- print("Successfully logged in to Hugging Face Hub.")
32
  except Exception as e:
33
- print(f"Warning: Hugging Face login failed: {e}")
34
  else:
35
- print("Warning: HF_TOKEN not found. Downloads of gated models may fail.")
36
 
37
  # --- Model Download Function ---
38
  def download_coz_support_models():
39
- print("Checking and downloading CoZ support models...")
40
  for model_key, model_info in CHECKPOINT_FILES_CONFIG.items():
41
  target_file_path = Path(model_info["target_path"])
42
  if not target_file_path.exists():
43
- print(f"Downloading {model_key} from {model_info['repo_id']}...")
44
  target_file_path.parent.mkdir(parents=True, exist_ok=True)
45
  try:
46
  cached_file_path = hf_hub_download(
@@ -48,25 +53,48 @@ def download_coz_support_models():
48
  filename=model_info['filename'],
49
  token=HF_AUTH_TOKEN,
50
  cache_dir="./hf_cache",
51
- force_filename=Path(model_info['filename']).name
 
52
  )
53
  shutil.copy(cached_file_path, target_file_path)
54
- print(f"{model_key} downloaded to {target_file_path}")
55
  except Exception as e:
56
- print(f"Error downloading {model_key}: {e}")
 
57
  else:
58
- print(f"{model_key} already exists at {target_file_path}")
59
- print("All CoZ support models checked.")
60
 
61
  # Download models at startup
62
- download_coz_support_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # --- Main Inference Function ---
65
  def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int):
66
  if input_image is None:
 
67
  raise gr.Error("Please upload an image.")
68
 
69
- global HF_AUTH_TOKEN
70
  with tempfile.TemporaryDirectory() as temp_base_str:
71
  temp_base_dir = Path(temp_base_str)
72
  input_img_parent_dir = temp_base_dir / "input_images_root"
@@ -74,12 +102,14 @@ def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str
74
  input_image_filename = "source_image.png"
75
  input_image_path = input_img_parent_dir / input_image_filename
76
  input_image.save(input_image_path, "PNG")
 
77
 
78
  output_img_dir = temp_base_dir / "output_data"
79
  output_img_dir.mkdir(parents=True, exist_ok=True)
80
 
81
  # Check if inference_coz.py exists
82
  if not Path("inference_coz.py").exists():
 
83
  raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.")
84
 
85
  command = [
@@ -107,7 +137,7 @@ def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str
107
  if HF_AUTH_TOKEN:
108
  command.extend(["--hf_token", HF_AUTH_TOKEN])
109
 
110
- print(f"Running command: {' '.join(command)}")
111
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
112
 
113
  stdout_lines = []
@@ -117,7 +147,7 @@ def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str
117
  # Stream stdout
118
  if process.stdout:
119
  for line in iter(process.stdout.readline, ""):
120
- print(f"[CoZ STDOUT] {line}", end="")
121
  stdout_lines.append(line)
122
  if "Saving image to" in line:
123
  try:
@@ -128,14 +158,14 @@ def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str
128
  # Stream stderr
129
  if process.stderr:
130
  for line in iter(process.stderr.readline, ""):
131
- print(f"[CoZ STDERR] {line}", end="")
132
  stderr_lines.append(line)
133
 
134
  process.wait()
135
 
136
  if process.returncode != 0:
137
  error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}"
138
- print(error_message)
139
  raise gr.Error(f"Processing failed: {error_message}")
140
 
141
  # Find output image
@@ -150,10 +180,11 @@ def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str
150
 
151
  if not final_output_image_path or not final_output_image_path.exists():
152
  all_files = list(output_img_dir.rglob("*"))
153
- print(f"Files in output directory: {all_files}")
154
  raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
155
 
156
  output_image = Image.open(final_output_image_path)
 
157
  return output_image
158
 
159
  # --- Gradio Interface ---
@@ -171,6 +202,7 @@ Optimized for CPU and GPU environments. Ensure HF_TOKEN is set in Space secrets
171
  """
172
  article = "<p style='text-align: center;'><a href='https://github.com/bryanswkim/Chain-of-Zoom' target='_blank'>Chain-of-Zoom GitHub</a></p>"
173
 
 
174
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
175
  gr.Markdown(f"<h1 style='text-align: center'>{title}</h1>")
176
  gr.Markdown(description)
@@ -194,4 +226,10 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
194
  )
195
 
196
  if __name__ == "__main__":
197
- demo.launch()
 
 
 
 
 
 
 
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download, login
9
  import torch
10
+ import logging
11
+
12
+ # --- Logging Setup ---
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
 
16
  # --- Configuration ---
17
  LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl"
 
26
 
27
  # --- Device Detection ---
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ logger.info(f"Using device: {DEVICE}")
30
 
31
  # --- Hugging Face Token ---
32
  HF_AUTH_TOKEN = os.environ.get("HF_TOKEN")
33
  if HF_AUTH_TOKEN:
34
  try:
35
  login(token=HF_AUTH_TOKEN)
36
+ logger.info("Successfully logged in to Hugging Face Hub.")
37
  except Exception as e:
38
+ logger.warning(f"Hugging Face login failed: {e}")
39
  else:
40
+ logger.warning("HF_TOKEN not found. Downloads of gated models may fail.")
41
 
42
  # --- Model Download Function ---
43
  def download_coz_support_models():
44
+ logger.info("Checking and downloading CoZ support models...")
45
  for model_key, model_info in CHECKPOINT_FILES_CONFIG.items():
46
  target_file_path = Path(model_info["target_path"])
47
  if not target_file_path.exists():
48
+ logger.info(f"Downloading {model_key} from {model_info['repo_id']}...")
49
  target_file_path.parent.mkdir(parents=True, exist_ok=True)
50
  try:
51
  cached_file_path = hf_hub_download(
 
53
  filename=model_info['filename'],
54
  token=HF_AUTH_TOKEN,
55
  cache_dir="./hf_cache",
56
+ force_download=False,
57
+ resume_download=True
58
  )
59
  shutil.copy(cached_file_path, target_file_path)
60
+ logger.info(f"{model_key} downloaded to {target_file_path}")
61
  except Exception as e:
62
+ logger.error(f"Error downloading {model_key}: {e}")
63
+ raise
64
  else:
65
+ logger.info(f"{model_key} already exists at {target_file_path}")
66
+ logger.info("All CoZ support models checked.")
67
 
68
  # Download models at startup
69
+ try:
70
+ logger.info("Starting model download...")
71
+ download_coz_support_models()
72
+ logger.info("Model download completed.")
73
+ except Exception as e:
74
+ logger.error(f"Failed to download models: {e}")
75
+ raise
76
+
77
+ # --- Preload Stable Diffusion Model ---
78
+ logger.info("Preloading Stable Diffusion model configuration...")
79
+ try:
80
+ from diffusers import StableDiffusionPipeline
81
+ StableDiffusionPipeline.from_pretrained(
82
+ "stabilityai/stable-diffusion-3-medium-diffusers",
83
+ use_auth_token=HF_AUTH_TOKEN,
84
+ cache_dir="./hf_cache"
85
+ )
86
+ logger.info("Stable Diffusion model configuration preloaded.")
87
+ except Exception as e:
88
+ logger.error(f"Failed to preload Stable Diffusion model: {e}")
89
+ raise
90
 
91
  # --- Main Inference Function ---
92
  def run_chain_of_zoom(input_image: Image.Image, magnification: int, caption: str, seed: int):
93
  if input_image is None:
94
+ logger.error("No input image provided.")
95
  raise gr.Error("Please upload an image.")
96
 
97
+ logger.info(f"Starting inference with magnification={magnification}, seed={seed}, caption={caption}")
98
  with tempfile.TemporaryDirectory() as temp_base_str:
99
  temp_base_dir = Path(temp_base_str)
100
  input_img_parent_dir = temp_base_dir / "input_images_root"
 
102
  input_image_filename = "source_image.png"
103
  input_image_path = input_img_parent_dir / input_image_filename
104
  input_image.save(input_image_path, "PNG")
105
+ logger.info(f"Input image saved to {input_image_path}")
106
 
107
  output_img_dir = temp_base_dir / "output_data"
108
  output_img_dir.mkdir(parents=True, exist_ok=True)
109
 
110
  # Check if inference_coz.py exists
111
  if not Path("inference_coz.py").exists():
112
+ logger.error("inference_coz.py not found in repository.")
113
  raise gr.Error("inference_coz.py not found in repository. Please check the Chain-of-Zoom repository.")
114
 
115
  command = [
 
137
  if HF_AUTH_TOKEN:
138
  command.extend(["--hf_token", HF_AUTH_TOKEN])
139
 
140
+ logger.info(f"Running command: {' '.join(command)}")
141
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
142
 
143
  stdout_lines = []
 
147
  # Stream stdout
148
  if process.stdout:
149
  for line in iter(process.stdout.readline, ""):
150
+ logger.info(f"[CoZ STDOUT] {line.strip()}")
151
  stdout_lines.append(line)
152
  if "Saving image to" in line:
153
  try:
 
158
  # Stream stderr
159
  if process.stderr:
160
  for line in iter(process.stderr.readline, ""):
161
+ logger.warning(f"[CoZ STDERR] {line.strip()}")
162
  stderr_lines.append(line)
163
 
164
  process.wait()
165
 
166
  if process.returncode != 0:
167
  error_message = f"Chain-of-Zoom failed.\nSTDOUT:\n{''.join(stdout_lines[-5:])}\nSTDERR:\n{''.join(stderr_lines[-5:])}"
168
+ logger.error(error_message)
169
  raise gr.Error(f"Processing failed: {error_message}")
170
 
171
  # Find output image
 
180
 
181
  if not final_output_image_path or not final_output_image_path.exists():
182
  all_files = list(output_img_dir.rglob("*"))
183
+ logger.error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
184
  raise gr.Error(f"Output image not found in {output_img_dir}. Files found: {all_files}")
185
 
186
  output_image = Image.open(final_output_image_path)
187
+ logger.info(f"Output image generated: {final_output_image_path}")
188
  return output_image
189
 
190
  # --- Gradio Interface ---
 
202
  """
203
  article = "<p style='text-align: center;'><a href='https://github.com/bryanswkim/Chain-of-Zoom' target='_blank'>Chain-of-Zoom GitHub</a></p>"
204
 
205
+ logger.info("Initializing Gradio interface...")
206
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
207
  gr.Markdown(f"<h1 style='text-align: center'>{title}</h1>")
208
  gr.Markdown(description)
 
226
  )
227
 
228
  if __name__ == "__main__":
229
+ logger.info("Launching Gradio app...")
230
+ try:
231
+ demo.launch(server_name="0.0.0.0", server_port=7860)
232
+ logger.info("Gradio app launched successfully.")
233
+ except Exception as e:
234
+ logger.error(f"Failed to launch Gradio app: {e}")
235
+ raise