File size: 3,387 Bytes
33aff71
 
 
 
 
 
 
31b64ca
 
 
 
33aff71
 
 
 
 
 
 
 
 
 
31b64ca
33aff71
 
31b64ca
33aff71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b64ca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
from typing import Optional, Union

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from pesq import pesq
from joblib import Parallel, delayed

from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d


class MetricDiscriminator(nn.Module):
    def __init__(self, config: NXMPNetConfig):
        super(MetricDiscriminator, self).__init__()
        dim = config.discriminator_dim
        in_channel = config.discriminator_in_channel

        self.layers = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
            nn.InstanceNorm2d(dim, affine=True),
            nn.PReLU(dim),
            nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
            nn.InstanceNorm2d(dim*2, affine=True),
            nn.PReLU(dim*2),
            nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
            nn.InstanceNorm2d(dim*4, affine=True),
            nn.PReLU(dim*4),
            nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
            nn.InstanceNorm2d(dim*8, affine=True),
            nn.PReLU(dim*8),
            nn.AdaptiveMaxPool2d(1),
            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
            nn.Dropout(0.3),
            nn.PReLU(dim*4),
            nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
            LearnableSigmoid1d(1)
        )

    def forward(self, x, y):
        xy = torch.stack((x, y), dim=1)
        return self.layers(xy)


MODEL_FILE = "discriminator.pt"


class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
    def __init__(self,
                 config: NXMPNetConfig,
                 ):
        super(MetricDiscriminatorPretrainedModel, self).__init__(
            config=config,
        )
        self.config = config

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        model = cls(config)

        if os.path.isdir(pretrained_model_name_or_path):
            ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
        else:
            ckpt_file = pretrained_model_name_or_path

        with open(ckpt_file, "rb") as f:
            state_dict = torch.load(f, map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict, strict=True)
        return model

    def save_pretrained(self,
                        save_directory: Union[str, os.PathLike],
                        state_dict: Optional[dict] = None,
                        ):

        model = self

        if state_dict is None:
            state_dict = model.state_dict()

        os.makedirs(save_directory, exist_ok=True)

        # save state dict
        model_file = os.path.join(save_directory, MODEL_FILE)
        torch.save(state_dict, model_file)

        # save config
        config_file = os.path.join(save_directory, CONFIG_FILE)
        self.config.to_yaml_file(config_file)
        return save_directory


if __name__ == '__main__':
    pass