Spaces:
Sleeping
Sleeping
added files
Browse files- .gitattributes +2 -0
- .gitignore +6 -0
- .python-version +1 -0
- README.md +8 -4
- app.py +86 -0
- data/content/bear.jpg +3 -0
- data/content/cow.jpg +3 -0
- data/content/ducks.jpg +3 -0
- data/styles/2272.jpg +3 -0
- data/styles/2314.jpg +3 -0
- data/styles/25.jpg +3 -0
- models/checkpoint_puddle_70k.pt +3 -0
- models/checkpoint_puddle_79k.pt +3 -0
- models/model_puddle.pt +3 -0
- nb.ipynb +3 -0
- requirements.txt +4 -0
- src/adain.py +17 -0
- src/loss.py +32 -0
- src/main.py +200 -0
- src/model.py +66 -0
- src/utils.py +62 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.DS_Store
|
3 |
+
.vscode/
|
4 |
+
|
5 |
+
test.*
|
6 |
+
test/
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ai-env
|
README.md
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
---
|
2 |
title: Style Transfer
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Style Transfer
|
3 |
+
emoji: 👨🎨
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Style Transfer
|
13 |
+
|
14 |
+
## References
|
15 |
+
|
16 |
+
[1] Huang, Xun, and Serge Belongie. "Arbitrary style transfer in real-time with adaptive instance normalization." *Proceedings of the IEEE international conference on computer vision*. 2017.
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms.functional as TF
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from src.model import Model
|
6 |
+
import os
|
7 |
+
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
|
10 |
+
|
11 |
+
def denorm_img(img: torch.Tensor):
|
12 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
|
13 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
|
14 |
+
return torch.clip(img * std + mean, min=0, max=1)
|
15 |
+
|
16 |
+
|
17 |
+
def main(inp1, inp2, alph, out_size=256):
|
18 |
+
model = Model()
|
19 |
+
model.load_state_dict(torch.load("models/model_puddle.pt", map_location=torch.device(device)))
|
20 |
+
model.eval()
|
21 |
+
|
22 |
+
model.alpha = alph
|
23 |
+
|
24 |
+
style = TF.to_tensor(inp1)
|
25 |
+
content = TF.to_tensor(inp2)
|
26 |
+
|
27 |
+
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
28 |
+
transform = transforms.Compose(
|
29 |
+
[transforms.Resize(out_size, antialias=True), transforms.CenterCrop(out_size)]
|
30 |
+
)
|
31 |
+
|
32 |
+
style, content = norm(style), norm(content)
|
33 |
+
style, content = transform(style), transform(content)
|
34 |
+
|
35 |
+
style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
|
36 |
+
|
37 |
+
out = model(content, style)
|
38 |
+
|
39 |
+
return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
|
40 |
+
|
41 |
+
|
42 |
+
with gr.Blocks() as demo:
|
43 |
+
gr.Markdown("# Style Transfer with AdaIN")
|
44 |
+
with gr.Row(variant="compact"):
|
45 |
+
inp1 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Style")
|
46 |
+
inp2 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Content")
|
47 |
+
out = gr.Image(type="numpy", label="Output")
|
48 |
+
with gr.Row():
|
49 |
+
out_size = (
|
50 |
+
gr.Dropdown(
|
51 |
+
choices=[256, 512],
|
52 |
+
value=256,
|
53 |
+
multiselect=False,
|
54 |
+
interactive=True,
|
55 |
+
allow_custom_value=True,
|
56 |
+
label="Output size",
|
57 |
+
info="Size of the output image",
|
58 |
+
),
|
59 |
+
)
|
60 |
+
alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column():
|
64 |
+
gr.Markdown("## Style Examples")
|
65 |
+
gr.Examples(
|
66 |
+
examples=[
|
67 |
+
os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"),
|
68 |
+
os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"),
|
69 |
+
os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"),
|
70 |
+
],
|
71 |
+
inputs=inp1,
|
72 |
+
)
|
73 |
+
with gr.Column():
|
74 |
+
gr.Markdown("## Content Examples")
|
75 |
+
gr.Examples(
|
76 |
+
examples=[
|
77 |
+
os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
|
78 |
+
os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
|
79 |
+
os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
|
80 |
+
],
|
81 |
+
inputs=inp2,
|
82 |
+
)
|
83 |
+
btn = gr.Button("Run")
|
84 |
+
btn.click(fn=main, inputs=[inp1, inp2, alph, out_size[0]], outputs=out)
|
85 |
+
|
86 |
+
demo.launch()
|
data/content/bear.jpg
ADDED
Git LFS Details
|
data/content/cow.jpg
ADDED
Git LFS Details
|
data/content/ducks.jpg
ADDED
Git LFS Details
|
data/styles/2272.jpg
ADDED
Git LFS Details
|
data/styles/2314.jpg
ADDED
Git LFS Details
|
data/styles/25.jpg
ADDED
Git LFS Details
|
models/checkpoint_puddle_70k.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:090b259580c0a6a7fbece489de7177aa5decebeea7dc6f26dab4d9e9aeb6f700
|
3 |
+
size 41942833
|
models/checkpoint_puddle_79k.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f771b61eeb8b1505f690c6c09bbca269d17473b82cb8ed905877d9ed7de26bc2
|
3 |
+
size 41942833
|
models/model_puddle.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a215c24cf629cf5f2b2cf96014347fda704c06acd68e312720c5783eebde5ca
|
3 |
+
size 23333701
|
nb.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53f653d323480c9ccce8283220a78f5f68fba868ca181923d41da2869408bd46
|
3 |
+
size 60328933
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
torchinfo
|
src/adain.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def mi(x: torch.Tensor) -> torch.Tensor:
|
6 |
+
return torch.sum(x, dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3])
|
7 |
+
|
8 |
+
def sigma(x: torch.Tensor, epsilon=1e-5) -> torch.Tensor:
|
9 |
+
return torch.sqrt(torch.sum(((x - mi(x))**2 + epsilon), dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3]))
|
10 |
+
|
11 |
+
class AdaIN(nn.Module):
|
12 |
+
def __init__(self, epsilon=1e-5):
|
13 |
+
super().__init__()
|
14 |
+
self.epsilon = epsilon
|
15 |
+
|
16 |
+
def forward(self, content: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
|
17 |
+
return (torch.mul(sigma(style, self.epsilon), ((content - mi(content)) / sigma(content, self.epsilon))) + mi(style))
|
src/loss.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from adain import mi, sigma
|
6 |
+
|
7 |
+
|
8 |
+
class Loss(nn.Module):
|
9 |
+
def __init__(self, lamb=8):
|
10 |
+
super().__init__()
|
11 |
+
self.lamb = lamb
|
12 |
+
|
13 |
+
def content_loss(self, enc_out: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
14 |
+
return F.mse_loss(enc_out, t)
|
15 |
+
|
16 |
+
def style_loss(self, out_activations: dict, style_activations: dict) -> torch.Tensor:
|
17 |
+
means, sds = 0, 0
|
18 |
+
for out_act, style_act in zip(out_activations.values(), style_activations.values()):
|
19 |
+
means += F.mse_loss(mi(out_act), mi(style_act))
|
20 |
+
sds += F.mse_loss(sigma(out_act), sigma(style_act))
|
21 |
+
|
22 |
+
return means + sds
|
23 |
+
|
24 |
+
def forward(self, enc_out: torch.Tensor, t: torch.Tensor, out_activations: dict, style_activations: dict) -> torch.Tensor:
|
25 |
+
self.loss_c = self.content_loss(enc_out, t)
|
26 |
+
self.loss_s = self.style_loss(out_activations, style_activations)
|
27 |
+
|
28 |
+
return (self.loss_c + self.lamb * self.loss_s)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
src/main.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import numpy as np
|
4 |
+
import wandb
|
5 |
+
import copy
|
6 |
+
import argparse
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from torchinfo import summary
|
12 |
+
|
13 |
+
from utils import StyleContentDataset, DataStore, denorm_img
|
14 |
+
from loss import Loss
|
15 |
+
from model import Model
|
16 |
+
|
17 |
+
|
18 |
+
config = {
|
19 |
+
"lr": 1e-4,
|
20 |
+
"max_iter": 80000,
|
21 |
+
"logging_interval": 100,
|
22 |
+
"preview_interval": 1000,
|
23 |
+
"batch_size": 4,
|
24 |
+
"activations": "ReLU",
|
25 |
+
"optimizer": "Adam",
|
26 |
+
"lambda": 7
|
27 |
+
}
|
28 |
+
|
29 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
30 |
+
print(f"Using {device} device")
|
31 |
+
|
32 |
+
def prepare_data(style_dir, content_dir, preview_dir):
|
33 |
+
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
34 |
+
|
35 |
+
# Training images
|
36 |
+
transform = transforms.Compose([transforms.Resize(512), transforms.RandomCrop(256)])
|
37 |
+
style_imgs = glob.glob(os.path.join(style_dir, '*.jpg'))
|
38 |
+
content_imgs = glob.glob(os.path.join(content_dir, '*.jpg'))
|
39 |
+
|
40 |
+
train_dataset = StyleContentDataset(style_imgs, content_imgs, transform=transform, normalize=norm)
|
41 |
+
datastore = DataStore(train_dataset, batch_size=config['batch_size'], shuffle=True)
|
42 |
+
|
43 |
+
# Preview images
|
44 |
+
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256)])
|
45 |
+
preview_style_imgs = glob.glob(os.path.join(preview_dir, 'style/*.jpg'))
|
46 |
+
preview_content_imgs = glob.glob(os.path.join(preview_dir, 'content/*.jpg'))
|
47 |
+
|
48 |
+
# preview_dataset = StyleContentDataset(preview_style_imgs, preview_content_imgs, transform=transform, normalize=norm)
|
49 |
+
preview_dataset = StyleContentDataset(preview_style_imgs, [preview_content_imgs[8]] * len(preview_style_imgs), transform=transform, normalize=norm)
|
50 |
+
preview_datastore = DataStore(preview_dataset, batch_size=len(preview_dataset), shuffle=False)
|
51 |
+
|
52 |
+
return datastore, preview_datastore
|
53 |
+
|
54 |
+
|
55 |
+
def preview(model: Model, datastore: DataStore, iteration, save=False, use_wandb=False):
|
56 |
+
model.eval()
|
57 |
+
with torch.no_grad():
|
58 |
+
# np.random.shuffle(datastore.dataset.style_imgs)
|
59 |
+
# np.random.shuffle(datastore.dataset.content_imgs)
|
60 |
+
|
61 |
+
style, content = datastore.get()
|
62 |
+
style, content = style.to(device), content.to(device)
|
63 |
+
out = model(content, style)
|
64 |
+
|
65 |
+
fig, axs = plt.subplots(8, 6, figsize=(20, 26))
|
66 |
+
axs = axs.flatten()
|
67 |
+
i = 0
|
68 |
+
for (s, c, o) in zip(style, content, out): # style, content, out
|
69 |
+
axs[i].imshow(denorm_img(s.cpu()).permute(1, 2, 0))
|
70 |
+
axs[i].axis('off')
|
71 |
+
axs[i].set_title('style')
|
72 |
+
axs[i+1].imshow(denorm_img(c.cpu()).permute(1, 2, 0))
|
73 |
+
axs[i+1].axis('off')
|
74 |
+
axs[i+1].set_title('content')
|
75 |
+
axs[i+2].imshow(denorm_img(o.cpu()).permute(1, 2, 0))
|
76 |
+
axs[i+2].axis('off')
|
77 |
+
axs[i+2].set_title('output')
|
78 |
+
i += 3
|
79 |
+
|
80 |
+
if save:
|
81 |
+
fig.savefig(f'outputs/{iteration}_preview.png')
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
if use_wandb:
|
85 |
+
wandb.log({'preview': wandb.Image(f'outputs/{iteration}_preview.png')}, step=iteration)
|
86 |
+
|
87 |
+
|
88 |
+
def train_one_iter(datastore: DataStore, model: Model, optimizer: torch.optim.Adam, loss_fn: Loss):
|
89 |
+
model.train()
|
90 |
+
|
91 |
+
style, content = datastore.get()
|
92 |
+
style, content = style.to(device), content.to(device)
|
93 |
+
|
94 |
+
optimizer.zero_grad()
|
95 |
+
|
96 |
+
# Forward
|
97 |
+
out = model(content, style)
|
98 |
+
|
99 |
+
# Save activations
|
100 |
+
style_activations = copy.deepcopy(model.activations)
|
101 |
+
|
102 |
+
enc_out = model.encoder(out)
|
103 |
+
out_activations = model.activations
|
104 |
+
|
105 |
+
# Compute loss
|
106 |
+
loss = loss_fn(enc_out, model.t, out_activations, style_activations)
|
107 |
+
|
108 |
+
# Update parameters
|
109 |
+
loss.backward()
|
110 |
+
optimizer.step()
|
111 |
+
|
112 |
+
return loss.item(), loss_fn.loss_c.item(), loss_fn.loss_s.item()
|
113 |
+
|
114 |
+
|
115 |
+
def train(datastore, preview_datastore, model: Model, optimizer: torch.optim.Adam, use_wandb=False):
|
116 |
+
train_history = {'style_loss': [], 'content_loss': [], 'loss': []}
|
117 |
+
|
118 |
+
# optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr'])
|
119 |
+
loss_fn = Loss(lamb=config['lambda'])
|
120 |
+
|
121 |
+
for i in range(config['max_iter']):
|
122 |
+
loss, content_loss, style_loss = train_one_iter(datastore, model, optimizer, loss_fn)
|
123 |
+
train_history['loss'].append(loss)
|
124 |
+
train_history['style_loss'].append(style_loss)
|
125 |
+
train_history['content_loss'].append(content_loss)
|
126 |
+
|
127 |
+
if i%config['logging_interval'] == 0:
|
128 |
+
print(f'iter: {i}')
|
129 |
+
print(f'loss: {loss:>5f}, style loss: {style_loss:>5f}, content loss: {content_loss:>5f}')
|
130 |
+
print('-------------------------------')
|
131 |
+
|
132 |
+
if use_wandb:
|
133 |
+
wandb.log({
|
134 |
+
'iter': i, 'loss': loss, 'style_loss': style_loss, 'content_loss': content_loss
|
135 |
+
})
|
136 |
+
|
137 |
+
if i%config['preview_interval'] == 0:
|
138 |
+
torch.save({
|
139 |
+
'iter': i, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict()
|
140 |
+
}, 'outputs/checkpoint.pt')
|
141 |
+
preview(model, preview_datastore, i, save=True, use_wandb=use_wandb)
|
142 |
+
|
143 |
+
return train_history
|
144 |
+
|
145 |
+
|
146 |
+
def main():
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument('--content_path', type=str, help='path to content dataset')
|
149 |
+
parser.add_argument('--style_path', type=str, help='path to content dataset')
|
150 |
+
parser.add_argument('--preview_path', type=str, help='path to preview dataset')
|
151 |
+
parser.add_argument('--wandb', type=str, help='wandb id')
|
152 |
+
parser.add_argument('--model_path', type=str, help='path to model')
|
153 |
+
args = parser.parse_args()
|
154 |
+
|
155 |
+
use_wandb = False
|
156 |
+
wandb_key = args.wandb
|
157 |
+
if wandb_key:
|
158 |
+
wandb.login(key=wandb_key)
|
159 |
+
wandb.init(project="assignment-3", name="", reinit=True, config=config)
|
160 |
+
use_wandb = True
|
161 |
+
|
162 |
+
if args.content_path and args.style_path and args.preview_path:
|
163 |
+
content_dir = args.content_path
|
164 |
+
style_dir = args.style_path
|
165 |
+
preview_dir = args.preview_path
|
166 |
+
else:
|
167 |
+
print('You didnt specify the data path >:(')
|
168 |
+
return
|
169 |
+
|
170 |
+
if not os.path.isdir('outputs'):
|
171 |
+
os.mkdir('outputs')
|
172 |
+
|
173 |
+
datastore, preview_datastore = prepare_data(style_dir, content_dir, preview_dir)
|
174 |
+
|
175 |
+
model = Model()
|
176 |
+
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr'])
|
177 |
+
if args.model_path:
|
178 |
+
# From checkpoint
|
179 |
+
checkpoint = torch.load('outputs/checkpoint.pt')
|
180 |
+
model.load_state_dict(checkpoint['model_state'])
|
181 |
+
optimizer.load_state_dict(checkpoint['optimizer_state'])
|
182 |
+
config['max_iter'] -= checkpoint['iter']
|
183 |
+
|
184 |
+
# From final model
|
185 |
+
# model.load_state_dict(torch.load(args.model_path, map_location=torch.device(device)))
|
186 |
+
# print(summary(model))
|
187 |
+
model.to(device)
|
188 |
+
|
189 |
+
train(datastore, preview_datastore, model, optimizer, use_wandb)
|
190 |
+
|
191 |
+
torch.save(model.state_dict(), 'outputs/model.pt')
|
192 |
+
if use_wandb:
|
193 |
+
artifact = wandb.Artifact('model', type='model')
|
194 |
+
artifact.add_file('outputs/model.pt')
|
195 |
+
wandb.log_artifact(artifact)
|
196 |
+
wandb.finish()
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
main()
|
src/model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torchvision.models import vgg19
|
3 |
+
import torchvision
|
4 |
+
from src.adain import AdaIN
|
5 |
+
|
6 |
+
class Model(nn.Module):
|
7 |
+
def __init__(self, alpha=1.0):
|
8 |
+
super().__init__()
|
9 |
+
self.alpha = alpha
|
10 |
+
|
11 |
+
self.encoder = nn.Sequential(*list(vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features)[:21])
|
12 |
+
|
13 |
+
for param in self.encoder.parameters():
|
14 |
+
param.requires_grad = False
|
15 |
+
|
16 |
+
# set padding in conv layers to reflect
|
17 |
+
# create dict for saving activations used in the style loss
|
18 |
+
self.activations = {}
|
19 |
+
for i, module in enumerate(self.encoder.children()):
|
20 |
+
if isinstance(module, nn.Conv2d):
|
21 |
+
module.padding_mode = 'reflect'
|
22 |
+
|
23 |
+
if i in [1, 6, 11, 20]:
|
24 |
+
module.register_forward_hook(self._save_activations(i))
|
25 |
+
|
26 |
+
self.AdaIN = AdaIN()
|
27 |
+
|
28 |
+
self.decoder = nn.Sequential(
|
29 |
+
nn.Upsample(scale_factor=2.0, mode='nearest'),
|
30 |
+
nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
31 |
+
nn.ReLU(),
|
32 |
+
|
33 |
+
nn.Upsample(scale_factor=2.0, mode='nearest'),
|
34 |
+
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
37 |
+
nn.ReLU(),
|
38 |
+
|
39 |
+
nn.Upsample(scale_factor=2.0, mode='nearest'),
|
40 |
+
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
41 |
+
nn.ReLU(),
|
42 |
+
nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
43 |
+
nn.ReLU(),
|
44 |
+
|
45 |
+
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
48 |
+
nn.Tanh()
|
49 |
+
)
|
50 |
+
|
51 |
+
# https://stackoverflow.com/a/68854535
|
52 |
+
def _save_activations(self, name):
|
53 |
+
def hook(module, input, output):
|
54 |
+
self.activations[name] = output
|
55 |
+
return hook
|
56 |
+
|
57 |
+
def forward(self, content, style):
|
58 |
+
enc_content = self.encoder(content)
|
59 |
+
enc_style = self.encoder(style)
|
60 |
+
|
61 |
+
self.t = self.AdaIN(enc_content, enc_style)
|
62 |
+
self.t = (1.0 - self.alpha) * enc_content + self.alpha * self.t
|
63 |
+
out = self.decoder(self.t)
|
64 |
+
|
65 |
+
return out
|
66 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from torchvision.io import read_image, ImageReadMode
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def denorm_img(img: torch.Tensor) -> torch.Tensor:
|
8 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
|
9 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
|
10 |
+
return torch.clip(img * std + mean, min=0, max=1)
|
11 |
+
|
12 |
+
|
13 |
+
class StyleContentDataset(Dataset):
|
14 |
+
def __init__(self, style_imgs, content_imgs, transform=None, normalize=None):
|
15 |
+
self.style_imgs = style_imgs
|
16 |
+
self.content_imgs = content_imgs
|
17 |
+
self.transform = transform
|
18 |
+
self.normalize = normalize
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
if len(self.style_imgs) < len(self.content_imgs):
|
22 |
+
return len(self.style_imgs)
|
23 |
+
else:
|
24 |
+
return len(self.content_imgs)
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
try:
|
28 |
+
style = read_image(self.style_imgs[idx], ImageReadMode.RGB).float() / 255.0
|
29 |
+
content = read_image(self.content_imgs[idx], ImageReadMode.RGB).float() / 255.0
|
30 |
+
except RuntimeError:
|
31 |
+
print(self.style_imgs[idx])
|
32 |
+
print(self.content_imgs[idx])
|
33 |
+
style = read_image(self.style_imgs[0], ImageReadMode.RGB).float() / 255.0
|
34 |
+
content = read_image(self.content_imgs[0], ImageReadMode.RGB).float() / 255.0
|
35 |
+
|
36 |
+
if self.normalize:
|
37 |
+
style = self.normalize(style)
|
38 |
+
content = self.normalize(content)
|
39 |
+
|
40 |
+
if self.transform:
|
41 |
+
style = self.transform(style)
|
42 |
+
content = self.transform(content)
|
43 |
+
|
44 |
+
return style, content
|
45 |
+
|
46 |
+
|
47 |
+
class DataStore():
|
48 |
+
def __init__(self, dataset: StyleContentDataset, batch_size, shuffle=False):
|
49 |
+
self.dataset = dataset
|
50 |
+
self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)
|
51 |
+
self.iterator = iter(self.dataloader)
|
52 |
+
|
53 |
+
def get(self):
|
54 |
+
try:
|
55 |
+
style, content = next(self.iterator)
|
56 |
+
except (StopIteration):
|
57 |
+
# print('| Repeating |')
|
58 |
+
# np.random.shuffle(self.dataset.style_imgs)
|
59 |
+
self.iterator = iter(self.dataloader)
|
60 |
+
style, content = next(self.iterator)
|
61 |
+
|
62 |
+
return style, content
|