BenkHel commited on
Commit
83b7d53
·
verified ·
1 Parent(s): 9ae17c6

Delete cumo/model/language_model/llava_llama.py

Browse files
cumo/model/language_model/llava_llama.py DELETED
@@ -1,159 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from transformers import AutoConfig, AutoModelForCausalLM, \
22
- LlamaConfig, LlamaModel, LlamaForCausalLM
23
-
24
- from transformers.modeling_outputs import CausalLMOutputWithPast
25
- from transformers.generation.utils import GenerateOutput
26
-
27
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
-
29
-
30
- class LlavaConfig(LlamaConfig):
31
- model_type = "llava_llama"
32
-
33
-
34
- class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
- config_class = LlavaConfig
36
-
37
- def __init__(self, config: LlamaConfig):
38
- super(LlavaLlamaModel, self).__init__(config)
39
-
40
-
41
- class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
- config_class = LlavaConfig
43
-
44
- def __init__(self, config):
45
- super(LlamaForCausalLM, self).__init__(config)
46
- self.model = LlavaLlamaModel(config)
47
- self.pretraining_tp = config.pretraining_tp
48
- self.vocab_size = config.vocab_size
49
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
-
51
- # Initialize weights and apply final processing
52
- self.post_init()
53
-
54
- def get_model(self):
55
- return self.model
56
-
57
- def forward(
58
- self,
59
- input_ids: torch.LongTensor = None,
60
- attention_mask: Optional[torch.Tensor] = None,
61
- position_ids: Optional[torch.LongTensor] = None,
62
- past_key_values: Optional[List[torch.FloatTensor]] = None,
63
- inputs_embeds: Optional[torch.FloatTensor] = None,
64
- labels: Optional[torch.LongTensor] = None,
65
- use_cache: Optional[bool] = None,
66
- output_attentions: Optional[bool] = None,
67
- output_hidden_states: Optional[bool] = None,
68
- images: Optional[torch.FloatTensor] = None,
69
- image_sizes: Optional[List[List[int]]] = None,
70
- return_dict: Optional[bool] = None,
71
- ) -> Union[Tuple, CausalLMOutputWithPast]:
72
-
73
- if inputs_embeds is None:
74
- (
75
- input_ids,
76
- position_ids,
77
- attention_mask,
78
- past_key_values,
79
- inputs_embeds,
80
- labels
81
- ) = self.prepare_inputs_labels_for_multimodal(
82
- input_ids,
83
- position_ids,
84
- attention_mask,
85
- past_key_values,
86
- labels,
87
- images,
88
- image_sizes
89
- )
90
-
91
- return super().forward(
92
- input_ids=input_ids,
93
- attention_mask=attention_mask,
94
- position_ids=position_ids,
95
- past_key_values=past_key_values,
96
- inputs_embeds=inputs_embeds,
97
- labels=labels,
98
- use_cache=use_cache,
99
- output_attentions=output_attentions,
100
- output_hidden_states=output_hidden_states,
101
- return_dict=return_dict
102
- )
103
-
104
- @torch.no_grad()
105
- def generate(
106
- self,
107
- inputs: Optional[torch.Tensor] = None,
108
- images: Optional[torch.Tensor] = None,
109
- image_sizes: Optional[torch.Tensor] = None,
110
- **kwargs,
111
- ) -> Union[GenerateOutput, torch.LongTensor]:
112
- position_ids = kwargs.pop("position_ids", None)
113
- attention_mask = kwargs.pop("attention_mask", None)
114
- if "inputs_embeds" in kwargs:
115
- raise NotImplementedError("`inputs_embeds` is not supported")
116
-
117
- if images is not None:
118
- (
119
- inputs,
120
- position_ids,
121
- attention_mask,
122
- past_key_values,
123
- inputs_embeds,
124
- labels,
125
- *_
126
- ) = self.prepare_inputs_labels_for_multimodal(
127
- inputs,
128
- position_ids,
129
- attention_mask,
130
- None,
131
- None,
132
- images,
133
- image_sizes=image_sizes
134
- )
135
- else:
136
- inputs_embeds = self.get_model().embed_tokens(inputs)
137
-
138
- return super().generate(
139
- position_ids=position_ids,
140
- attention_mask=attention_mask,
141
- inputs_embeds=inputs_embeds,
142
- **kwargs
143
- )
144
-
145
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146
- inputs_embeds=None, **kwargs):
147
- images = kwargs.pop("images", None)
148
- image_sizes = kwargs.pop("image_sizes", None)
149
- inputs = super().prepare_inputs_for_generation(
150
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151
- )
152
- if images is not None:
153
- inputs['images'] = images
154
- if image_sizes is not None:
155
- inputs['image_sizes'] = image_sizes
156
- return inputs
157
-
158
- AutoConfig.register("llava_llama", LlavaConfig)
159
- AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)