SonicDiffusion / app.py
burakcanbiner's picture
Update app.py
5dcc4b2 verified
raw
history blame
10.6 kB
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/app.py
import spaces
import os
import json
import torch
import random
import gradio as gr
from glob import glob
from omegaconf import OmegaConf
from datetime import datetime
from safetensors import safe_open
from PIL import Image
from unet2d_custom import UNet2DConditionModel
import torch
from pipeline_stable_diffusion_custom import StableDiffusionPipeline
from diffusers import DDIMScheduler
from pnp_utils import *
import torchvision.transforms as T
from preprocess import get_timesteps
from preprocess import Preprocess
from pnp import PNP
sample_idx = 0
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 1.5em;
min-width: 1.5em !important;
height: 1.5em;
}
"""
class AnimateController:
def __init__(self):
self.sr = 44100
self.save_steps = 50
self.device = 'cuda'
self.seed = 42
self.extract_reverse = False
self.save_dir = 'latents'
self.steps = 50
self.inversion_prompt = ''
self.seed = 42
seed_everything(self.seed)
self.pnp = PNP(sd_version="1.4")
self.pnp.unet.to(self.device)
self.pnp.audio_projector.to(self.device)
# audio_projector_path = "ckpts/audio_projector_landscape.pth"
# gate_dict_path = "ckpts/landscape.pt"
# self.pnp.set_audio_projector(gate_dict_path, audio_projector_path)
self.audio_projector_path = None
#"ckpts/audio_projector_landscape.pth"
self.adapter_ckpt_path = None
#"ckpts/landscape.pt"
@spaces.GPU
def preprocess(self, image=None):
model_key = "CompVis/stable-diffusion-v1-4"
toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
toy_scheduler.set_timesteps(self.save_steps)
timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=self.save_steps,
strength=1.0,
device=self.device)
save_path = os.path.join(self.save_dir + "_forward")
os.makedirs(save_path, exist_ok=True)
model = Preprocess(self.device, sd_version='1.4', hf_key=None)
recon_image = model.extract_latents(data_path=image,
num_steps=self.steps,
save_path=save_path,
timesteps_to_save=timesteps_to_save,
inversion_prompt=self.inversion_prompt,
extract_reverse=False)
T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg'))
@spaces.GPU
def generate(self, file=None, audio=None, prompt=None,
cfg_scale=5, image_path=None,
pnp_f_t=0.8, pnp_attn_t=0.8,):
if self.audio_projector_path is None:
print("audio projectore path is nonee")
self.audio_projector_path = "ckpts/audio_projector_landscape.pth"
self.adapter_ckpt_path = "ckpts/landscape.pt"
print(f"before run_pnp {self.audio_projector_path} -- {self.adapter_ckpt_path}")
image = self.pnp.run_pnp(
n_timesteps=50,
pnp_f_t=pnp_f_t, pnp_attn_t=pnp_attn_t,
prompt=prompt,
negative_prompt="",
audio_path=audio,
image_path=image_path,
audio_projector_path = self.audio_projector_path,
adapter_ckpt_path = self.adapter_ckpt_path,
cfg_scale=cfg_scale,
)
return image
def update_audio_model(self, audio_model_update):
print(f"changing ckpts audio model {audio_model_update}")
if audio_model_update == "Landscape Model":
self.audio_projector_path = "ckpts/audio_projector_landscape.pth"
self.adapter_ckpt_path = "ckpts/landscape.pt"
else:
self.audio_projector_path = "ckpts/audio_projector_gh.pth"
self.adapter_ckpt_path = "ckpts/greatest_hits.pt"
print(f"audio_projector_path {self.audio_projector_path} -- {self.adapter_ckpt_path}")
# self.pnp.set_audio_projector(gate_dict_path, audio_projector_path)
# self.pnp.changed_model = True
# gate_dict = torch.load(gate_dict_path)
# for name, param in self.pnp.unet.named_parameters():
# if "adapter" in name:
# param.data = gate_dict[name]
# self.pnp.audio_projector.load_state_dict(torch.load(audio_projector_path))
# self.pnp.unet.to(self.device)
# self.pnp.audio_projector.to(self.device)
return gr.Dropdown()
controller = AnimateController()
def ui():
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# [SonicDiffusion: Audio-Driven Image Generation and Editing with Pretrained Diffusion Models]
"""
)
with gr.Row():
audio_input = gr.Audio(sources="upload", type="filepath")
prompt_textbox = gr.Textbox(label="Prompt", lines=2)
with gr.Row():
with gr.Column():
pnp_f_t = gr.Slider(label="PNP Residual Injection", step=0.1, value=0.8, minimum=0.0, maximum=1.0)
pnp_attn_t = gr.Slider(label="PNP Attention Injection", step=0.1, value=0.8, minimum=0.0, maximum=1.0)
with gr.Column():
audio_model_dropdown = gr.Dropdown(
label="Select SonicDiffusion model",
value="Landscape Model",
choices=["Landscape Model", "Greatest Hits Model"],
interactive=True,
)
audio_model_dropdown.change(fn=controller.update_audio_model, inputs=[audio_model_dropdown], outputs=[audio_model_dropdown])
cfg_scale_slider = gr.Slider(label="CFG Scale", step=0.5, value=7.5, minimum=0, maximum=20)
with gr.Row():
preprocess_button = gr.Button(value="Preprocess", variant='primary')
generate_button = gr.Button(value="Generate", variant='primary')
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image Component", sources="upload", type="filepath")
with gr.Column():
output = gr.Image(label="Output Image Component",
height=512, width=512)
with gr.Row():
examples_img_1 = [
[Image.open("assets/corridor.png")],
[Image.open("assets/desert.png")],
[Image.open("assets/forest.png")],
[Image.open("assets/forest_painting.png")],
[Image.open("assets/golf_field.png")],
[Image.open("assets/human.png")],
[Image.open("assets/wood.png")],
[Image.open("assets/house.png")],
[Image.open("assets/apple.png")],
[Image.open("assets/chair.png")],
[Image.open("assets/hands.png")],
[Image.open("assets/pineapple.png")],
[Image.open("assets/table.png")],
]
gr.Examples(examples=examples_img_1,inputs=[image_input], label="Images")
# # examples_img_2 = [
# # [Image.open("assets/apple.png")],
# # [Image.open("assets/chair.png")],
# # [Image.open("assets/hands.png")],
# # [Image.open("assets/pineapple.png")],
# # [Image.open("assets/table.png")],
# # ]
# # gr.Examples(examples=examples,inputs=[image_input], label="Greatest Hits Images")
examples2 = [
['./assets/fire_crackling.wav'],
['./assets/forest_birds.wav'],
['./assets/forest_stepping_on_branches.wav'],
['./assets/howling_wind.wav'],
['./assets/rain.wav'],
['./assets/splashing_water.wav'],
['./assets/splashing_water_soft.wav'],
['./assets/steps_on_snow.wav'],
['./assets/thunder.wav'],
['./assets/underwater.wav'],
['./assets/waterfall_burble.wav'],
['./assets/wind_noise_birds.wav'],
]
gr.Examples(examples=examples2,inputs=[audio_input], label="Landscape Audios")
examples3 = [
['./assets/cardboard.wav'],
['./assets/carpet.wav'],
['./assets/ceramic.wav'],
['./assets/cloth.wav'],
['./assets/gravel.wav'],
['./assets/leaf.wav'],
['./assets/metal.wav'],
['./assets/plastic_bag.wav'],
['./assets/plastic.wav'],
['./assets/rock.wav'],
['./assets/wood.wav'],
]
gr.Examples(examples=examples3,inputs=[audio_input], label="Greatest Hits Audios")
preprocess_button.click(
fn=controller.preprocess,
inputs=[
image_input
],
outputs=output
)
generate_button.click(
fn=controller.generate,
inputs=[
audio_model_dropdown,
audio_input,
prompt_textbox,
cfg_scale_slider,
image_input,
pnp_f_t,
pnp_attn_t,
],
outputs=output
)
return demo
if __name__ == "__main__":
demo = ui()
demo.launch(share=True)