Spaces:
Sleeping
Sleeping
File size: 1,739 Bytes
3ad8be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
def flag(model_forward, perturb_shape, y, args, optimizer, device, criterion) :
model, forward = model_forward
model.train()
optimizer.zero_grad()
perturb = torch.FloatTensor(*perturb_shape).uniform_(-args.step_size, args.step_size).to(device)
perturb.requires_grad_()
out = forward(perturb)
loss = criterion(out, y)
loss /= args.m
for _ in range(args.m-1):
loss.backward()
perturb_data = perturb.detach() + args.step_size * torch.sign(perturb.grad.detach())
perturb.data = perturb_data.data
perturb.grad[:] = 0
out = forward(perturb)
loss = criterion(out, y)
loss /= args.m
loss.backward()
optimizer.step()
return loss, out
def flag_sbap(model_forward, perturb_shape, step_size, m, optimizer, device) :
model, forward = model_forward
model.train()
optimizer.zero_grad()
perturb = torch.FloatTensor(*perturb_shape).uniform_(-step_size, step_size).to(device)
perturb.requires_grad_()
(regression_loss_IC50, regression_loss_K), \
(affinity_pred_IC50, affinity_pred_K), \
(affinity_IC50, affinity_K) = forward(perturb)
loss = regression_loss_IC50 + regression_loss_K
loss /= m
for _ in range(m-1):
loss.backward()
perturb_data = perturb.detach() + step_size * torch.sign(perturb.grad.detach())
perturb.data = perturb_data.data
perturb.grad[:] = 0
(regression_loss_IC50, regression_loss_K), \
(affinity_pred_IC50, affinity_pred_K), \
(affinity_IC50, affinity_K) = forward(perturb)
loss = regression_loss_IC50 + regression_loss_K
loss /= m
loss.backward()
optimizer.step()
return loss
|