import transformers | |
from dataclasses import dataclass, field | |
from typing import List, Optional | |
class ModelArguments: | |
model_name_or_path: Optional[str] = field(default='') | |
class DataArguments: | |
given_num: bool = False | |
img_size: int = 490 | |
hd_num: int = -1 | |
data_cfg: str = '' | |
data_version: int = 3 | |
class TrainingArguments(transformers.TrainingArguments): | |
cache_dir: Optional[str] = field(default=None) | |
optim: str = field(default='adamw_torch') | |
max_length: int = field( | |
default=4096, | |
metadata={ | |
'help': | |
'Maximum sequence length. Sequences will be right padded (and possibly truncated).' | |
}, | |
) | |
use_lora: bool = False | |
fix_vit: bool = True | |
fix_sampler: bool = False | |
# eval_flag: int = 0 | |
label_names: List[str] = field(default_factory=lambda: ['samples']) | |
seed: int = 3407 | |
gradient_checkpointing: bool = True | |
class LoraArguments: | |
lora_r: int = 64 | |
lora_alpha: int = 64 | |
lora_dropout: float = 0.05 | |
### for internlm ### | |
lora_target_modules: List[str] = field(default_factory=lambda: [ | |
'attention.wqkv', | |
'attention.wo', | |
'feed_forward.w1', | |
'feed_forward.w2', | |
'feed_forward.w3', | |
]) | |
#### for idefics2 ### | |
# lora_target_modules: List[str] = field(default_factory=lambda: [ | |
# 'self_attn.q_proj', | |
# 'self_attn.k_proj', | |
# 'self_attn.v_proj', | |
# 'self_attn.o_proj', | |
# 'mlp.gate_proj', | |
# 'mlp.up_proj', | |
# 'mlp.down_proj', | |
# ]) | |
lora_weight_path: str = '' | |
lora_bias: str = 'none' | |
lora_type: str = 'lora' | |
class EvalArguments: | |
max_length: int = field( | |
default=4096, | |
metadata={ | |
'help': | |
'Maximum sequence length. Sequences will be right padded (and possibly truncated).' | |
}, | |
) | |
use_lora: bool = False | |
fix_vit: bool = True | |
fix_sampler: bool = True | |
# eval_flag: int = 0 | |
label_names: List[str] = field(default_factory=lambda: ['samples']) | |
gradient_checkpointing: bool = False |