cheeseman182 commited on
Commit
b488552
·
verified ·
1 Parent(s): b5dac16
Files changed (1) hide show
  1. media.py +78 -50
media.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # --- LIBRARIES ---
2
  import torch
3
  import gradio as gr
@@ -11,9 +14,10 @@ import numpy as np
11
  import threading
12
  from queue import Queue, Empty as QueueEmpty
13
  from PIL import Image
 
14
  from huggingface_hub import login
15
 
16
- # --- DYNAMIC HARDWARE DETECTION & AUTH ---
17
  if torch.cuda.is_available():
18
  device = "cuda"
19
  torch_dtype = torch.float16
@@ -21,13 +25,22 @@ if torch.cuda.is_available():
21
  else:
22
  device = "cpu"
23
  torch_dtype = torch.float32
24
- print("⚠️ No GPU detected.")
25
 
26
- HF_TOKEN = os.getenv("HF_TOKEN") # Will read the token from Space secrets
27
- if HF_TOKEN is None:
28
- raise ValueError("❌ HF_TOKEN is not set in the environment variables!")
29
 
30
- login(token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- CONFIGURATION & STATE ---
33
  available_models = {
@@ -38,34 +51,46 @@ available_models = {
38
  }
39
  model_state = { "current_pipe": None, "loaded_model_name": None }
40
 
41
- # --- THE FINAL, STABLE GENERATION FUNCTION ---
42
- def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
43
  global model_state
44
-
45
- # --- Model Loading ---
46
  if model_state.get("loaded_model_name") != model_key:
47
  yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
48
- if model_state.get("current_pipe"):
49
- pipe_to_delete = model_state.pop("current_pipe", None)
50
- if pipe_to_delete: del pipe_to_delete
51
- gc.collect()
52
- torch.cuda.empty_cache()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  model_id = available_models[model_key]
55
  if "Video" in model_key:
56
- pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
57
  else:
58
  pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
59
-
60
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
61
  pipe.to(device)
62
-
63
  if device == "cuda":
64
- if "Video" not in model_key: pipe.enable_model_cpu_offload()
65
- pipe.enable_vae_slicing()
 
 
66
  model_state["current_pipe"] = pipe
67
  model_state["loaded_model_name"] = model_key
68
- print(f"✅ Model loaded on {device.upper()}.")
69
 
70
  pipe = model_state["current_pipe"]
71
  generator = torch.Generator(device).manual_seed(seed)
@@ -73,29 +98,32 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
73
  # --- Generation Logic ---
74
  if "Video" in model_key:
75
  yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
76
- # (Your working video code)
77
- video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
78
- video_frames_5d = np.array(video_frames)
79
- video_frames_4d = np.squeeze(video_frames_5d)
80
- video_uint8 = (video_frames_4d * 255).astype(np.uint8)
81
- list_of_frames = [frame for frame in video_uint8]
82
- video_path = f"video_{seed}.mp4"
83
- imageio.mimsave(video_path, list_of_frames, fps=12)
84
- yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
85
-
86
- else: # Image Generation with your brilliant text-based progress bar
 
 
 
 
87
  progress_queue = Queue()
88
 
89
  def run_pipe():
90
  start_time = time.time()
91
 
92
- # This callback correctly accepts all arguments
93
- def progress_callback(step, timestep, latents, **kwargs):
94
  elapsed_time = time.time() - start_time
95
  if elapsed_time > 0:
96
  its_per_sec = (step + 1) / elapsed_time
97
- progress_queue.put(("progress", step + 1, its_per_sec))
98
- return kwargs
99
 
100
  try:
101
  final_image = pipe(
@@ -104,7 +132,7 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
104
  generator=generator,
105
  callback_on_step_end=progress_callback
106
  ).images[0]
107
- progress_queue.put(("final", final_image))
108
  except Exception as e:
109
  print(f"An error occurred in the generation thread: {e}")
110
  progress_queue.put(("error", str(e)))
@@ -113,19 +141,17 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
113
  thread.start()
114
 
115
  total_steps = int(steps)
116
- final_image_result = None
117
  yield {status_textbox: "Generating..."}
118
 
119
  while True:
120
  try:
121
- update_type, data = progress_queue.get(timeout=1.0)
122
 
123
- if update_type == "final":
124
- final_image_result = data
125
- yield {output_image: final_image_result, status_textbox: f"Generation complete! Seed: {seed}"}
126
  break
127
  elif update_type == "progress":
128
- current_step, its_per_sec = data
129
  progress_percent = (current_step / total_steps) * 100
130
  steps_remaining = total_steps - current_step
131
  eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
@@ -137,18 +163,19 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
137
  )
138
  yield {status_textbox: status_text}
139
  elif update_type == "error":
140
- yield {status_textbox: f"Error: {data}"}
141
  break
142
  except QueueEmpty:
143
  if not thread.is_alive():
 
144
  yield {status_textbox: "Generation failed. Check console for details."}
145
  break
146
 
147
  thread.join()
 
148
 
149
- # --- GRADIO UI (Unchanged) ---
150
  with gr.Blocks(theme='gradio/soft') as demo:
151
- # (Your UI code is perfect)
152
  gr.Markdown("# The Generative Media Suite")
153
  gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
154
  seed_state = gr.State(-1)
@@ -185,15 +212,16 @@ with gr.Blocks(theme='gradio/soft') as demo:
185
  }
