ImageGuard / utils /arguments.py
adwardlee's picture
Upload folder using huggingface_hub
2e1316e verified
raw
history blame
2.14 kB
import transformers
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default='')
@dataclass
class DataArguments:
given_num: bool = False
img_size: int = 490
hd_num: int = -1
data_cfg: str = ''
data_version: int = 3
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default='adamw_torch')
max_length: int = field(
default=4096,
metadata={
'help':
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
},
)
use_lora: bool = False
fix_vit: bool = True
fix_sampler: bool = False
# eval_flag: int = 0
label_names: List[str] = field(default_factory=lambda: ['samples'])
seed: int = 3407
gradient_checkpointing: bool = True
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.05
### for internlm ###
lora_target_modules: List[str] = field(default_factory=lambda: [
'attention.wqkv',
'attention.wo',
'feed_forward.w1',
'feed_forward.w2',
'feed_forward.w3',
])
#### for idefics2 ###
# lora_target_modules: List[str] = field(default_factory=lambda: [
# 'self_attn.q_proj',
# 'self_attn.k_proj',
# 'self_attn.v_proj',
# 'self_attn.o_proj',
# 'mlp.gate_proj',
# 'mlp.up_proj',
# 'mlp.down_proj',
# ])
lora_weight_path: str = ''
lora_bias: str = 'none'
lora_type: str = 'lora'
@dataclass
class EvalArguments:
max_length: int = field(
default=4096,
metadata={
'help':
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
},
)
use_lora: bool = False
fix_vit: bool = True
fix_sampler: bool = True
# eval_flag: int = 0
label_names: List[str] = field(default_factory=lambda: ['samples'])
gradient_checkpointing: bool = False