akarshan11 commited on
Commit
a0aa36c
·
verified ·
1 Parent(s): bf2c23b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import inspect
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from PIL import Image
7
+ import numpy as np
8
+ from torch import autocast
9
+ import cv2
10
+ import gradio as gr
11
+
12
+ # -----------------------------------------------------------------------------
13
+ # 1. REQUIREMENTS & SETUP
14
+ # -----------------------------------------------------------------------------
15
+ # To set up the environment for this script, create a file named 'requirements.txt'
16
+ # with the following content and run 'pip install -r requirements.txt':
17
+ #
18
+ # torch>=2.0.0
19
+ # torchvision>=0.15.1
20
+ # diffusers>=0.20.2
21
+ # transformers>=4.30.2
22
+ # accelerate>=0.21.0
23
+ # gradio>=3.36.1
24
+ # opencv-python-headless>=4.8.0.74
25
+ # -----------------------------------------------------------------------------
26
+
27
+ # --- Automatic Device Detection ---
28
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ print("-------------------------------------------------")
30
+ print(f"INFO: Using device: {torch_device.upper()}")
31
+ if torch_device == "cpu":
32
+ print("WARNING: CUDA (GPU) not detected. The script will run on the CPU.")
33
+ print(" This will be extremely slow. For better performance,")
34
+ print(" please ensure you have an NVIDIA GPU and the correct")
35
+ print(" PyTorch version with CUDA support installed.")
36
+ print("-------------------------------------------------")
37
+
38
+
39
+ # --- Load the Model ---
40
+ print("Loading Stable Diffusion model... This may take a moment.")
41
+ try:
42
+ # Load the pipeline and move it to the detected device
43
+ pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
44
+ pipe.to(torch_device)
45
+ print("Model loaded successfully.")
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+ print("Please check your internet connection and ensure the model name is correct.")
49
+ exit()
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Helper Functions (slerp, diffuse)
53
+ # -----------------------------------------------------------------------------
54
+
55
+ @torch.no_grad()
56
+ def diffuse(
57
+ pipe, cond_embeddings, cond_latents, num_inference_steps, guidance_scale, eta
58
+ ):
59
+ # This function remains the same, as it gets the device from the input tensors
60
+ device = cond_latents.get_device()
61
+ max_length = cond_embeddings.shape[1]
62
+ uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
63
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
64
+ text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
65
+
66
+ if "LMS" in pipe.scheduler.__class__.__name__:
67
+ cond_latents = cond_latents * pipe.scheduler.sigmas[0]
68
+
69
+ accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
70
+ extra_set_kwargs = {}
71
+ if accepts_offset:
72
+ extra_set_kwargs["offset"] = 1
73
+ pipe.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
74
+
75
+ accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
76
+ extra_step_kwargs = {}
77
+ if accepts_eta:
78
+ extra_step_kwargs["eta"] = eta
79
+
80
+ for i, t in enumerate(pipe.scheduler.timesteps):
81
+ latent_model_input = torch.cat([cond_latents] * 2)
82
+ if "LMS" in pipe.scheduler.__class__.__name__:
83
+ sigma = pipe.scheduler.sigmas[i]
84
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
85
+
86
+ # predict the noise residual
87
+ noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
88
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
89
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
90
+ cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
91
+
92
+ cond_latents = 1 / 0.18215 * cond_latents
93
+ image = pipe.vae.decode(cond_latents).sample
94
+ image = (image / 2 + 0.5).clamp(0, 1)
95
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
96
+ image = (image[0] * 255).astype(np.uint8)
97
+ return image
98
+
99
+ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
100
+ # This function is device-agnostic
101
+ inputs_are_torch = isinstance(v0, torch.Tensor)
102
+ if inputs_are_torch:
103
+ input_device = v0.device
104
+ v0 = v0.cpu().numpy()
105
+ v1 = v1.cpu().numpy()
106
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
107
+ if np.abs(dot) > DOT_THRESHOLD:
108
+ v2 = (1 - t) * v0 + t * v1
109
+ else:
110
+ theta_0 = np.arccos(dot)
111
+ sin_theta_0 = np.sin(theta_0)
112
+ theta_t = theta_0 * t
113
+ sin_theta_t = np.sin(theta_t)
114
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
115
+ s1 = sin_theta_t / sin_theta_0
116
+ v2 = s0 * v0 + s1 * v1
117
+ if inputs_are_torch:
118
+ v2 = torch.from_numpy(v2).to(input_device)
119
+ return v2
120
+
121
+ # -----------------------------------------------------------------------------
122
+ # Main Generator Function for Gradio
123
+ # -----------------------------------------------------------------------------
124
+ def generate_dream_video(
125
+ prompt_1, prompt_2, seed_1, seed_2,
126
+ width, height, num_steps, guidance_scale,
127
+ num_inference_steps, eta, name
128
+ ):
129
+ # --- 1. SETUP ---
130
+ yield {
131
+ status_text: "Status: Preparing prompts and latents...",
132
+ live_frame: None,
133
+ output_video: None,
134
+ }
135
+ prompts = [prompt_1, prompt_2]
136
+ seeds = [int(seed_1), int(seed_2)]
137
+ rootdir = './dreams'
138
+ outdir = os.path.join(rootdir, name)
139
+ os.makedirs(outdir, exist_ok=True)
140
+
141
+ # --- 2. EMBEDDINGS AND LATENTS ---
142
+ prompt_embeddings = []
143
+ for prompt in prompts:
144
+ text_input = pipe.tokenizer(prompt, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
145
+ # Move input_ids to the correct device before text encoding
146
+ with torch.no_grad():
147
+ embed = pipe.text_encoder(text_input.input_ids.to(torch_device))[0]
148
+ prompt_embeddings.append(embed)
149
+
150
+ prompt_embedding_a, prompt_embedding_b = prompt_embeddings
151
+
152
+ # Use a device-specific generator for reproducibility
153
+ generator_a = torch.Generator(device=torch_device).manual_seed(seeds[0])
154
+ generator_b = torch.Generator(device=torch_device).manual_seed(seeds[1])
155
+
156
+ init_a = torch.randn((1, pipe.unet.config.in_channels, height // 8, width // 8), device=torch_device, generator=generator_a)
157
+ init_b = torch.randn((1, pipe.unet.config.in_channels, height // 8, width // 8), device=torch_device, generator=generator_b)
158
+
159
+ # --- 3. GENERATION LOOP ---
160
+ frame_paths = []
161
+ for i, t in enumerate(np.linspace(0, 1, num_steps)):
162
+ yield {
163
+ status_text: f"Status: Generating frame {i + 1} of {num_steps} on {torch_device.upper()}...",
164
+ live_frame: None,
165
+ output_video: None,
166
+ }
167
+
168
+ cond_embedding = slerp(float(t), prompt_embedding_a, prompt_embedding_b)
169
+ init = slerp(float(t), init_a, init_b)
170
+
171
+ # Use autocast only if on CUDA
172
+ with autocast(torch_device) if torch_device == "cuda" else open(os.devnull, 'w') as f:
173
+ image = diffuse(pipe, cond_embedding, init, num_inference_steps, guidance_scale, eta)
174
+
175
+ im = Image.fromarray(image)
176
+ outpath = os.path.join(outdir, f'frame{i:06d}.jpg')
177
+ im.save(outpath)
178
+ frame_paths.append(outpath)
179
+
180
+ yield { live_frame: im }
181
+
182
+ # --- 4. VIDEO COMPILATION ---
183
+ yield { status_text: "Status: Compiling video from frames..." }
184
+
185
+ video_path = os.path.join(outdir, f"{name}.mp4")
186
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
187
+ video_writer = cv2.VideoWriter(video_path, fourcc, 15, (width, height))
188
+ for frame_path in frame_paths:
189
+ frame = cv2.imread(frame_path)
190
+ video_writer.write(frame)
191
+ video_writer.release()
192
+
193
+ print(f"Video saved to {video_path}")
194
+ yield {
195
+ status_text: f"Status: Done! Video saved to {video_path}",
196
+ output_video: video_path
197
+ }
198
+
199
+ # -----------------------------------------------------------------------------
200
+ # Gradio UI (Unchanged)
201
+ # -----------------------------------------------------------------------------
202
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
203
+ gr.Markdown("# 🎥 Stable Diffusion Video Interpolation")
204
+ gr.Markdown("Create smooth transition videos between two concepts. Configure the prompts and settings below, then click Generate.")
205
+
206
+ with gr.Row():
207
+ with gr.Column(scale=2):
208
+ with gr.Accordion("1. Core Prompts & Seeds", open=True):
209
+ prompt_1 = gr.Textbox(lines=2, label="Starting Prompt", value="ultrarealistic steam punk neural network machine in the shape of a brain, placed on a pedestal, covered with neurons made of gears.")
210
+ seed_1 = gr.Number(label="Seed 1", value=243, precision=0, info="A specific number to control the starting noise pattern.")
211
+ prompt_2 = gr.Textbox(lines=2, label="Ending Prompt", value="A bioluminescent, glowing jellyfish floating in a dark, deep abyss, surrounded by sparkling plankton.")
212
+ seed_2 = gr.Number(label="Seed 2", value=523, precision=0, info="A specific number to control the ending noise pattern.")
213
+ name = gr.Textbox(label="Output File Name", value="my_dream_video", info="The name for the output folder and .mp4 file.")
214
+
215
+ with gr.Accordion("2. Generation Parameters", open=True):
216
+ with gr.Row():
217
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
218
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
219
+ num_steps = gr.Slider(label="Interpolation Frames", minimum=10, maximum=500, value=120, step=1, info="How many frames the final video will have. More frames = smoother video.")
220
+
221
+ with gr.Accordion("3. Advanced Diffusion Settings", open=False):
222
+ num_inference_steps = gr.Slider(label="Inference Steps per Frame", minimum=10, maximum=100, value=40, step=1, info="More steps can improve quality but will be much slower.")
223
+ guidance_scale = gr.Slider(label="Guidance Scale (CFG)", minimum=1, maximum=20, value=7.5, step=0.5, info="How strongly the prompt guides the image generation.")
224
+ eta = gr.Slider(label="ETA (for DDIM Scheduler)", minimum=0.0, maximum=1.0, value=0.0, step=0.1, info="A parameter for noise scheduling. 0.0 is deterministic.")
225
+
226
+ run_button = gr.Button("Generate Video", variant="primary")
227
+
228
+ with gr.Column(scale=3):
229
+ status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
230
+ live_frame = gr.Image(label="Live Preview", type="pil")
231
+ output_video = gr.Video(label="Final Video")
232
+
233
+ run_button.click(
234
+ fn=generate_dream_video,
235
+ inputs=[
236
+ prompt_1, prompt_2, seed_1, seed_2,
237
+ width, height, num_steps, guidance_scale,
238
+ num_inference_steps, eta, name
239
+ ],
240
+ outputs=[status_text, live_frame, output_video]
241
+ )
242
+
243
+ # --- Launch the App ---
244
+ if __name__ == "__main__":
245
+ demo.launch(share=True, debug=True)