File size: 4,572 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast

from .base import ProcessorMixin


DEFAULT_PROMPT_TEMPLATE = {
    "template": (
        "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
        "1. The main content and theme of the video."
        "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
        "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
        "4. background environment, light, style and atmosphere."
        "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
    ),
    "crop_start": 95,
}


class LlamaProcessor(ProcessorMixin):
    r"""
    Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
    and attention masks for the input text.

    Args:
        output_names (`List[str]`):
            The names of the outputs that the processor should return. The first output is the embeddings of the input
            text and the second output is the attention mask for the input text.
    """

    def __init__(self, output_names: List[str] = None):
        super().__init__()

        self.output_names = output_names

        assert len(output_names) == 2

    def forward(
        self,
        tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast],
        text_encoder: LlamaModel,
        caption: Union[str, List[str]],
        max_sequence_length: int,
        prompt_template: Optional[Dict[str, Any]] = None,
        num_layers_to_skip: int = 2,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Encode the input text and return the embeddings and attention mask for the input text.

        Args:
            tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
                The tokenizer used to tokenize the input text.
            text_encoder (`LlamaModel`):
                The text encoder used to encode the input text.
            caption (`Union[str, List[str]]`):
                The input text to be encoded.
            max_sequence_length (`int`):
                The maximum sequence length of the input text.
            prompt_template (`Optional[Dict[str, Any]]`):
                The prompt template to be used to encode the input text.
        """
        if prompt_template is None:
            prompt_template = DEFAULT_PROMPT_TEMPLATE
        if isinstance(caption, str):
            caption = [caption]

        device = text_encoder.device
        dtype = text_encoder.dtype

        batch_size = len(caption)
        caption = [prompt_template["template"].format(c) for c in caption]

        crop_start = prompt_template.get("crop_start", None)
        if crop_start is None:
            prompt_template_input = tokenizer(
                prompt_template["template"],
                padding="max_length",
                return_tensors="pt",
                return_length=False,
                return_overflowing_tokens=False,
                return_attention_mask=False,
            )
            crop_start = prompt_template_input["input_ids"].shape[-1]
            # Remove <|eot_id|> token and placeholder {}
            crop_start -= 2

        max_sequence_length += crop_start
        text_inputs = tokenizer(
            caption,
            max_length=max_sequence_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            return_length=False,
            return_overflowing_tokens=False,
            return_attention_mask=True,
        )
        text_input_ids = text_inputs.input_ids.to(device)
        prompt_attention_mask = text_inputs.attention_mask.bool().to(device)

        prompt_embeds = text_encoder(
            text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
        ).hidden_states[-(num_layers_to_skip + 1)]
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

        if crop_start is not None and crop_start > 0:
            prompt_embeds = prompt_embeds[:, crop_start:]
            prompt_attention_mask = prompt_attention_mask[:, crop_start:]

        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)

        return {
            self.output_names[0]: prompt_embeds,
            self.output_names[1]: prompt_attention_mask,
        }