Spaces:
Running
Running
fix error
Browse files
media.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
# --- LIBRARIES ---
|
2 |
import torch
|
3 |
import gradio as gr
|
@@ -11,9 +14,10 @@ import numpy as np
|
|
11 |
import threading
|
12 |
from queue import Queue, Empty as QueueEmpty
|
13 |
from PIL import Image
|
|
|
14 |
from huggingface_hub import login
|
15 |
|
16 |
-
# --- DYNAMIC HARDWARE DETECTION
|
17 |
if torch.cuda.is_available():
|
18 |
device = "cuda"
|
19 |
torch_dtype = torch.float16
|
@@ -21,13 +25,22 @@ if torch.cuda.is_available():
|
|
21 |
else:
|
22 |
device = "cpu"
|
23 |
torch_dtype = torch.float32
|
24 |
-
print("⚠️ No GPU detected.")
|
25 |
|
26 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # Will read the token from Space secrets
|
27 |
-
if HF_TOKEN is None:
|
28 |
-
raise ValueError("❌ HF_TOKEN is not set in the environment variables!")
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# --- CONFIGURATION & STATE ---
|
33 |
available_models = {
|
@@ -38,34 +51,46 @@ available_models = {
|
|
38 |
}
|
39 |
model_state = { "current_pipe": None, "loaded_model_name": None }
|
40 |
|
41 |
-
# --- THE FINAL
|
42 |
-
def
|
43 |
global model_state
|
44 |
-
|
45 |
-
# --- Model Loading ---
|
46 |
if model_state.get("loaded_model_name") != model_key:
|
47 |
yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
|
48 |
-
if model_state.get("current_pipe"):
|
49 |
-
pipe_to_delete = model_state.pop("current_pipe", None)
|
50 |
-
if pipe_to_delete: del pipe_to_delete
|
51 |
-
gc.collect()
|
52 |
-
torch.cuda.empty_cache()
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
model_id = available_models[model_key]
|
55 |
if "Video" in model_key:
|
56 |
-
pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
|
57 |
else:
|
58 |
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
|
59 |
-
|
60 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
61 |
pipe.to(device)
|
62 |
-
|
63 |
if device == "cuda":
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
model_state["current_pipe"] = pipe
|
67 |
model_state["loaded_model_name"] = model_key
|
68 |
-
print(f"✅ Model loaded on {device.upper()}.")
|
69 |
|
70 |
pipe = model_state["current_pipe"]
|
71 |
generator = torch.Generator(device).manual_seed(seed)
|
@@ -73,29 +98,32 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
|
|
73 |
# --- Generation Logic ---
|
74 |
if "Video" in model_key:
|
75 |
yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
progress_queue = Queue()
|
88 |
|
89 |
def run_pipe():
|
90 |
start_time = time.time()
|
91 |
|
92 |
-
|
93 |
-
def progress_callback(step, timestep, latents, **kwargs):
|
94 |
elapsed_time = time.time() - start_time
|
95 |
if elapsed_time > 0:
|
96 |
its_per_sec = (step + 1) / elapsed_time
|
97 |
-
progress_queue.put(("progress", step + 1, its_per_sec))
|
98 |
-
return
|
99 |
|
100 |
try:
|
101 |
final_image = pipe(
|
@@ -104,7 +132,7 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
|
|
104 |
generator=generator,
|
105 |
callback_on_step_end=progress_callback
|
106 |
).images[0]
|
107 |
-
progress_queue.put(("
|
108 |
except Exception as e:
|
109 |
print(f"An error occurred in the generation thread: {e}")
|
110 |
progress_queue.put(("error", str(e)))
|
@@ -113,19 +141,17 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
|
|
113 |
thread.start()
|
114 |
|
115 |
total_steps = int(steps)
|
116 |
-
final_image_result = None
|
117 |
yield {status_textbox: "Generating..."}
|
118 |
|
119 |
while True:
|
120 |
try:
|
121 |
-
update_type,
|
122 |
|
123 |
-
if update_type == "
|
124 |
-
|
125 |
-
yield {output_image: final_image_result, status_textbox: f"Generation complete! Seed: {seed}"}
|
126 |
break
|
127 |
elif update_type == "progress":
|
128 |
-
current_step, its_per_sec =
|
129 |
progress_percent = (current_step / total_steps) * 100
|
130 |
steps_remaining = total_steps - current_step
|
131 |
eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
|
@@ -137,18 +163,19 @@ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_
|
|
137 |
)
|
138 |
yield {status_textbox: status_text}
|
139 |
elif update_type == "error":
|
140 |
-
yield {status_textbox: f"Error: {
|
141 |
break
|
142 |
except QueueEmpty:
|
143 |
if not thread.is_alive():
|
|
|
144 |
yield {status_textbox: "Generation failed. Check console for details."}
|
145 |
break
|
146 |
|
147 |
thread.join()
|
|
|
148 |
|
149 |
-
# --- GRADIO UI
|
150 |
with gr.Blocks(theme='gradio/soft') as demo:
|
151 |
-
# (Your UI code is perfect)
|
152 |
gr.Markdown("# The Generative Media Suite")
|
153 |
gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
|
154 |
seed_state = gr.State(-1)
|
@@ -185,15 +212,16 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
185 |
}
|
186 |
model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
|
187 |
|
188 |
-
|
189 |
-
fn=lambda s:
|
190 |
inputs=seed_input,
|
191 |
outputs=seed_state,
|
192 |
queue=False
|
193 |
).then(
|
194 |
-
fn=
|
195 |
inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
|
196 |
outputs=[output_image, output_video, status_textbox]
|
197 |
)
|
198 |
|
199 |
-
|
|
|
|
1 |
+
|
2 |
+
# --- START OF FILE media.py (FINAL WITH LIVE PROGRESS & FIXES) ---
|
3 |
+
|
4 |
# --- LIBRARIES ---
|
5 |
import torch
|
6 |
import gradio as gr
|
|
|
14 |
import threading
|
15 |
from queue import Queue, Empty as QueueEmpty
|
16 |
from PIL import Image
|
17 |
+
import os
|
18 |
from huggingface_hub import login
|
19 |
|
20 |
+
# --- DYNAMIC HARDWARE DETECTION ---
|
21 |
if torch.cuda.is_available():
|
22 |
device = "cuda"
|
23 |
torch_dtype = torch.float16
|
|
|
25 |
else:
|
26 |
device = "cpu"
|
27 |
torch_dtype = torch.float32
|
28 |
+
print("⚠️ No GPU detected. Using CPU.")
|
29 |
|
|
|
|
|
|
|
30 |
|
31 |
+
HF_TOKEN = os.environ.get('HF_TOKEN')
|
32 |
+
|
33 |
+
if HF_TOKEN:
|
34 |
+
print("✅ Found HF_TOKEN secret. Logging in...")
|
35 |
+
try:
|
36 |
+
login(token=HF_TOKEN)
|
37 |
+
print("✅ Hugging Face Authentication successful.")
|
38 |
+
except Exception as e:
|
39 |
+
print(f"❌ Hugging Face login failed: {e}")
|
40 |
+
else:
|
41 |
+
# This message will show when you run the app locally, which is fine.
|
42 |
+
print("⚠️ No HF_TOKEN secret found. This is normal for local testing.")
|
43 |
+
print(" The deployed app will use the secret you set on Hugging Face.")
|
44 |
|
45 |
# --- CONFIGURATION & STATE ---
|
46 |
available_models = {
|
|
|
51 |
}
|
52 |
model_state = { "current_pipe": None, "loaded_model_name": None }
|
53 |
|
54 |
+
# --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS & FIXES ---
|
55 |
+
def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
|
56 |
global model_state
|
57 |
+
|
58 |
+
# --- Model Loading & Cleanup ---
|
59 |
if model_state.get("loaded_model_name") != model_key:
|
60 |
yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# --- More Aggressive & Explicit Cleanup ---
|
63 |
+
pipe_to_delete = model_state.pop("current_pipe", None)
|
64 |
+
if pipe_to_delete:
|
65 |
+
# FIX: Explicitly move the model to CPU before deleting to free VRAM.
|
66 |
+
print("Offloading previous model to CPU...")
|
67 |
+
pipe_to_delete.to("cpu")
|
68 |
+
del pipe_to_delete
|
69 |
+
print("Previous model deleted.")
|
70 |
+
|
71 |
+
# Explicitly run garbage collection and empty CUDA cache.
|
72 |
+
gc.collect()
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
torch.cuda.empty_cache()
|
75 |
+
|
76 |
+
# Load the new pipeline
|
77 |
model_id = available_models[model_key]
|
78 |
if "Video" in model_key:
|
79 |
+
pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
|
80 |
else:
|
81 |
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
|
82 |
+
|
83 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
84 |
pipe.to(device)
|
85 |
+
|
86 |
if device == "cuda":
|
87 |
+
if "Video" not in model_key:
|
88 |
+
pipe.enable_model_cpu_offload()
|
89 |
+
pipe.enable_vae_slicing()
|
90 |
+
|
91 |
model_state["current_pipe"] = pipe
|
92 |
model_state["loaded_model_name"] = model_key
|
93 |
+
print(f"✅ Model '{model_key}' loaded on {device.upper()}.")
|
94 |
|
95 |
pipe = model_state["current_pipe"]
|
96 |
generator = torch.Generator(device).manual_seed(seed)
|
|
|
98 |
# --- Generation Logic ---
|
99 |
if "Video" in model_key:
|
100 |
yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
|
101 |
+
try:
|
102 |
+
video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
|
103 |
+
|
104 |
+
# FIX: More memory-efficient video saving
|
105 |
+
video_path = f"video_{seed}.mp4"
|
106 |
+
with imageio.get_writer(video_path, fps=12) as writer:
|
107 |
+
for frame in video_frames:
|
108 |
+
writer.append_data((frame * 255).astype(np.uint8))
|
109 |
+
|
110 |
+
yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
|
111 |
+
except Exception as e:
|
112 |
+
print(f"An error occurred during video generation: {e}")
|
113 |
+
yield {status_textbox: f"Error during video generation: {e}"}
|
114 |
+
|
115 |
+
else: # Image Generation with Live Progress
|
116 |
progress_queue = Queue()
|
117 |
|
118 |
def run_pipe():
|
119 |
start_time = time.time()
|
120 |
|
121 |
+
def progress_callback(pipe, step, timestep, callback_kwargs):
|
|
|
122 |
elapsed_time = time.time() - start_time
|
123 |
if elapsed_time > 0:
|
124 |
its_per_sec = (step + 1) / elapsed_time
|
125 |
+
progress_queue.put(("progress", (step + 1, its_per_sec)))
|
126 |
+
return callback_kwargs
|
127 |
|
128 |
try:
|
129 |
final_image = pipe(
|
|
|
132 |
generator=generator,
|
133 |
callback_on_step_end=progress_callback
|
134 |
).images[0]
|
135 |
+
progress_queue.put(("result", final_image))
|
136 |
except Exception as e:
|
137 |
print(f"An error occurred in the generation thread: {e}")
|
138 |
progress_queue.put(("error", str(e)))
|
|
|
141 |
thread.start()
|
142 |
|
143 |
total_steps = int(steps)
|
|
|
144 |
yield {status_textbox: "Generating..."}
|
145 |
|
146 |
while True:
|
147 |
try:
|
148 |
+
update_type, payload = progress_queue.get(timeout=1.0)
|
149 |
|
150 |
+
if update_type == "result":
|
151 |
+
yield {output_image: payload, status_textbox: f"Generation complete! Seed: {seed}"}
|
|
|
152 |
break
|
153 |
elif update_type == "progress":
|
154 |
+
current_step, its_per_sec = payload
|
155 |
progress_percent = (current_step / total_steps) * 100
|
156 |
steps_remaining = total_steps - current_step
|
157 |
eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
|
|
|
163 |
)
|
164 |
yield {status_textbox: status_text}
|
165 |
elif update_type == "error":
|
166 |
+
yield {status_textbox: f"Error: {payload}. Check console."}
|
167 |
break
|
168 |
except QueueEmpty:
|
169 |
if not thread.is_alive():
|
170 |
+
print("⚠️ Generation thread finished unexpectedly.")
|
171 |
yield {status_textbox: "Generation failed. Check console for details."}
|
172 |
break
|
173 |
|
174 |
thread.join()
|
175 |
+
print("Generation thread joined.")
|
176 |
|
177 |
+
# --- GRADIO UI ---
|
178 |
with gr.Blocks(theme='gradio/soft') as demo:
|
|
|
179 |
gr.Markdown("# The Generative Media Suite")
|
180 |
gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
|
181 |
seed_state = gr.State(-1)
|
|
|
212 |
}
|
213 |
model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
|
214 |
|
215 |
+
generate_button.click(
|
216 |
+
fn=lambda s: s if s != -1 else random.randint(0, 2**32 - 1),
|
217 |
inputs=seed_input,
|
218 |
outputs=seed_state,
|
219 |
queue=False
|
220 |
).then(
|
221 |
+
fn=generate_media_live_progress,
|
222 |
inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
|
223 |
outputs=[output_image, output_video, status_textbox]
|
224 |
)
|
225 |
|
226 |
+
if __name__ == "__main__":
|
227 |
+
demo.launch(share=True, debug=True)
|