g-ronimo commited on
Commit
c0e1760
1 Parent(s): d6748e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -115
app.py CHANGED
@@ -1,62 +1,237 @@
1
- import gradio as gr
 
 
 
 
2
  import numpy as np
 
 
 
 
 
 
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
- import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
 
51
- return image, seed
 
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
 
 
 
 
60
  css = """
61
  #col-container {
62
  margin: 0 auto;
@@ -66,89 +241,34 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
 
 
 
 
 
 
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
  import numpy as np
7
+ from PIL import Image, ImageSequence
8
+ from moviepy.editor import VideoFileClip
9
+ import imageio
10
+ import uuid
11
+
12
+ import gradio as gr
13
  import random
14
 
15
  # import spaces #[uncomment to use ZeroGPU]
16
  from diffusers import DiffusionPipeline
 
17
 
18
+ from diffusers import (
19
+ TextToVideoSDPipeline,
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ DDIMScheduler,
23
+ UNet3DConditionModel,
24
+ )
25
+ import time
26
+ from transformers import CLIPTokenizer, CLIPTextModel
27
+
28
+ from diffusers.utils import export_to_video
29
+ from gifs_filter import filter
30
+ from invert_utils import ddim_inversion as dd_inversion
31
+ from text2vid_modded_full import TextToVideoSDPipelineModded
32
+
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ dtype = torch.bfloat16
35
+ LORA_CHECKPOINT = "checkpoint-2500"
36
+
37
+ def cleanup_old_files(directory, age_in_seconds = 600):
38
+ """
39
+ Deletes files older than a certain age in the specified directory.
40
+
41
+ Args:
42
+ directory (str): The directory to clean up.
43
+ age_in_seconds (int): The age in seconds; files older than this will be deleted.
44
+ """
45
+ now = time.time()
46
+ for filename in os.listdir(directory):
47
+ file_path = os.path.join(directory, filename)
48
+ # Only delete files (not directories)
49
+ if os.path.isfile(file_path):
50
+ file_age = now - os.path.getmtime(file_path)
51
+ if file_age > age_in_seconds:
52
+ try:
53
+ os.remove(file_path)
54
+ print(f"Deleted old file: {file_path}")
55
+ except Exception as e:
56
+ print(f"Error deleting file {file_path}: {e}")
57
+
58
+ def load_frames(image: Image, mode='RGBA'):
59
+ return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
60
 
61
+ def save_gif(frames, path):
62
+ imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10)
 
 
63
 
64
+ def load_image(imgname, target_size=None):
65
+ pil_img = Image.open(imgname).convert('RGB')
66
+ if target_size:
67
+ if isinstance(target_size, int):
68
+ target_size = (target_size, target_size)
69
+ pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
70
+ return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension
71
 
72
+ def prepare_latents(pipe, x_aug):
73
+ with torch.cuda.amp.autocast():
74
+ batch_size, num_frames, channels, height, width = x_aug.shape
75
+ x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
76
+ latents = pipe.vae.encode(x_aug).latent_dist.sample()
77
+ latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
78
+ latents = latents.permute(0, 2, 1, 3, 4)
79
+ return pipe.vae.config.scaling_factor * latents
80
+
81
+ @torch.no_grad()
82
+ def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
83
+ input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
84
+ input_img = torch.cat(input_img, dim=1)
85
+ latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
86
+ inv.set_timesteps(25)
87
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
88
+ return torch.mean(id_latents, dim=2, keepdim=True)
89
+
90
+ def load_primary_models(pretrained_model_path):
91
+ return (
92
+ DDPMScheduler.from_config(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/scheduler"),
93
+ CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/tokenizer"),
94
+ CLIPTextModel.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/text_encoder"),
95
+ AutoencoderKL.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/vae"),
96
+ UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/unet"),
97
+ )
98
+
99
+
100
+ def initialize_pipeline(model: str, device: str = "cuda"):
101
+ with warnings.catch_warnings():
102
+ warnings.simplefilter("ignore")
103
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
104
+ pipe = TextToVideoSDPipeline.from_pretrained(
105
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
106
+ scheduler=scheduler,
107
+ tokenizer=tokenizer,
108
+ text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
109
+ vae=vae.to(device=device, dtype=torch.bfloat16),
110
+ unet=unet.to(device=device, dtype=torch.bfloat16),
111
+ )
112
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
113
+ return pipe, pipe.scheduler
114
+
115
+ @torch.no_grad()
116
+ def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
117
+ pipe_inversion.to(device)
118
+ id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
119
+ latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
120
+ generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
121
+ video_frames = pipe(
122
+ prompt=caption,
123
+ negative_prompt="",
124
+ num_frames=num_frames,
125
+ num_inference_steps=25,
126
+ inv_latents=latents,
127
+ guidance_scale=9,
128
+ generator=generator,
129
+ lambda_=lambda_,
130
+ ).frames
131
+ try:
132
+ load_name = load_name.split("/")[-1]
133
+ except:
134
+ pass
135
+ gifs = []
136
+ for seed in range(num_seeds):
137
+ vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4"
138
+ gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif"
139
+ video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
140
+ VideoFileClip(vid_name).write_gif(gif_name)
141
+ with Image.open(gif_name) as im:
142
+ frames = load_frames(im)
143
+
144
+ frames_collect = np.empty((0, 1024, 1024), int)
145
+ for frame in frames:
146
+ frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
147
+ frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
148
+
149
+ _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
150
+
151
+ frames_collect = np.append(frames_collect, [frame], axis=0)
152
+
153
+ save_gif(frames_collect, gif_name)
154
+ gifs.append(gif_name)
155
+
156
+ return gifs
157
+
158
+
159
+ def generate_gifs(filepath, prompt, num_seeds=5, lambda_=0):
160
+ exp_dir = "static/app_tmp"
161
+ os.makedirs(exp_dir, exist_ok=True)
162
+ gifs = process(
163
+ num_frames=10,
164
+ num_seeds=num_seeds,
165
+ generator=None,
166
+ exp_dir=exp_dir,
167
+ load_name=filepath,
168
+ caption=prompt,
169
+ lambda_=lambda_
170
+ )
171
+ return gifs
172
+
173
+ pipe_inversion, inv = initialize_pipeline("Hmrishav/t2v_sketch-lora", device)
174
+ pipe = TextToVideoSDPipelineModded.from_pretrained(
175
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
176
+ scheduler=pipe_inversion.scheduler,
177
+ tokenizer=pipe_inversion.tokenizer,
178
+ text_encoder=pipe_inversion.text_encoder,
179
+ vae=pipe_inversion.vae,
180
+ unet=pipe_inversion.unet,
181
+ ).to(device)
182
 
183
 
184
  # @spaces.GPU #[uncomment to use ZeroGPU]
185
  def infer(
186
  prompt,
187
+ image,
188
+ num_gifs,
189
+ num_frames,
190
+ lambda_value,
 
 
 
191
  progress=gr.Progress(track_tqdm=True),
192
  ):
193
+ if image is None:
194
+ raise gr.Error("Please provide an image to animate.")
195
+ directories_to_clean = [
196
+ 'static/app_tmp/mp4_logs',
197
+ 'static/app_tmp/gif_logs',
198
+ 'static/app_tmp/png_logs'
199
+ ]
200
+
201
+ # Perform cleanup
202
+ os.makedirs('static/app_tmp', exist_ok=True)
203
+ for directory in directories_to_clean:
204
+ os.makedirs(directory, exist_ok=True) # Ensure the directory exists
205
+ cleanup_old_files(directory)
 
206
 
207
+ # Save the uploaded image
208
+ unique_id = str(uuid.uuid4())
209
+ os.makedirs('upload', exist_ok=True)
210
+ filepath = os.path.join("upload", f"{unique_id}_uploaded_image.png")
211
+ image.save(filepath)
212
 
213
+ exp_dir = "static/app_tmp"
214
+ os.makedirs(exp_dir, exist_ok=True)
215
+ generated_gifs = process(
216
+ num_frames=num_frames,
217
+ num_seeds=num_gifs,
218
+ generator=None,
219
+ exp_dir=exp_dir,
220
+ load_name=filepath,
221
+ caption=prompt,
222
+ lambda_=lambda_value
223
+ )
224
 
225
+ unique_id = str(uuid.uuid4())
226
+ for i in range(len(generated_gifs)):
227
+ os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif")
228
+ generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif"
229
+ # Move the generated gifs to the static folder
230
 
231
+ filtered_gifs = filter(generated_gifs, filepath)
232
+ print(filtered_gifs)
233
+ return filtered_gifs[0]
234
+
235
  css = """
