|
import torch |
|
from typing import Dict, Tuple, List |
|
from transformers import PretrainedConfig |
|
|
|
class PathummaAudioConfig(PretrainedConfig): |
|
|
|
model_type: str = "pathumma_audio" |
|
|
|
def __init__( |
|
self, |
|
llm_path: str = "Qwen/Qwen2-7B-Instruct", |
|
whisper_path: str = "nectec/Pathumma-whisper-th-large-v3", |
|
beats_path: str = "", |
|
init_from_scratch: bool = True, |
|
|
|
lora: bool = True, |
|
lora_infer_mode: bool = True, |
|
lora_rank: int = 8, |
|
lora_alpha: int = 32, |
|
lora_dropout: float = 0.1, |
|
target_modules: List[str] = ["q_proj", "v_proj"], |
|
qformer_query_token: int = 1, |
|
qformer_hidden_layers: int = 2, |
|
second_per_window: float = 0.333333, |
|
second_stride: float = 0.333333, |
|
|
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.architectures = kwargs.get("architectures", ["PathummaAudioModel"]) |
|
self.auto_map = kwargs.get("auto_map", { |
|
"AutoConfig": "configuration_pathumma_audio.PathummaAudioConfig", |
|
"AutoModel": "modeling_pathumma_audio.PathummaAudioModel" |
|
}) |
|
|
|
self.llm_path = llm_path |
|
self.whisper_path = whisper_path |
|
self.beats_path = beats_path |
|
self.init_from_scratch = init_from_scratch |
|
|
|
self.lora = lora |
|
self.lora_infer_mode = lora_infer_mode |
|
self.lora_rank = lora_rank |
|
self.lora_alpha = lora_alpha |
|
self.lora_dropout = lora_dropout |
|
self.target_modules = target_modules |
|
|
|
self.qformer_query_token = qformer_query_token |
|
self.qformer_hidden_layers = qformer_hidden_layers |
|
self.second_per_window = second_per_window |
|
self.second_stride = second_stride |
|
|
|
self.torch_dtype = torch_dtype |