186
  model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
187
 
188
- click_event = generate_button.click(
189
- fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
190
  inputs=seed_input,
191
  outputs=seed_state,
192
  queue=False
193
  ).then(
194
- fn=generate_media_with_progress,
195
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
196
  outputs=[output_image, output_video, status_textbox]
197
  )
198
 
199
- demo.launch(share=True)
 
 
1
+
2
+ # --- START OF FILE media.py (FINAL WITH LIVE PROGRESS & FIXES) ---
3
+
4
  # --- LIBRARIES ---
5
  import torch
6
  import gradio as gr
 
14
  import threading
15
  from queue import Queue, Empty as QueueEmpty
16
  from PIL import Image
17
+ import os
18
  from huggingface_hub import login
19
 
20
+ # --- DYNAMIC HARDWARE DETECTION ---
21
  if torch.cuda.is_available():
22
  device = "cuda"
23
  torch_dtype = torch.float16
 
25
  else:
26
  device = "cpu"
27
  torch_dtype = torch.float32
28
+ print("⚠️ No GPU detected. Using CPU.")
29
 
 
 
 
30
 
31
+ HF_TOKEN = os.environ.get('HF_TOKEN')
32
+
33
+ if HF_TOKEN:
34
+ print("✅ Found HF_TOKEN secret. Logging in...")
35
+ try:
36
+ login(token=HF_TOKEN)
37
+ print("✅ Hugging Face Authentication successful.")
38
+ except Exception as e:
39
+ print(f"❌ Hugging Face login failed: {e}")
40
+ else:
41
+ # This message will show when you run the app locally, which is fine.
42
+ print("⚠️ No HF_TOKEN secret found. This is normal for local testing.")
43
+ print(" The deployed app will use the secret you set on Hugging Face.")
44
 
45
  # --- CONFIGURATION & STATE ---
46
  available_models = {
 
51
  }
52
  model_state = { "current_pipe": None, "loaded_model_name": None }
53
 
54
+ # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS & FIXES ---
55
+ def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
56
  global model_state
57
+
58
+ # --- Model Loading & Cleanup ---
59
  if model_state.get("loaded_model_name") != model_key:
60
  yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
 
 
 
 
 
61
 
62
+ # --- More Aggressive & Explicit Cleanup ---
63
+ pipe_to_delete = model_state.pop("current_pipe", None)
64
+ if pipe_to_delete:
65
+ # FIX: Explicitly move the model to CPU before deleting to free VRAM.
66
+ print("Offloading previous model to CPU...")
67
+ pipe_to_delete.to("cpu")
68
+ del pipe_to_delete
69
+ print("Previous model deleted.")
70
+
71
+ # Explicitly run garbage collection and empty CUDA cache.
72
+ gc.collect()
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+
76
+ # Load the new pipeline
77
  model_id = available_models[model_key]
