File size: 4,199 Bytes
4135cc9
836ce90
 
 
 
 
 
 
 
4135cc9
 
836ce90
 
 
 
 
4135cc9
836ce90
 
4135cc9
836ce90
 
 
4135cc9
836ce90
 
 
4135cc9
836ce90
4135cc9
836ce90
 
 
 
 
 
 
4135cc9
836ce90
 
4135cc9
836ce90
4135cc9
836ce90
 
 
4135cc9
836ce90
 
 
 
 
 
 
 
4135cc9
836ce90
4135cc9
836ce90
 
 
4135cc9
836ce90
4135cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
# import torch
# import torch.optim as optim
# import torchvision.models as models
# import torchvision.transforms as transforms
# from PIL import Image
# import numpy as np
# import requests
# from io import BytesIO

# Load VGG19 model
# vgg = models.vgg19(pretrained=True).features
# for param in vgg.parameters():
#     param.requires_grad_(False)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vgg.to(device)

# # Helper functions (load_image, im_convert, get_features, gram_matrix)
# # ... (Include the helper functions you provided earlier here)

# def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
#     content = load_image(content_image).to(device)
#     style = load_image(style_image, shape=content.shape[-2:]).to(device)
    
#     content_features = get_features(content, vgg)
#     style_features = get_features(style, vgg)
#     style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
    
#     target = content.clone().requires_grad_(True).to(device)
    
#     style_weights = {
#         'conv1_1': conv1_1,
#         'conv2_1': conv2_1,
#         'conv3_1': conv3_1,
#         'conv4_1': conv4_1,
#         'conv5_1': conv5_1
#     }
    
#     content_weight = alpha
#     style_weight = beta * 1e6
    
#     optimizer = optim.Adam([target], lr=0.003)
    
#     for ii in range(1, steps+1):
#         target_features = get_features(target, vgg)
#         content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
        
#         style_loss = 0
#         for layer in style_weights:
#             target_feature = target_features[layer]
#             target_gram = gram_matrix(target_feature)
#             _, d, h, w = target_feature.shape
#             style_gram = style_grams[layer]
#             layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
#             style_loss += layer_style_loss / (d * h * w)
        
#         total_loss = content_weight * content_loss + style_weight * style_loss
        
#         optimizer.zero_grad()
#         total_loss.backward()
#         optimizer.step()
    
#     return im_convert(target)

# Example images
examples = [
    ["path/to/content1.jpg", "path/to/style1.jpg"],
    ["path/to/content2.jpg", "path/to/style2.jpg"],
    ["path/to/content3.jpg", "path/to/style3.jpg"],
]

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Neural Style Transfer")
    with gr.Row():
        with gr.Column():
            content_input = gr.Image(label="Content Image")
            style_input = gr.Image(label="Style Image")
        with gr.Column():
            output_image = gr.Image(label="Output Image")
    
    with gr.Row():
        alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
        beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
    
    with gr.Row():
        conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
        conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
        conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
        conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
        conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
    
    steps_slider = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Number of Steps")
    
    run_button = gr.Button("Run Style Transfer")
    
    run_button.click(
        style_transfer,
        inputs=[
            content_input,
            style_input,
            alpha_slider,
            beta_slider,
            conv1_1_slider,
            conv2_1_slider,
            conv3_1_slider,
            conv4_1_slider,
            conv5_1_slider,
            steps_slider
        ],
        outputs=output_image
    )
    
    gr.Examples(
        examples,
        inputs=[content_input, style_input]
    )

demo.launch()