File size: 2,013 Bytes
7999e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2612f87
 
7999e5a
2612f87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torchvision.transforms
from PIL import Image

from AdaIN import AdaINNet
from utils import adaptive_instance_normalization, transform, linear_histogram_matching

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
    """
    Given content image and style image, generate feature maps with encoder, apply
    neural style transfer with adaptive instance normalization, generate output image
    with decoder

    Args:
        content_tensor (torch.FloatTensor): Content image
        style_tensor (torch.FloatTensor): Style Image
        encoder: Encoder (vgg19) network
        decoder: Decoder network
        alpha (float, default=1.0): Weight of style image feature

    Return:
        output_tensor (torch.FloatTensor): Style Transfer output image
    """

    content_enc = encoder(content_tensor)
    style_enc = encoder(style_tensor)

    transfer_enc = adaptive_instance_normalization(content_enc, style_enc)

    mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc
    return decoder(mix_enc)


def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control):

    vgg = torch.load(vgg_weights_path)
    model = AdaINNet(vgg).to(device)
    model.decoder.load_state_dict(torch.load(decoder_weights_path))
    model.eval()

    # Prepare image transform
    t = transform(512)

    # load images
    content_img = Image.open(content_path)
    content_tensor = t(content_img).unsqueeze(0).to(device)
    style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device)

    if color_control:
        style_tensor = linear_histogram_matching(content_tensor, style_tensor)

    with torch.no_grad():
        out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()

    outimage_fname = 'output.png'
    torchvision.utils.save_image(out_tensor.squeeze(0), outimage_fname)

    return outimage_fname