78
  if "Video" in model_key:
79
+ pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
80
  else:
81
  pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
82
+
83
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
84
  pipe.to(device)
85
+
86
  if device == "cuda":
87
+ if "Video" not in model_key:
88
+ pipe.enable_model_cpu_offload()
89
+ pipe.enable_vae_slicing()
90
+
91
  model_state["current_pipe"] = pipe
92
  model_state["loaded_model_name"] = model_key
93
+ print(f"✅ Model '{model_key}' loaded on {device.upper()}.")
94
 
95
  pipe = model_state["current_pipe"]
96
  generator = torch.Generator(device).manual_seed(seed)
 
98
  # --- Generation Logic ---
99
  if "Video" in model_key:
100
  yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
101
+ try:
102
+ video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
103
+
104
+ # FIX: More memory-efficient video saving
105
+ video_path = f"video_{seed}.mp4"
106
+ with imageio.get_writer(video_path, fps=12) as writer:
107
+ for frame in video_frames:
108
+ writer.append_data((frame * 255).astype(np.uint8))
109
+
110
+ yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
111
+ except Exception as e:
112
+ print(f"An error occurred during video generation: {e}")
113
+ yield {status_textbox: f"Error during video generation: {e}"}
114
+
115
+ else: # Image Generation with Live Progress
116
  progress_queue = Queue()
117
 
118
  def run_pipe():
119
  start_time = time.time()
120
 
121
+ def progress_callback(pipe, step, timestep, callback_kwargs):
 
122
  elapsed_time = time.time() - start_time
123
  if elapsed_time > 0:
124
  its_per_sec = (step + 1) / elapsed_time
125
+ progress_queue.put(("progress", (step + 1, its_per_sec)))
126
+ return callback_kwargs
127
 
128
  try:
129
  final_image = pipe(
 
132
  generator=generator,
133
  callback_on_step_end=progress_callback
134
  ).images[0]
135
+ progress_queue.put(("result", final_image))
136
  except Exception as e:
137
  print(f"An error occurred in the generation thread: {e}")
138
  progress_queue.put(("error", str(e)))
 
141
  thread.start()
142
 
143
  total_steps = int(steps)
 
144
  yield {status_textbox: "Generating..."}
145
 
146
  while True:
147
  try:
148
+ update_type, payload = progress_queue.get(timeout=1.0)
149
 
150
+ if update_type == "result":
151
+ yield {output_image: payload, status_textbox: f"Generation complete! Seed: {seed}"}
 
152
  break
153
  elif update_type == "progress":
154
+ current_step, its_per_sec = payload
155
  progress_percent = (current_step / total_steps) * 100
156
  steps_remaining = total_steps - current_step
157
  eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
 
163
  )
164
  yield {status_textbox: status_text}
165
  elif update_type == "error":
166
+ yield {status_textbox: f"Error: {payload}. Check console."}
167
  break
168
  except QueueEmpty:
169
  if not thread.is_alive():
170
+ print("⚠️ Generation thread finished unexpectedly.")
171
  yield {status_textbox: "Generation failed. Check console for details."}
172
  break
173
 
174
  thread.join()
175
+ print("Generation thread joined.")
176
 
177
+ # --- GRADIO UI ---
178
  with gr.Blocks(theme='gradio/soft') as demo:
 
179
  gr.Markdown("# The Generative Media Suite")
180
  gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
181
  seed_state = gr.State(-1)
 
212
  }
213
  model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
214
 
215
+ generate_button.click(
216
+ fn=lambda s: s if s != -1 else random.randint(0, 2**32 - 1),
217
  inputs=seed_input,
218
  outputs=seed_state,
219
  queue=False
220
  ).then(
221
+ fn=generate_media_live_progress,
222
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
223
  outputs=[output_image, output_video, status_textbox]
224
  )
225
 
226
+ if __name__ == "__main__":
227
+ demo.launch(share=True, debug=True)