Upload 2 files
Browse files- src/src_loss.py +45 -0
- src/src_scheduler_config.json +3 -0
src/src_loss.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_A=None
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
class LossSchedulerModel(torch.nn.Module):
|
5 |
+
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
6 |
+
def forward(A,t,xT,e_prev):
|
7 |
+
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
|
8 |
+
for(D,E)in zip(B,A.we[t]):C+=D*E
|
9 |
+
return C.to(xT.dtype)
|
10 |
+
class LossScheduler:
|
11 |
+
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
|
12 |
+
@staticmethod
|
13 |
+
def load(path):A,B,C=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C);return LossScheduler(A,D)
|
14 |
+
def save(A,path):B,C,D=A.timesteps,A.model.wx,A.model.we;torch.save((B,C,D),path)
|
15 |
+
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
|
16 |
+
def scale_model_input(A,sample,*B,**C):return sample
|
17 |
+
@torch.no_grad()
|
18 |
+
def step(self,model_output,timestep,sample,*D,**E):
|
19 |
+
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
20 |
+
if A.t_prev==-1:A.xT=sample
|
21 |
+
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
22 |
+
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
|
23 |
+
else:A.t_prev=B
|
24 |
+
return C,
|
25 |
+
class SchedulerWrapper:
|
26 |
+
def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path
|
27 |
+
def set_timesteps(A,num_inference_steps,**C):
|
28 |
+
D=num_inference_steps
|
29 |
+
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
30 |
+
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
31 |
+
def step(B,model_output,timestep,sample,**F):
|
32 |
+
D=sample;E=model_output;A=timestep
|
33 |
+
if B.loss_scheduler is _A:
|
34 |
+
C=B.scheduler.step(E,A,D,**F);A=A.tolist()
|
35 |
+
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
36 |
+
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
37 |
+
else:C=B.loss_scheduler.step(E,A,D,**F);return C
|
38 |
+
def scale_model_input(A,sample,timestep):return sample
|
39 |
+
def add_noise(A,original_samples,noise,timesteps):B=A.scheduler.add_noise(original_samples,noise,timesteps);return B
|
40 |
+
def get_path(C):
|
41 |
+
A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
|
42 |
+
for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
|
43 |
+
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
44 |
+
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
|
45 |
+
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
src/src_scheduler_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_by": "RobertML"
|
3 |
+
}
|