File size: 3,881 Bytes
be13417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from lavis.common.registry import registry
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder, MultiModalDatasetBuilder

from lavis.datasets.datasets.audio_captioning_datasets import (
    AudioSetDataset,
    AudioSetEvalDataset,
    AudioSetInstructDataset,
    AudioCapsDataset,
    AudioCapsEvalDataset,
    AudioCapsInstructDataset,
    ClothoV2Dataset,
    ClothoV2InstructDataset,
    ClothoV2EvalDataset,
    AudioLanguagePretrainDataset,
    AudioLanguagePretrainEvalDataset,
    AudioLanguagePretrainInstructDataset
)


class AudioCapBuilder(MultiModalDatasetBuilder):
    train_dataset_cls = AudioSetDataset
    eval_dataset_cls = AudioSetEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/audioset/defaults_mm_cap.yaml",
    }


    def build(self):
        datasets = super().build()
        build_info = self.config.build_info
        for split,ds in datasets.items():
            # TODO: add option to download templates
            templates = build_info.get('templates')
            if templates == None:
                ds._build_templates(None)
            else:
                ds._build_templates(build_info.templates.storage)
        return datasets

@registry.register_builder("audioset_mm_caption")
class AudioSetBuilder(AudioCapBuilder):
    train_dataset_cls = AudioSetDataset
    eval_dataset_cls = AudioSetEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/audioset/defaults_mm_cap.yaml",
    }

@registry.register_builder("audioset_mm_caption_instruct")
class AudioSetInstructBuilder(AudioCapBuilder):
    train_dataset_cls = AudioSetInstructDataset
    eval_dataset_cls = AudioSetEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/audioset/defaults_mm_cap_instruct.yaml",
    }

@registry.register_builder("audiocaps_mm_caption")
class AudioCapsCapBuilder(AudioCapBuilder):
    train_dataset_cls = AudioCapsDataset
    eval_dataset_cls = AudioCapsEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/audiocaps/defaults_mm_cap.yaml",
    }

@registry.register_builder("audiocaps_mm_caption_instruct")
class AudioCapsInstructCapBuilder(AudioCapBuilder):
    train_dataset_cls = AudioCapsInstructDataset
    eval_dataset_cls = AudioCapsEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/audiocaps/defaults_mm_cap_instruct.yaml",
    }

@registry.register_builder("clothov2")
class ClothoCapInstructBuilder(MultiModalDatasetBuilder):
    train_dataset_cls = ClothoV2Dataset
    eval_dataset_cls = ClothoV2EvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/clotho/defaults_mm_cap.yaml",
    }

@registry.register_builder("clothov2_instruct")
class ClothoCapInstructBuilder(MultiModalDatasetBuilder):
    train_dataset_cls = ClothoV2InstructDataset
    eval_dataset_cls = ClothoV2EvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/clotho/defaults_mm_cap_instruct.yaml",
    }


@registry.register_builder("wavcaps_mm_caption")
class WavCapsCapBuilder(AudioCapBuilder):
    train_dataset_cls = AudioLanguagePretrainDataset
    eval_dataset_cls = AudioLanguagePretrainEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/wavcaps/defaults_mm_cap.yaml",
    }



@registry.register_builder("wavcaps_mm_caption_instruct")
class WavCapsCapInstructBuilder(AudioCapBuilder):
    train_dataset_cls = AudioLanguagePretrainInstructDataset
    eval_dataset_cls = AudioLanguagePretrainEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/wavcaps/defaults_mm_cap_instruct.yaml",
    }