File size: 4,584 Bytes
e8c4ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97ec1af
 
 
 
e8c4ed3
97ec1af
 
e8c4ed3
97ec1af
e8c4ed3
 
 
97ec1af
 
e8c4ed3
97ec1af
e8c4ed3
97ec1af
 
e8c4ed3
 
 
 
 
 
97ec1af
 
e8c4ed3
 
 
 
 
97ec1af
e8c4ed3
 
97ec1af
e8c4ed3
97ec1af
e8c4ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
from typing import Optional
from absl import flags
from functools import partial

import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from transformers import FlaxCLIPModel

from nerf import utils

FLAGS = flags.FLAGS

@partial(jax.jit, static_argnums=[0])
def semantic_loss(clip_model, src_image, target_embedding): 
    c_image = utils.unshard(src_image[0])
    f_image = utils.unshard(src_image[1])
    w = int(math.sqrt(f_image.shape[0]))
    c_image = c_image.reshape([w, w, 3])
    f_image = f_image.reshape([w, w, 3])
     
    src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
    src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
    sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
    return sc_loss, f_image

def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
    random_rays = batch["random_rays"]
    target_embedding = batch["embedding"]
    rng, key_0, key_1 = random.split(rng,3)
    
    def loss_fn(variables):
        images = render_pfn(variables, key_0, key_1, random_rays)
        sc_loss, f_image = semantic_loss(clip_model, images, target_embedding)
        return sc_loss * FLAGS.sc_loss_mult, f_image
    (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
    return sc_loss, grad, src_image

@partial(jax.jit, static_argnums=[0, 1])
def semantic_step_single(model, clip_model, rng, state, batch, lr):
    random_rays = jax.tree_map(lambda x: x.reshape(-1,3), batch["random_rays"])
    target_embedding = batch["embedding"]
    rng, key_0, key_1 = random.split(rng,3)

    def semantic_loss(variables):
        c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
        w = int(math.sqrt(f_image.shape[0]))
        c_image = c_image.reshape([w, w, 3])
        f_image = f_image.reshape([w, w, 3])

        src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
        src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
        sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
        return sc_loss * FLAGS.sc_loss_mult, f_image
    (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
    return sc_loss, grad, src_image

def trans_t(t):
    return jnp.array([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1]], dtype=jnp.float32)

def rot_phi(phi):
    return jnp.array([
        [1, 0, 0, 0],
        [0, jnp.cos(phi), jnp.sin(phi), 0],
        [0,-jnp.sin(phi), jnp.cos(phi), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)

def rot_theta(th):
    return jnp.array([
        [jnp.cos(th), 0,-jnp.sin(th), 0],
        [0, 1, 0, 0],
        [jnp.sin(th), 0, jnp.cos(th), 0],
        [0, 0, 0, 1]], dtype=jnp.float32)

def pose_spherical(radius, theta, phi):
    c2w = trans_t(radius)
    c2w = rot_phi(phi) @ c2w
    c2w = rot_theta(theta) @ c2w
    c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
    return c2w

def random_pose(rng, bds):
    rng, *rng_inputs = jax.random.split(rng, 3)
    radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
    theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
    phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
    return pose_spherical(radius, theta, phi)

def preprocess_for_CLIP(image):
    """
    jax-based preprocessing for CLIP
    image  [B, 3, H, W]: batch image
    return [B, 3, 224, 224]: pre-processed image for CLIP
    """
    B, D, H, W = image.shape
    mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
    std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
    image = jax.image.resize(image, (B, D, 224, 224), 'bicubic')  # assume that images have rectangle shape.
    image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
    return image

def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
    if dtype == 'float16':
        dtype = jnp.float16
    elif dtype == 'float32':
        dtype = jnp.float32
    else:
        raise ValueError

    if model_name is None:
        model_name = 'openai/clip-vit-base-patch32'

    return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)