File size: 1,850 Bytes
dae6ad4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from typing import Dict, Tuple, List
from transformers import PretrainedConfig

class PathummaAudioConfig(PretrainedConfig):

    model_type: str = "pathumma_audio"

    def __init__(
        self,
        llm_path: str = "Qwen/Qwen2-7B-Instruct",
        whisper_path: str = "openai/whisper-large-v3",
        beats_path: str = "",
        init_from_scratch: bool = True,
        
        lora: bool = True,
        lora_infer_mode: bool = True,
        lora_rank: int = 8,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
        target_modules: List[str] = ["q_proj", "v_proj"],
        qformer_query_token: int = 1,
        qformer_hidden_layers: int = 2,
        second_per_window: float = 0.333333,
        second_stride: float = 0.333333,
        
        torch_dtype: torch.dtype = torch.bfloat16,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        self.architectures = kwargs.get("architectures", ["PathummaAudioModel"])
        self.auto_map = kwargs.get("auto_map", {
            "AutoConfig": "configuration_pathumma_audio.PathummaAudioConfig",
            "AutoModel": "modeling_pathumma_audio.PathummaAudioModel"
        })
        
        self.llm_path = llm_path
        self.whisper_path = whisper_path
        self.beats_path = beats_path
        self.init_from_scratch = init_from_scratch

        self.lora = lora
        self.lora_infer_mode = lora_infer_mode
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.target_modules = target_modules

        self.qformer_query_token = qformer_query_token
        self.qformer_hidden_layers = qformer_hidden_layers
        self.second_per_window = second_per_window
        self.second_stride = second_stride        

        self.torch_dtype = torch_dtype