ImageGuard / model /Internlm.py
adwardlee's picture
Upload folder using huggingface_hub
11b9c8d verified
raw
history blame
3.88 kB
import torch
from model.internlm_xcomposer.configuration_internlm_xcomposer2 import InternLMXcomposer2Config
from model.internlm_xcomposer.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
from model.internlm_xcomposer.tokenization_internlm_xcomposer2 import InternLMXComposer2Tokenizer
from peft import LoraConfig, get_peft_model
from .base import BaseModel
class InternLM(BaseModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def load_model_tokenizer(self):
config = InternLMXcomposer2Config.from_pretrained(self.model_path)
config.use_cache = False
config.max_length = self.training_args.max_length
model = InternLMXComposer2ForCausalLM.from_pretrained(
self.model_path,
config=config,
device_map=None,
use_caption=self.use_caption
)
if self.data_args.img_size != 336:
model.vit.resize_pos()
tokenizer_path = self.model_path if self.lora_args.lora_weight_path == '' \
else self.lora_args.lora_weight_path
tokenizer = InternLMXComposer2Tokenizer.from_pretrained(
tokenizer_path,
padding_side='right',
use_fast=False,
)
self.model = model
self.tokenizer = tokenizer
def configure_training_args(self):
training_args = self.training_args
if training_args.fix_vit:
self.model.vit.requires_grad_(False)
else:
self.model.vit.requires_grad_(True)
self.model.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity()
if training_args.fix_sampler or self.use_caption:
self.model.vision_proj.requires_grad_(False)
else:
self.model.vision_proj.requires_grad_(True)
def configure_peft(self):
if not self.training_args.use_lora:
for name, param in self.model.model.named_parameters():
if 'vision_cross' not in name:
param.requires_grad = False
return
lora_args = self.lora_args
if lora_args.lora_type == 'lora':
for name, param in self.model.model.named_parameters():
if 'vision_cross' in name:
continue
param.requires_grad = False
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type='CAUSAL_LM',
)
self.model = get_peft_model(self.model, lora_config)
elif lora_args.lora_type == 'plora':
for name, param in self.model.model.named_parameters():
if 'Plora' not in name:
param.requires_grad = False
if self.use_caption:
if lora_args.lora_type == 'lora':
self.model.model.vision_proj.requires_grad_(True)
self.model.model.model.tok_embeddings.requires_grad_(True)
self.model.model.logit_scale.requires_grad_(True)
else:
self.model.vision_proj.requires_grad_(True)
self.model.model.tok_embeddings.requires_grad_(True)
self.model.logit_scale.requires_grad_(True)
####
for name, param in self.model.model.named_parameters():
if 'vision_cross' in name:
param.requires_grad = True
if self.training_args.gradient_checkpointing:
self.model.enable_input_require_grads()
# self.model.gradient_checkpointing_enable()
self.model.model.vit.vision_tower.gradient_checkpointing_enable({"use_reentrant": True})