rahul7star commited on
Commit
d446991
·
verified ·
1 Parent(s): 909058a

Create critic.py

Browse files
Files changed (1) hide show
  1. jobs/process/models/critic.py +234 -0
jobs/process/models/critic.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import TYPE_CHECKING, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from safetensors.torch import load_file, save_file
9
+
10
+ from toolkit.losses import get_gradient_penalty
11
+ from toolkit.metadata import get_meta_for_safetensors
12
+ from toolkit.optimizer import get_optimizer
13
+ from toolkit.train_tools import get_torch_dtype
14
+
15
+
16
+ class MeanReduce(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, inputs):
21
+ # global mean over spatial dims (keeps channel/batch)
22
+ return torch.mean(inputs, dim=(2, 3), keepdim=True)
23
+
24
+
25
+ class SelfAttention2d(nn.Module):
26
+ """
27
+ Lightweight self-attention layer (SAGAN-style) that keeps spatial
28
+ resolution unchanged. Adds minimal params / compute but improves
29
+ long-range modelling – helpful for variable-sized inputs.
30
+ """
31
+ def __init__(self, in_channels: int):
32
+ super().__init__()
33
+ self.query = nn.Conv1d(in_channels, in_channels // 8, 1)
34
+ self.key = nn.Conv1d(in_channels, in_channels // 8, 1)
35
+ self.value = nn.Conv1d(in_channels, in_channels, 1)
36
+ self.gamma = nn.Parameter(torch.zeros(1))
37
+
38
+ def forward(self, x):
39
+ B, C, H, W = x.shape
40
+ flat = x.view(B, C, H * W) # (B,C,N)
41
+ q = self.query(flat).permute(0, 2, 1) # (B,N,C//8)
42
+ k = self.key(flat) # (B,C//8,N)
43
+ attn = torch.bmm(q, k) # (B,N,N)
44
+ attn = attn.softmax(dim=-1) # softmax along last dim
45
+ v = self.value(flat) # (B,C,N)
46
+ out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N)
47
+ out = out.view(B, C, H, W) # restore spatial dims
48
+ return self.gamma * out + x # residual
49
+
50
+
51
+ class CriticModel(nn.Module):
52
+ def __init__(self, base_channels: int = 64):
53
+ super().__init__()
54
+
55
+ def sn_conv(in_c, out_c, k, s, p):
56
+ return nn.utils.spectral_norm(
57
+ nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p)
58
+ )
59
+
60
+ layers = [
61
+ # initial down-sample
62
+ sn_conv(3, base_channels, 3, 2, 1),
63
+ nn.LeakyReLU(0.2, inplace=True),
64
+ ]
65
+
66
+ in_c = base_channels
67
+ # progressive downsamples ×3 (64→128→256→512)
68
+ for _ in range(3):
69
+ out_c = min(in_c * 2, 1024)
70
+ layers += [
71
+ sn_conv(in_c, out_c, 3, 2, 1),
72
+ nn.LeakyReLU(0.2, inplace=True),
73
+ ]
74
+ # single attention block after reaching 256 channels
75
+ if out_c == 256:
76
+ layers += [SelfAttention2d(out_c)]
77
+ in_c = out_c
78
+
79
+ # extra depth (keeps spatial size)
80
+ layers += [
81
+ sn_conv(in_c, 1024, 3, 1, 1),
82
+ nn.LeakyReLU(0.2, inplace=True),
83
+
84
+ # final 1-channel prediction map
85
+ sn_conv(1024, 1, 3, 1, 1),
86
+ MeanReduce(), # → (B,1,1,1)
87
+ nn.Flatten(), # → (B,1)
88
+ ]
89
+
90
+ self.main = nn.Sequential(*layers)
91
+
92
+ def forward(self, inputs):
93
+ # force full-precision inside AMP ctx for stability
94
+ with torch.cuda.amp.autocast(False):
95
+ return self.main(inputs.float())
96
+
97
+
98
+ if TYPE_CHECKING:
99
+ from jobs.process.TrainVAEProcess import TrainVAEProcess
100
+ from jobs.process.TrainESRGANProcess import TrainESRGANProcess
101
+
102
+
103
+ class Critic:
104
+ process: Union['TrainVAEProcess', 'TrainESRGANProcess']
105
+
106
+ def __init__(
107
+ self,
108
+ learning_rate=1e-5,
109
+ device='cpu',
110
+ optimizer='adam',
111
+ num_critic_per_gen=1,
112
+ dtype='float32',
113
+ lambda_gp=10,
114
+ start_step=0,
115
+ warmup_steps=1000,
116
+ process=None,
117
+ optimizer_params=None,
118
+ ):
119
+ self.learning_rate = learning_rate
120
+ self.device = device
121
+ self.optimizer_type = optimizer
122
+ self.num_critic_per_gen = num_critic_per_gen
123
+ self.dtype = dtype
124
+ self.torch_dtype = get_torch_dtype(self.dtype)
125
+ self.process = process
126
+ self.model = None
127
+ self.optimizer = None
128
+ self.scheduler = None
129
+ self.warmup_steps = warmup_steps
130
+ self.start_step = start_step
131
+ self.lambda_gp = lambda_gp
132
+
133
+ if optimizer_params is None:
134
+ optimizer_params = {}
135
+ self.optimizer_params = optimizer_params
136
+ self.print = self.process.print
137
+ print(f" Critic config: {self.__dict__}")
138
+
139
+ def setup(self):
140
+ self.model = CriticModel().to(self.device)
141
+ self.load_weights()
142
+ self.model.train()
143
+ self.model.requires_grad_(True)
144
+ params = self.model.parameters()
145
+ self.optimizer = get_optimizer(
146
+ params,
147
+ self.optimizer_type,
148
+ self.learning_rate,
149
+ optimizer_params=self.optimizer_params,
150
+ )
151
+ self.scheduler = torch.optim.lr_scheduler.ConstantLR(
152
+ self.optimizer,
153
+ total_iters=self.process.max_steps * self.num_critic_per_gen,
154
+ factor=1,
155
+ verbose=False,
156
+ )
157
+
158
+ def load_weights(self):
159
+ path_to_load = None
160
+ self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
161
+ files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
162
+ if files:
163
+ latest_file = max(files, key=os.path.getmtime)
164
+ print(f" - Latest checkpoint is: {latest_file}")
165
+ path_to_load = latest_file
166
+ else:
167
+ self.print(" - No checkpoint found, starting from scratch")
168
+ if path_to_load:
169
+ self.model.load_state_dict(load_file(path_to_load))
170
+
171
+ def save(self, step=None):
172
+ self.process.update_training_metadata()
173
+ save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
174
+ step_num = f"_{str(step).zfill(9)}" if step is not None else ''
175
+ save_path = os.path.join(
176
+ self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors"
177
+ )
178
+ save_file(self.model.state_dict(), save_path, save_meta)
179
+ self.print(f"Saved critic to {save_path}")
180
+
181
+ def get_critic_loss(self, vgg_output):
182
+ # (caller still passes combined [pred|target] images)
183
+ if self.start_step > self.process.step_num:
184
+ return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
185
+
186
+ warmup_scaler = 1.0
187
+ if self.process.step_num < self.start_step + self.warmup_steps:
188
+ warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
189
+
190
+ self.model.eval()
191
+ self.model.requires_grad_(False)
192
+
193
+ vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0)
194
+ stacked_output = self.model(vgg_pred)
195
+ return (-torch.mean(stacked_output)) * warmup_scaler
196
+
197
+ def step(self, vgg_output):
198
+ self.model.train()
199
+ self.model.requires_grad_(True)
200
+ self.optimizer.zero_grad()
201
+
202
+ critic_losses = []
203
+ inputs = vgg_output.detach().to(self.device, dtype=torch.float32)
204
+
205
+ vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
206
+ stacked_output = self.model(inputs).float()
207
+ out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
208
+
209
+ # hinge loss + gradient penalty
210
+ loss_real = torch.relu(1.0 - out_target).mean()
211
+ loss_fake = torch.relu(1.0 + out_pred).mean()
212
+ gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
213
+ critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty
214
+
215
+ critic_loss.backward()
216
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
217
+ self.optimizer.step()
218
+ self.scheduler.step()
219
+ critic_losses.append(critic_loss.item())
220
+
221
+ return float(np.mean(critic_losses))
222
+
223
+ def get_lr(self):
224
+ if hasattr(self.optimizer, 'get_avg_learning_rate'):
225
+ learning_rate = self.optimizer.get_avg_learning_rate()
226
+ elif self.optimizer_type.startswith('dadaptation') or \
227
+ self.optimizer_type.lower().startswith('prodigy'):
228
+ learning_rate = (
229
+ self.optimizer.param_groups[0]["d"] *
230
+ self.optimizer.param_groups[0]["lr"]
231
+ )
232
+ else:
233
+ learning_rate = self.optimizer.param_groups[0]['lr']
234
+ return learning_rate