File size: 2,142 Bytes
2e1316e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import transformers
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default='')

@dataclass
class DataArguments:
    given_num: bool = False
    img_size: int = 490
    hd_num: int = -1
    data_cfg: str = ''
    data_version: int = 3


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default='adamw_torch')
    max_length: int = field(
        default=4096,
        metadata={
            'help':
            'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
        },
    )
    use_lora: bool = False
    fix_vit: bool = True
    fix_sampler: bool = False
    # eval_flag: int = 0
    label_names: List[str] = field(default_factory=lambda: ['samples'])
    seed: int = 3407
    gradient_checkpointing: bool = True

@dataclass
class LoraArguments:
    lora_r: int = 64
    lora_alpha: int = 64
    lora_dropout: float = 0.05
    ### for internlm ###
    lora_target_modules: List[str] = field(default_factory=lambda: [
        'attention.wqkv',
        'attention.wo',
        'feed_forward.w1',
        'feed_forward.w2',
        'feed_forward.w3',
    ])
    #### for idefics2 ###
    # lora_target_modules: List[str] = field(default_factory=lambda: [
    #     'self_attn.q_proj',
    #     'self_attn.k_proj',
    #     'self_attn.v_proj',
    #     'self_attn.o_proj',
    #     'mlp.gate_proj',
    #     'mlp.up_proj',
    #     'mlp.down_proj',
    # ])
    lora_weight_path: str = ''
    lora_bias: str = 'none'
    lora_type: str = 'lora'


@dataclass
class EvalArguments:
    max_length: int = field(
        default=4096,
        metadata={
            'help':
            'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
        },
    )
    use_lora: bool = False
    fix_vit: bool = True
    fix_sampler: bool = True
    # eval_flag: int = 0
    label_names: List[str] = field(default_factory=lambda: ['samples'])
    gradient_checkpointing: bool = False