Spaces:
baselqt
/
No application file

File size: 290 Bytes
22b8701
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn


class BlendModule(nn.Module):
    def __init__(self, model_path, device):
        super().__init__()

        self.model = torch.jit.load(model_path).to(device)

    def forward(self, swap, mask, att_img):
        return self.model(swap, mask, att_img)