BeanSamuel commited on
Commit
4db9546
·
1 Parent(s): 3cffe9d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from torchvision import transforms
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ print(device)
9
+
10
+ class ResidualBlock(nn.Module):
11
+ def __init__(self, channels):
12
+ super(ResidualBlock, self).__init__()
13
+ self.block = nn.Sequential(
14
+ nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
15
+ nn.InstanceNorm2d(channels),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
18
+ nn.InstanceNorm2d(channels)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return x + self.block(x)
23
+
24
+ # 強化版生成器:利用下採樣、殘差塊和上採樣結構
25
+ class StrongGenerator(nn.Module):
26
+ def __init__(self, num_residual_blocks=6):
27
+ super(StrongGenerator, self).__init__()
28
+ # 初始卷積層
29
+ model = [
30
+ nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False),
31
+ nn.InstanceNorm2d(64),
32
+ nn.ReLU(inplace=True)
33
+ ]
34
+
35
+ # 下採樣:連續兩次卷積降維
36
+ in_channels = 64
37
+ for _ in range(2):
38
+ out_channels = in_channels * 2
39
+ model += [
40
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
41
+ nn.InstanceNorm2d(out_channels),
42
+ nn.ReLU(inplace=True)
43
+ ]
44
+ in_channels = out_channels
45
+
46
+ # 多個殘差塊
47
+ for _ in range(num_residual_blocks):
48
+ model += [ResidualBlock(in_channels)]
49
+
50
+ # 上採樣:連續兩次反捲積提升解析度
51
+ for _ in range(2):
52
+ out_channels = in_channels // 2
53
+ model += [
54
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
55
+ nn.InstanceNorm2d(out_channels),
56
+ nn.ReLU(inplace=True)
57
+ ]
58
+ in_channels = out_channels
59
+
60
+ # 輸出層
61
+ model += [
62
+ nn.Conv2d(in_channels, 3, kernel_size=7, stride=1, padding=3),
63
+ nn.Tanh()
64
+ ]
65
+
66
+ self.model = nn.Sequential(*model)
67
+
68
+ def forward(self, x):
69
+ return self.model(x)
70
+
71
+ generator = StrongGenerator().to(device)
72
+
73
+
74
+ # 載入訓練好的 Generator 模型(此處以第 10 個 epoch 為例,請根據實際情況修改)
75
+ generator.load_state_dict(torch.load("./generator_epoch_10.pth", map_location=device))
76
+ generator.eval()
77
+
78
+ def restore_image(mosaic_image):
79
+ # 與訓練時相同的圖像轉換
80
+ transform_in = transforms.Compose([
81
+ transforms.Resize((256, 256)),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
84
+ ])
85
+ input_tensor = transform_in(mosaic_image).unsqueeze(0).to(device)
86
+ with torch.no_grad():
87
+ restored_tensor = generator(input_tensor)
88
+ restored_tensor = restored_tensor.squeeze(0).cpu()
89
+ restored_tensor = (restored_tensor * 0.5 + 0.5).clamp(0, 1)
90
+ restored_image = transforms.ToPILImage()(restored_tensor)
91
+ return restored_image
92
+
93
+ iface = gr.Interface(
94
+ fn=restore_image,
95
+ inputs=gr.Image(type="pil"),
96
+ outputs="image",
97
+ title="Dog Image Mosaic Restoration",
98
+ description="上傳打碼後的狗狗圖像,模型將嘗試還原原始圖像。"
99
+ )
100
+ iface.launch()