File size: 3,064 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import torch
from typing import List, Optional, Tuple


class FeedForward(torch.nn.Module):
    """FeedForward module definition.

    Args:
        size: Input/Output size.
        hidden_size: Hidden size.
        block_id: Block index.
        num_blocks: Number of blocks in the architecture.

    """

    def __init__(
        self,
        size: int,
        hidden_size: int,
        block_id: int,
        dropout_rate: float,
        num_blocks: int,
    ) -> None:
        """Construct a FeedForward object."""
        super().__init__()

        self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))

        self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
        self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))

        self.proj_key = torch.nn.Linear(size, hidden_size, bias=True)
        self.proj_value = torch.nn.Linear(hidden_size, size, bias=True)
        self.proj_receptance = torch.nn.Linear(size, size, bias=True)

        self.block_id = block_id

        self.reset_parameters(size, block_id, num_blocks)
        self.dropout = torch.nn.Dropout(p=dropout_rate)

    def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None:
        """Reset module parameters.

        Args:
            size: Block size.
            block_id: Block index.
            num_blocks: Number of blocks in the architecture.

        """
        ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)

        time_weight = torch.ones(1, 1, size)

        for i in range(size):
            time_weight[0, 0, i] = i / size

        with torch.no_grad():
            self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
            self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)

    def forward(
        self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
        """Compute channel mixing.

        Args:
            x: FeedForward input sequences. (B, U, size)
            state: Decoder hidden state. [5 x (B, 1, size, N)]

        Returns:
            x: FeedForward output sequences. (B, U, size)
            state: Decoder hidden state. [5 x (B, 1, size, N)]

        """
        shifted_x = (
            self.time_shift(x) if state is None else state[0][..., self.block_id]
        )

        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
        receptance = x * self.time_mix_receptance + shifted_x * (
            1 - self.time_mix_receptance
        )

        key = torch.square(torch.relu(self.proj_key(key)))
        value = self.proj_value(self.dropout(key))
        receptance = torch.sigmoid(self.proj_receptance(receptance))

        if state is not None:
            state[0][..., self.block_id] = x

        x = receptance * value

        return x, state