236
  #col-container {
237
  margin: 0 auto;
 
241
 
242
  with gr.Blocks(css=css) as demo:
243
  with gr.Column(elem_id="col-container"):
244
+ gr.Markdown(" # FlipSketch")
245
 
246
  with gr.Row():
247
+ with gr.Column():
248
+ image = gr.Image(label="Upload your image", type="pil")
249
+ prompt = gr.Text(
250
+ label="Prompt",
251
+ show_label=False,
252
+ max_lines=1,
253
+ placeholder="Enter your prompt",
254
+ container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
256
+
257
+ with gr.Accordion("Advanced options", open=False):
258
+ num_gifs = gr.Slider(label="num_gifs", value=3, minimum=1, maximum=10, step=1)
259
+ num_frames = gr.Slider(label="num_frames", value=10, minimum=5, maximum=50, step=1)
260
+ lambda_value = gr.Slider(label="lambda", value=0, minimum=0, maximum=1, step=0.1)
261
+
262
+ run_button = gr.Button("Run", scale=0, variant="primary")
263
 
264
+ result = gr.Image(label="Result", elem_id="result", show_label=False, visible=True, type="filepath")
 
 
 
 
 
 
265
 
266
+ # gr.Examples(examples=examples, inputs=[prompt])
267
  gr.on(
268
  triggers=[run_button.click, prompt.submit],
269
  fn=infer,
270
+ inputs=[prompt, image, num_gifs, num_frames, lambda_value],
271
+ outputs=[result],
 
 
 
 
 
 
 
 
 
272
  )
273
 
274
+ demo.launch(share=False)