|
import torch.nn as nn |
|
import torch |
|
from .deform_conv import DCN_layer |
|
import clip |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
clip_model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
text_embed_dim = clip_model.text_projection.shape[1] |
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) |
|
|
|
|
|
class DGM(nn.Module): |
|
def __init__(self, channels_in, channels_out, kernel_size): |
|
super(DGM, self).__init__() |
|
self.channels_out = channels_out |
|
self.channels_in = channels_in |
|
self.kernel_size = kernel_size |
|
|
|
self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size, |
|
padding=(kernel_size - 1) // 2, bias=False) |
|
self.sft = SFT_layer(self.channels_in, self.channels_out) |
|
|
|
self.relu = nn.LeakyReLU(0.1, True) |
|
|
|
def forward(self, x, inter, text_prompt): |
|
''' |
|
:param x: feature map: B * C * H * W |
|
:inter: degradation map: B * C * H * W |
|
''' |
|
dcn_out = self.dcn(x, inter) |
|
sft_out = self.sft(x, inter, text_prompt) |
|
out = dcn_out + sft_out |
|
out = x + out |
|
|
|
return out |
|
|
|
|
|
class TextProjectionHead(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super(TextProjectionHead, self).__init__() |
|
self.proj = nn.Sequential( |
|
nn.Linear(input_dim, output_dim), |
|
nn.ReLU(), |
|
nn.Linear(output_dim, output_dim) |
|
).float() |
|
|
|
def forward(self, x): |
|
return self.proj(x.float()) |
|
|
|
|
|
|
|
class SFT_layer(nn.Module): |
|
def __init__(self, channels_in, channels_out): |
|
super(SFT_layer, self).__init__() |
|
self.conv_gamma = nn.Sequential( |
|
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
) |
|
self.conv_beta = nn.Sequential( |
|
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
) |
|
|
|
self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out) |
|
|
|
''' |
|
self.text_gamma = nn.Sequential( |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
).float() |
|
self.text_beta = nn.Sequential( |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False), |
|
).float() |
|
''' |
|
|
|
self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2) |
|
|
|
|
|
def forward(self, x, inter, text_prompt): |
|
''' |
|
:param x: degradation representation: B * C |
|
:param inter: degradation intermediate representation map: B * C * H * W |
|
''' |
|
|
|
|
|
|
|
B, C, H, W = inter.shape |
|
|
|
|
|
text_tokens = clip.tokenize(text_prompt).to(device) |
|
with torch.no_grad(): |
|
text_embed = clip_model.encode_text(text_tokens) |
|
|
|
text_proj = self.text_proj_head(text_embed).float() |
|
|
|
|
|
|
|
text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W) |
|
|
|
|
|
combined = inter * text_proj_expanded |
|
|
|
|
|
|
|
img_gamma = self.conv_gamma(combined) |
|
img_beta = self.conv_beta(combined) |
|
|
|
''' simple concat |
|
text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) |
|
text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W) |
|
''' |
|
|
|
''' |
|
text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C |
|
|
|
# 이미지 중간 표현 변환: B * (H*W) * C로 변경 |
|
inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C |
|
|
|
# Cross-attention 적용 |
|
attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat) |
|
attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W |
|
|
|
# Gamma와 Beta 계산 |
|
img_gamma = self.conv_gamma(attn_output) |
|
img_beta = self.conv_beta(attn_output) |
|
''' |
|
|
|
return x * img_gamma + img_beta |
|
|
|
|
|
class DGB(nn.Module): |
|
def __init__(self, conv, n_feat, kernel_size): |
|
super(DGB, self).__init__() |
|
|
|
|
|
|
|
self.dgm1 = DGM(n_feat, n_feat, kernel_size) |
|
self.dgm2 = DGM(n_feat, n_feat, kernel_size) |
|
self.conv1 = conv(n_feat, n_feat, kernel_size) |
|
self.conv2 = conv(n_feat, n_feat, kernel_size) |
|
|
|
self.relu = nn.LeakyReLU(0.1, True) |
|
|
|
def forward(self, x, inter, text_prompt): |
|
''' |
|
:param x: feature map: B * C * H * W |
|
:param inter: degradation representation: B * C * H * W |
|
''' |
|
|
|
out = self.relu(self.dgm1(x, inter, text_prompt)) |
|
out = self.relu(self.conv1(out)) |
|
out = self.relu(self.dgm2(out, inter, text_prompt)) |
|
out = self.conv2(out) + x |
|
|
|
return out |
|
|
|
|
|
class DGG(nn.Module): |
|
def __init__(self, conv, n_feat, kernel_size, n_blocks): |
|
super(DGG, self).__init__() |
|
self.n_blocks = n_blocks |
|
modules_body = [ |
|
DGB(conv, n_feat, kernel_size) \ |
|
for _ in range(n_blocks) |
|
] |
|
modules_body.append(conv(n_feat, n_feat, kernel_size)) |
|
|
|
self.body = nn.Sequential(*modules_body) |
|
|
|
def forward(self, x, inter, text_prompt): |
|
''' |
|
:param x: feature map: B * C * H * W |
|
:param inter: degradation representation: B * C * H * W |
|
''' |
|
res = x |
|
for i in range(self.n_blocks): |
|
res = self.body[i](res, inter, text_prompt) |
|
res = self.body[-1](res) |
|
res = res + x |
|
|
|
return res |
|
|
|
|
|
class DGRN(nn.Module): |
|
def __init__(self, opt, conv=default_conv): |
|
super(DGRN, self).__init__() |
|
|
|
self.n_groups = 5 |
|
n_blocks = 5 |
|
n_feats = 64 |
|
kernel_size = 3 |
|
|
|
|
|
modules_head = [conv(3, n_feats, kernel_size)] |
|
self.head = nn.Sequential(*modules_head) |
|
|
|
|
|
modules_body = [ |
|
DGG(default_conv, n_feats, kernel_size, n_blocks) \ |
|
for _ in range(self.n_groups) |
|
] |
|
modules_body.append(conv(n_feats, n_feats, kernel_size)) |
|
self.body = nn.Sequential(*modules_body) |
|
|
|
|
|
modules_tail = [conv(n_feats, 3, kernel_size)] |
|
self.tail = nn.Sequential(*modules_tail) |
|
|
|
def forward(self, x, inter, text_prompt): |
|
|
|
x = self.head(x) |
|
|
|
|
|
res = x |
|
for i in range(self.n_groups): |
|
res = self.body[i](res, inter, text_prompt) |
|
res = self.body[-1](res) |
|
res = res + x |
|
|
|
|
|
x = self.tail(res) |
|
|
|
return x |
|
|