et-viewer / app.py
azizinaghsh's picture
change character position
bbcccf9
raw
history blame
7.45 kB
import spaces
from functools import partial
from typing import Any, Callable, Dict
import clip
import gradio as gr
from gradio_rerun import Rerun
import numpy as np
import trimesh
import rerun as rr
import torch
from utils.common_viz import init, get_batch
from utils.random_utils import set_random_seed
from utils.rerun import log_sample
from src.diffuser import Diffuser
from src.datasets.multimodal_dataset import MultimodalDataset
# ------------------------------------------------------------------------------------- #
batch_size, num_cams, num_verts = None, None, None
SAMPLE_IDS = [
"2011_KAeAqaA0Llg_00005_00001",
"2011_F_EuMeT2wBo_00014_00001",
"2011_MCkKihQrNA4_00014_00000",
]
LABEL_TO_IDS = {
"right": 0,
"static": 1,
"complex": 2,
}
EXAMPLES = [
"While the character moves right, the camera trucks right.",
"While the character moves right, the camera performs a push in.",
"While the character moves right, the camera performs a pull out.",
"Movement: shortArcShotRight Easing: easeInOutQuad Frames: 30 Camera Angle: birdsEyeView Shot Type: mediumShot Subject Index: 0",
"Movement: fullZoomIn Easing: easeInOutSine Frames: 30 Camera Angle: highAngle Shot Type: closeUp",
"Movement: pedestalDown Easing: easeOutExpo Frames: 30 Camera Angle: mediumAngle Shot Type: longShot", # noqa
"Movement: dollyIn Easing: easeOutBounce Frames: 30 Camera Angle: mediumAngle Shot Type: longShot", # noqa
]
DEFAULT_TEXT = [
"While the character moves right, the camera [...].",
"Movement: dollyIn Easing: easeOutBounce Frames: 30 [...].",
"Movement: shortArcShotRight Easing: easeInOutQuad [...]. "
"Movement: fullZoomIn Easing: easeInOutSine [...].",
]
HEADER = """
<div align="center">
<h1 style='text-align: center'>E.T. the Exceptional Trajectories (Static Character Pose</h2>
<a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a>
<a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a>
<a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a>
<a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a>
<a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a>
</div>
<div align="center">
<a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
</div>
<br/>
"""
# ------------------------------------------------------------------------------------- #
def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
num_frames, num_faces = vertices.shape[0], faces.shape[-2]
faces = faces.expand(num_frames, num_faces, 3)
normals = [
trimesh.Trimesh(vertices=v, faces=f, process=False).vertex_normals
for v, f in zip(vertices, faces)
]
normals = torch.from_numpy(np.stack(normals))
return normals
@spaces.GPU
def generate(
prompt: str,
seed: int,
guidance_weight: float,
character_position: list,
# ----------------------- #
dataset: MultimodalDataset,
device: torch.device,
diffuser: Diffuser,
clip_model: clip.model.CLIP,
) -> Dict[str, Any]:
diffuser.to(device)
clip_model.to(device)
# Set arguments
set_random_seed(seed)
diffuser.gen_seeds = np.array([seed])
diffuser.guidance_weight = guidance_weight
# Inference
sample_id = SAMPLE_IDS[0] # Default to the first sample ID
seq_feat = diffuser.net.model.clip_sequential
batch = get_batch(prompt, sample_id, character_position, clip_model, dataset, seq_feat, device)
with torch.no_grad():
out = diffuser.predict_step(batch, 0)
# Run visualization
padding_mask = out["padding_mask"][0].to(bool).cpu()
padded_traj = out["gen_samples"][0].cpu()
traj = padded_traj[padding_mask]
char_traj = out["char_feat"][0].cpu()
padded_vertices = out["char_raw"]["char_vertices"][0]
vertices = padded_vertices[padding_mask]
faces = out["char_raw"]["char_faces"][0]
normals = get_normals(vertices, faces)
fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy()
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
caption = out["caption_raw"][0]
rr.init(f"{sample_id}")
rr.save(".tmp_gr.rrd")
log_sample(
root_name="world",
traj=traj.numpy(),
char_traj=char_traj.numpy(),
K=K,
vertices=vertices.numpy(),
faces=faces.numpy(),
normals=normals.numpy(),
caption=caption,
mesh_masks=None,
)
return "./.tmp_gr.rrd"
# ------------------------------------------------------------------------------------- #
def launch_app(gen_fn: Callable):
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
with gr.Blocks(theme=theme) as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column(scale=3):
with gr.Column(scale=2):
char_position = gr.Textbox(
placeholder="Enter character position as [x, y, z]",
show_label=True,
label="Character Position (3D vector)",
value="[0.0, 0.0, 0.0]",
)
text = gr.Textbox(
placeholder="Type the camera motion you want to generate",
show_label=True,
label="Text prompt",
value=DEFAULT_TEXT[0],
)
seed = gr.Number(value=33, label="Seed")
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
with gr.Column(scale=1):
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
examples = gr.Examples(
examples=[[x, None, None] for x in EXAMPLES],
inputs=[text],
)
with gr.Row():
output = Rerun()
def load_example(example_id):
processed_example = examples.non_none_processed_examples[example_id]
return gr.utils.resolve_singleton(processed_example)
inputs = [text, seed, guidance, char_position]
examples.dataset.click(
load_example,
inputs=[examples.dataset],
outputs=examples.inputs_with_examples,
show_progress=False,
postprocess=False,
queue=False,
).then(fn=gen_fn, inputs=inputs, outputs=[output])
btn.click(fn=gen_fn, inputs=inputs, outputs=[output])
text.submit(fn=gen_fn, inputs=inputs, outputs=[output])
demo.queue().launch(share=False)
# ------------------------------------------------------------------------------------- #
diffuser, clip_model, dataset, device = init("config")
generate_sample = partial(
generate,
dataset=dataset,
device=device,
diffuser=diffuser,
clip_model=clip_model,
)
launch_app(generate_sample)