BenkHel commited on
Commit
92944bc
·
verified ·
1 Parent(s): 83b7d53

Upload llava_llama.py

Browse files
cumo/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava_llama"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_llama", LlavaConfig)
158
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)