File size: 290 Bytes
22b8701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
import torch.nn as nn
class BlendModule(nn.Module):
def __init__(self, model_path, device):
super().__init__()
self.model = torch.jit.load(model_path).to(device)
def forward(self, swap, mask, att_img):
return self.model(swap, mask, att_img)
|