File size: 6,124 Bytes
40f71f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PatchEmbedding
from collections import Counter
from layers.SharedWavMoE import WavMoE
from layers.RevIN import RevIN
import torch.fft
from layers.Embed import DataEmbedding
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
# self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
# x = self.flatten(x)
# print(self.linear,x.shape)
x = self.linear(x)
x = self.dropout(x)
return x
class Model(nn.Module):
"""
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.patch_len = configs.input_token_len
self.stride = self.patch_len
self.pred_len = configs.test_pred_len
self.test_seq_len = configs.test_seq_len
# embedding configs
self.output_attention = configs.output_attention
self.padding = configs.padding
# MoE设置
self.hidden_size = configs.hidden_size
self.intermediate_size = configs.intermediate_size
self.top_k = configs.top_k
self.shared_experts = configs.shared_experts
self.wavelet = configs.wavelet
self.level = configs.shared_experts
self.proj_wight = configs.proj_wight
# Embedding
self.patch_embedding = PatchEmbedding(
configs.d_model, self.patch_len, self.stride, self.padding, configs.dropout)
self.data_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.revin_layer = RevIN(configs.enc_in)
self.encoder_patch = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, configs.factor,
attention_dropout=configs.dropout,
output_attention=configs.output_attention),
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
self.encoder_time = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, configs.factor,
attention_dropout=configs.dropout,
output_attention=configs.output_attention),
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
self.head_nf = configs.d_model * \
int((configs.seq_len - self.patch_len) / self.stride + 1)
self.projection = nn.Linear(self.head_nf, int(configs.seq_len*self.proj_wight), bias=True)
self.data_projection = nn.Linear(configs.d_model, configs.enc_in, bias=True)
self.wavmoe = WavMoE(configs)
self.head = FlattenHead(configs.enc_in, nf= int(configs.seq_len*self.proj_wight), target_window= self.seq_len,
head_dropout=configs.dropout)
self.gelu = nn.GELU()
def main(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# 归一化并且嵌入
x_revin = self.revin_layer(x_enc, 'norm').permute(0, 2, 1)
# print("x_revin.shape:",x_revin.shape)
B, D, S = x_revin.shape
# 进入注意力机制
x_inver=self.data_embedding(x_revin.permute(0, 2, 1), x_mark_enc)
nav_out, attn_w = self.encoder_time(x_inver, attn_mask=None)
#print("nav_out.shape:", nav_out.shape,self.data_projection)
nav_out = self.data_projection(nav_out)
#print("nav_out.shape:", nav_out.shape)
#patch embedding进入多头FullAttention
# u: [bs * nvars x patch_num x d_model]
x_pe, n_vars = self.patch_embedding(x_revin+nav_out.permute(0, 2, 1))
#print("x_pe.shape:",x_pe.shape, n_vars)
enc_out, attn = self.encoder_patch(x_pe)
dec_out = enc_out.reshape(B, D, -1)
#print("dec_out.shape:",dec_out.shape, self.head_nf)
act_val = self.projection(dec_out)
#print("act_val:", act_val.shape)
# 专家系统
moe_out, router_logits = self.wavmoe(act_val + nav_out.permute(0, 2, 1))
#print("moe_out", moe_out.shape)
head_out = self.head(moe_out)
# 逆归一化输出
x_out = self.revin_layer(head_out.permute(0, 2, 1), 'denorm')
#print(x_out.shape)
return x_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'forecast':
dec_out = self.main(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.test_seq_len :, :] # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.main(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out # [B, L, D]
return None
|