Spaces:
Paused
Paused
import os | |
import cv2 | |
import torch | |
import torchvision | |
import warnings | |
import numpy as np | |
from PIL import Image, ImageSequence | |
from moviepy.editor import VideoFileClip | |
import imageio | |
import uuid | |
import gradio as gr | |
import random | |
import spaces #[uncomment to use ZeroGPU] | |
from diffusers import DiffusionPipeline | |
from diffusers import ( | |
TextToVideoSDPipeline, | |
AutoencoderKL, | |
DDPMScheduler, | |
DDIMScheduler, | |
UNet3DConditionModel, | |
) | |
import time | |
from transformers import CLIPTokenizer, CLIPTextModel | |
from diffusers.utils import export_to_video | |
from gifs_filter import filter | |
from invert_utils import ddim_inversion as dd_inversion | |
from text2vid_modded_full import TextToVideoSDPipelineModded | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
dtype = torch.bfloat16 | |
LORA_CHECKPOINT = "checkpoint-2500" | |
def cleanup_old_files(directory, age_in_seconds = 600): | |
""" | |
Deletes files older than a certain age in the specified directory. | |
Args: | |
directory (str): The directory to clean up. | |
age_in_seconds (int): The age in seconds; files older than this will be deleted. | |
""" | |
now = time.time() | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
# Only delete files (not directories) | |
if os.path.isfile(file_path): | |
file_age = now - os.path.getmtime(file_path) | |
if file_age > age_in_seconds: | |
try: | |
os.remove(file_path) | |
print(f"Deleted old file: {file_path}") | |
except Exception as e: | |
print(f"Error deleting file {file_path}: {e}") | |
def load_frames(image: Image, mode='RGBA'): | |
return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)]) | |
def save_gif(frames, path): | |
imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10) | |
def load_image(imgname, target_size=None): | |
pil_img = Image.open(imgname).convert('RGB') | |
if target_size: | |
if isinstance(target_size, int): | |
target_size = (target_size, target_size) | |
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS) | |
return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension | |
def prepare_latents(pipe, x_aug): | |
with torch.cuda.amp.autocast(): | |
batch_size, num_frames, channels, height, width = x_aug.shape | |
x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width) | |
latents = pipe.vae.encode(x_aug).latent_dist.sample() | |
latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3]) | |
latents = latents.permute(0, 2, 1, 3, 4) | |
return pipe.vae.config.scaling_factor * latents | |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16): | |
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5 | |
input_img = torch.cat(input_img, dim=1) | |
latents = prepare_latents(pipe, input_img).to(torch.bfloat16) | |
inv.set_timesteps(25) | |
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype) | |
return torch.mean(id_latents, dim=2, keepdim=True) | |
def load_primary_models(pretrained_model_path): | |
return ( | |
DDPMScheduler.from_config(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/scheduler"), | |
CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/tokenizer"), | |
CLIPTextModel.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/text_encoder"), | |
AutoencoderKL.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/vae"), | |
UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder=LORA_CHECKPOINT + "/unet"), | |
) | |
def initialize_pipeline(model: str, device: str = "cuda"): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model) | |
pipe = TextToVideoSDPipeline.from_pretrained( | |
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
scheduler=scheduler, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16), | |
vae=vae.to(device=device, dtype=torch.bfloat16), | |
unet=unet.to(device=device, dtype=torch.bfloat16), | |
) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
return pipe, pipe.scheduler | |
def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_): | |
pipe_inversion.to(device) | |
id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype) | |
latents = id_latents.repeat(num_seeds, 1, 1, 1, 1) | |
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)] | |
video_frames = pipe( | |
prompt=caption, | |
negative_prompt="", | |
num_frames=num_frames, | |
num_inference_steps=25, | |
inv_latents=latents, | |
guidance_scale=9, | |
generator=generator, | |
lambda_=lambda_, | |
).frames | |
try: | |
load_name = load_name.split("/")[-1] | |
except: | |
pass | |
gifs = [] | |
for seed in range(num_seeds): | |
vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4" | |
gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif" | |
video_path = export_to_video(video_frames[seed], output_video_path=vid_name) | |
VideoFileClip(vid_name).write_gif(gif_name) | |
with Image.open(gif_name) as im: | |
frames = load_frames(im) | |
frames_collect = np.empty((0, 1024, 1024), int) | |
for frame in frames: | |
frame = cv2.resize(frame, (1024, 1024))[:, :, :3] | |
frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY) | |
_, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
frames_collect = np.append(frames_collect, [frame], axis=0) | |
save_gif(frames_collect, gif_name) | |
gifs.append(gif_name) | |
return gifs | |
pipe_inversion, inv = initialize_pipeline("Hmrishav/t2v_sketch-lora", device) | |
pipe = TextToVideoSDPipelineModded.from_pretrained( | |
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
scheduler=pipe_inversion.scheduler, | |
tokenizer=pipe_inversion.tokenizer, | |
text_encoder=pipe_inversion.text_encoder, | |
vae=pipe_inversion.vae, | |
unet=pipe_inversion.unet, | |
).to(device) | |
def infer( | |
prompt, | |
image, | |
num_gifs, | |
num_frames, | |
lambda_value, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if image is None: | |
raise gr.Error("Please provide an image to animate.") | |
directories_to_clean = [ | |
'static/app_tmp/mp4_logs', | |
'static/app_tmp/gif_logs', | |
'static/app_tmp/png_logs' | |
] | |
# Perform cleanup | |
os.makedirs('static/app_tmp', exist_ok=True) | |
for directory in directories_to_clean: | |
os.makedirs(directory, exist_ok=True) # Ensure the directory exists | |
cleanup_old_files(directory) | |
# Save the uploaded image | |
unique_id = str(uuid.uuid4()) | |
os.makedirs('upload', exist_ok=True) | |
filepath = os.path.join("upload", f"{unique_id}_uploaded_image.png") | |
image.save(filepath) | |
exp_dir = "static/app_tmp" | |
os.makedirs(exp_dir, exist_ok=True) | |
generated_gifs = process( | |
num_frames=num_frames, | |
num_seeds=num_gifs, | |
generator=None, | |
exp_dir=exp_dir, | |
load_name=filepath, | |
caption=prompt, | |
lambda_=lambda_value | |
) | |
unique_id = str(uuid.uuid4()) | |
for i in range(len(generated_gifs)): | |
os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif") | |
generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif" | |
# Move the generated gifs to the static folder | |
filtered_gifs = filter(generated_gifs, filepath) | |
print(filtered_gifs) | |
return filtered_gifs[0] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(" # FlipSketch") | |
gr.Markdown("https://github.com/hmrishavbandy/FlipSketch") | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Upload your image", type="pil") | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
with gr.Accordion("Advanced options", open=False): | |
num_gifs = gr.Slider(label="num_gifs", value=1, minimum=1, maximum=10, step=1) | |
num_frames = gr.Slider(label="num_frames", value=10, minimum=5, maximum=50, step=1) | |
lambda_value = gr.Slider(label="lambda", value=0, minimum=0, maximum=1, step=0.1) | |
run_button = gr.Button("Run", scale=0, variant="primary") | |
result = gr.Image(label="Result", elem_id="result", show_label=False, visible=True, type="filepath") | |
# gr.Examples(examples=examples, inputs=[prompt]) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[prompt, image, num_gifs, num_frames, lambda_value], | |
outputs=[result], | |
) | |
demo.launch() | |