File size: 1,801 Bytes
71d3bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

from typing import List, Tuple

import torch
from torch import nn

def preprocess(image_a: torch.Tensor, image_b: torch.Tensor, image_encoder: nn.Module, clip_mean: torch.Tensor,
            clip_std: torch.Tensor, should_drop_cond: List[Tuple[bool, bool]] = None, concat_hidden_states=None,
            image_list=None):
    with torch.no_grad():
        image_list = [] if image_list is None else image_list
        additional_list = []
        if image_a is not None:
            additional_list.append(image_a)
        if image_b is not None:
            additional_list.append(image_b)
        image_list = additional_list + image_list
        embeds_list = []
        for image in image_list:
            # If already is vector skip encoder
            if len(image.shape) == 2:
                image_embeds = image
            else:
                encoder_outs = image_encoder(image, output_hidden_states=False)
                image_embeds = encoder_outs.image_embeds
            image_embeds = (image_embeds - clip_mean) / clip_std
            embeds_list.append(image_embeds.unsqueeze(1))
        if should_drop_cond is not None:
            for b_ind in range(embeds_list[0].shape[0]):
                should_drop_a, should_drop_b = should_drop_cond[b_ind]
                if should_drop_a:
                    embeds_list[0][b_ind] = torch.zeros_like(embeds_list[0][b_ind])
                if should_drop_b and image_b is not None:
                    embeds_list[1][b_ind] = torch.zeros_like(embeds_list[1][b_ind])
        if concat_hidden_states is not None:
            embeds_list.append(concat_hidden_states)
        out_hidden_states = torch.concat(embeds_list, dim=1)

        image_embeds = torch.zeros_like(embeds_list[0].squeeze(1))

    return image_embeds, out_hidden_states