import torch.nn as nn import torch from .deform_conv import DCN_layer import clip device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') clip_model, preprocess = clip.load("ViT-B/32", device=device) # 동적으로 텍스트 임베딩 차원 가져오기 text_embed_dim = clip_model.text_projection.shape[1] def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) class DGM(nn.Module): def __init__(self, channels_in, channels_out, kernel_size): super(DGM, self).__init__() self.channels_out = channels_out self.channels_in = channels_in self.kernel_size = kernel_size self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.sft = SFT_layer(self.channels_in, self.channels_out) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x, inter, text_prompt): ''' :param x: feature map: B * C * H * W :inter: degradation map: B * C * H * W ''' dcn_out = self.dcn(x, inter) sft_out = self.sft(x, inter, text_prompt) out = dcn_out + sft_out out = x + out return out # Projection Head 정의 class TextProjectionHead(nn.Module): def __init__(self, input_dim, output_dim): super(TextProjectionHead, self).__init__() self.proj = nn.Sequential( nn.Linear(input_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim) ).float() def forward(self, x): return self.proj(x.float()) class SFT_layer(nn.Module): def __init__(self, channels_in, channels_out): super(SFT_layer, self).__init__() self.conv_gamma = nn.Sequential( nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), ) self.conv_beta = nn.Sequential( nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), ) self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out) ''' self.text_gamma = nn.Sequential( nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), ).float() self.text_beta = nn.Sequential( nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), ).float() ''' self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2) def forward(self, x, inter, text_prompt): ''' :param x: degradation representation: B * C :param inter: degradation intermediate representation map: B * C * H * W ''' # img_gamma = self.conv_gamma(inter) # img_beta = self.conv_beta(inter) B, C, H, W = inter.shape #cross attention text_tokens = clip.tokenize(text_prompt).to(device) # Tokenize the text prompts (Batch size) with torch.no_grad(): text_embed = clip_model.encode_text(text_tokens) text_proj = self.text_proj_head(text_embed).float() # 텍스트 임베딩 차원 확장: (B, C, H, W)로 변경 #concat # text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, self.conv_gamma[0].out_channels, H, W) text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W) # 이미지 중간 표현과 텍스트 임베딩 결합 (concat) combined = inter * text_proj_expanded # combined = torch.cat([inter, text_proj_expanded], dim=1) # 이미지와 텍스트 기반 gamma와 beta 계산 img_gamma = self.conv_gamma(combined) img_beta = self.conv_beta(combined) ''' simple concat text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) ''' ''' text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C # 이미지 중간 표현 변환: B * (H*W) * C로 변경 inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C # Cross-attention 적용 attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat) attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W # Gamma와 Beta 계산 img_gamma = self.conv_gamma(attn_output) img_beta = self.conv_beta(attn_output) ''' # concat으로 text 결합 실험 return x * img_gamma + img_beta class DGB(nn.Module): def __init__(self, conv, n_feat, kernel_size): super(DGB, self).__init__() # self.da_conv1 = DGM(n_feat, n_feat, kernel_size) # self.da_conv2 = DGM(n_feat, n_feat, kernel_size) self.dgm1 = DGM(n_feat, n_feat, kernel_size) self.dgm2 = DGM(n_feat, n_feat, kernel_size) self.conv1 = conv(n_feat, n_feat, kernel_size) self.conv2 = conv(n_feat, n_feat, kernel_size) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x, inter, text_prompt): ''' :param x: feature map: B * C * H * W :param inter: degradation representation: B * C * H * W ''' out = self.relu(self.dgm1(x, inter, text_prompt)) out = self.relu(self.conv1(out)) out = self.relu(self.dgm2(out, inter, text_prompt)) out = self.conv2(out) + x return out class DGG(nn.Module): def __init__(self, conv, n_feat, kernel_size, n_blocks): super(DGG, self).__init__() self.n_blocks = n_blocks modules_body = [ DGB(conv, n_feat, kernel_size) \ for _ in range(n_blocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x, inter, text_prompt): ''' :param x: feature map: B * C * H * W :param inter: degradation representation: B * C * H * W ''' res = x for i in range(self.n_blocks): res = self.body[i](res, inter, text_prompt) res = self.body[-1](res) res = res + x return res class DGRN(nn.Module): def __init__(self, opt, conv=default_conv): super(DGRN, self).__init__() self.n_groups = 5 n_blocks = 5 n_feats = 64 kernel_size = 3 # head module modules_head = [conv(3, n_feats, kernel_size)] self.head = nn.Sequential(*modules_head) # body modules_body = [ DGG(default_conv, n_feats, kernel_size, n_blocks) \ for _ in range(self.n_groups) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) # tail modules_tail = [conv(n_feats, 3, kernel_size)] self.tail = nn.Sequential(*modules_tail) def forward(self, x, inter, text_prompt): # head x = self.head(x) # body res = x for i in range(self.n_groups): res = self.body[i](res, inter, text_prompt) res = self.body[-1](res) res = res + x # tail x = self.tail(res) return x