import numpy as np import math import torch import torch.nn as nn import torch.nn.functional as F ##_____________________Define:MODEL-F & MODEL-G_________________ # Positional Encoding class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=1024): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(0.1) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model) def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) # Transformer Encoder class TransformerEncoder(nn.Module): def __init__(self, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1): super(TransformerEncoder, self).__init__() self.positional_encoding = PositionalEncoding(d_model) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout,batch_first=True), num_layers=num_layers ) def preprocess_latent(self, Z): batch_size, channels, height, width = Z.shape # (batch_size, 256, 32, 32) seq_len = height * width Z = Z.permute(0, 2, 3, 1).reshape(batch_size, seq_len, channels) # (batch_size, 1024, 256) return Z def postprocess_latent(self, Z): batch_size, seq_len, channels = Z.shape # (batch_size, 1024, 256) height = width = int(math.sqrt(seq_len)) Z = Z.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) # (batch_size, 256, 32, 32) return Z def forward(self, Z): Z = self.preprocess_latent(Z) Z = self.positional_encoding(Z) Z = self.encoder(Z) Z = self.postprocess_latent(Z) return Z # latent of transformer class TransformerDecoder(nn.Module): def __init__(self, d_model=256, nhead=8, num_layers=12, dim_feedforward=1024, dropout=0.1): super().__init__() self.d_model = d_model # Enhanced positional encoding self.positional_encoding = PositionalEncoding(d_model) # Multi-layer learnable start tokens self.base_start = nn.Parameter(torch.randn(1, 1024, d_model)) self.start_net = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), nn.LayerNorm(d_model) ) # Context-aware transformer decoder self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True ), num_layers=num_layers ) # Output projection with residual self.output_layer = nn.Sequential( nn.Linear(d_model, d_model*2), nn.GELU(), nn.Linear(d_model*2, d_model)) self.init_weights() def init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) nn.init.normal_(self.base_start, mean=0, std=0.02) def preprocess_latent(self, Z): # Convert (B, C, H, W) to (B, H*W, C) return Z.permute(0, 2, 3, 1).flatten(1, 2) def postprocess_latent(self, Z): # Convert (B, H*W, C) back to (B, C, H, W) B, L, C = Z.shape H = W = int(L**0.5) return Z.view(B, H, W, C).permute(0, 3, 1, 2) def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5): # Process input latent Z = self.preprocess_latent(Z) #Z = self.positional_encoding(Z) # Generate enhanced start tokens B = Z.size(0) base_tokens = self.base_start.expand(B, -1, -1) processed_start = self.start_net(base_tokens) # Teacher forcing integration if Z1_start_tokens is not None and teacher_forcing_ratio > 0: Z1_processed = self.positional_encoding(self.preprocess_latent(Z1_start_tokens)) # Create mixing mask mask = torch.rand(B, 1, 1, device=Z.device) < teacher_forcing_ratio processed_start = torch.where(mask, Z1_processed, processed_start) # Decoder processing with residual decoder_input = self.positional_encoding(processed_start) outputs = self.decoder(decoder_input, Z) outputs = self.output_layer(outputs + decoder_input) return self.postprocess_latent(outputs) class DeepfakeToSourceTransformer(nn.Module): def __init__(self, d_model=256, encoder_nhead=8, decoder_nhead=8, num_encoder_layers=6, num_decoder_layers=12, dim_feedforward=1024, dropout=0.1): super().__init__() self.encoder = TransformerEncoder( d_model=d_model, nhead=encoder_nhead, num_layers=num_encoder_layers, dim_feedforward=1024, dropout=dropout ) self.decoder = TransformerDecoder( d_model=d_model, nhead=decoder_nhead, num_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout ) def forward(self, Z, Z1_start_tokens=None, teacher_forcing_ratio=0.5): memory = self.encoder(Z) Z1 = self.decoder(memory, Z1_start_tokens, teacher_forcing_ratio=teacher_forcing_ratio) return Z1