kuko6 commited on
Commit
c583015
·
1 Parent(s): c228a11

added files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+ .vscode/
4
+
5
+ test.*
6
+ test/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ ai-env
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
  title: Style Transfer
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Style Transfer
3
+ emoji: 👨‍🎨
4
+ colorFrom: pink
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Style Transfer
13
+
14
+ ## References
15
+
16
+ [1] Huang, Xun, and Serge Belongie. "Arbitrary style transfer in real-time with adaptive instance normalization." *Proceedings of the IEEE international conference on computer vision*. 2017.
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms.functional as TF
4
+ import torchvision.transforms as transforms
5
+ from src.model import Model
6
+ import os
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+
11
+ def denorm_img(img: torch.Tensor):
12
+ std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
13
+ mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
14
+ return torch.clip(img * std + mean, min=0, max=1)
15
+
16
+
17
+ def main(inp1, inp2, alph, out_size=256):
18
+ model = Model()
19
+ model.load_state_dict(torch.load("models/model_puddle.pt", map_location=torch.device(device)))
20
+ model.eval()
21
+
22
+ model.alpha = alph
23
+
24
+ style = TF.to_tensor(inp1)
25
+ content = TF.to_tensor(inp2)
26
+
27
+ norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
+ transform = transforms.Compose(
29
+ [transforms.Resize(out_size, antialias=True), transforms.CenterCrop(out_size)]
30
+ )
31
+
32
+ style, content = norm(style), norm(content)
33
+ style, content = transform(style), transform(content)
34
+
35
+ style, content = style.unsqueeze(0).to(device), content.unsqueeze(0).to(device)
36
+
37
+ out = model(content, style)
38
+
39
+ return denorm_img(out[0].detach()).permute(1, 2, 0).numpy()
40
+
41
+
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("# Style Transfer with AdaIN")
44
+ with gr.Row(variant="compact"):
45
+ inp1 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Style")
46
+ inp2 = gr.Image(type="pil", sources=["upload", "clipboard"], label="Content")
47
+ out = gr.Image(type="numpy", label="Output")
48
+ with gr.Row():
49
+ out_size = (
50
+ gr.Dropdown(
51
+ choices=[256, 512],
52
+ value=256,
53
+ multiselect=False,
54
+ interactive=True,
55
+ allow_custom_value=True,
56
+ label="Output size",
57
+ info="Size of the output image",
58
+ ),
59
+ )
60
+ alph = gr.Slider(0, 1, value=1, label="Alpha", info="How much to change the original image", interactive=True, scale=3)
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ gr.Markdown("## Style Examples")
65
+ gr.Examples(
66
+ examples=[
67
+ os.path.join(os.path.dirname(__file__), "data/styles/25.jpg"),
68
+ os.path.join(os.path.dirname(__file__), "data/styles/2272.jpg"),
69
+ os.path.join(os.path.dirname(__file__), "data/styles/2314.jpg"),
70
+ ],
71
+ inputs=inp1,
72
+ )
73
+ with gr.Column():
74
+ gr.Markdown("## Content Examples")
75
+ gr.Examples(
76
+ examples=[
77
+ os.path.join(os.path.dirname(__file__), "data/content/bear.jpg"),
78
+ os.path.join(os.path.dirname(__file__), "data/content/cow.jpg"),
79
+ os.path.join(os.path.dirname(__file__), "data/content/ducks.jpg"),
80
+ ],
81
+ inputs=inp2,
82
+ )
83
+ btn = gr.Button("Run")
84
+ btn.click(fn=main, inputs=[inp1, inp2, alph, out_size[0]], outputs=out)
85
+
86
+ demo.launch()
data/content/bear.jpg ADDED

Git LFS Details

  • SHA256: f3a2974ce3686332609124c70e3e6a2e3aca43fccf1cd1bd7c5c03820977f57d
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
data/content/cow.jpg ADDED

Git LFS Details

  • SHA256: a1d362810f97e0dd00ecda4f1d427aec52ba3361c1a15f00cc525d9dc8216ad3
  • Pointer size: 130 Bytes
  • Size of remote file: 90.9 kB
data/content/ducks.jpg ADDED

Git LFS Details

  • SHA256: ae0cf5374adfa2c78f50e5fc58b51a18e8db2285f00912c2eca6a2af204857d1
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
data/styles/2272.jpg ADDED

Git LFS Details

  • SHA256: e8ba2aa73ebb7f4e1f8554c18a1e2b12ab60b6e4422a3a3651acf021ced59260
  • Pointer size: 133 Bytes
  • Size of remote file: 26.4 MB
data/styles/2314.jpg ADDED

Git LFS Details

  • SHA256: d7ae6fb18550ccedafb30b97d6f6ea4939ee82969bc9c74e1d8012741f746d3e
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
data/styles/25.jpg ADDED

Git LFS Details

  • SHA256: 6f8f92da7b6113d62484d8f893b7ef4f4797804e0d3f2fb64a04c24f1d9a8269
  • Pointer size: 130 Bytes
  • Size of remote file: 73 kB
models/checkpoint_puddle_70k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:090b259580c0a6a7fbece489de7177aa5decebeea7dc6f26dab4d9e9aeb6f700
3
+ size 41942833
models/checkpoint_puddle_79k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f771b61eeb8b1505f690c6c09bbca269d17473b82cb8ed905877d9ed7de26bc2
3
+ size 41942833
models/model_puddle.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a215c24cf629cf5f2b2cf96014347fda704c06acd68e312720c5783eebde5ca
3
+ size 23333701
nb.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53f653d323480c9ccce8283220a78f5f68fba868ca181923d41da2869408bd46
3
+ size 60328933
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ torchinfo
src/adain.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def mi(x: torch.Tensor) -> torch.Tensor:
6
+ return torch.sum(x, dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3])
7
+
8
+ def sigma(x: torch.Tensor, epsilon=1e-5) -> torch.Tensor:
9
+ return torch.sqrt(torch.sum(((x - mi(x))**2 + epsilon), dim=(2, 3), keepdim=True) / (x.shape[2] * x.shape[3]))
10
+
11
+ class AdaIN(nn.Module):
12
+ def __init__(self, epsilon=1e-5):
13
+ super().__init__()
14
+ self.epsilon = epsilon
15
+
16
+ def forward(self, content: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
17
+ return (torch.mul(sigma(style, self.epsilon), ((content - mi(content)) / sigma(content, self.epsilon))) + mi(style))
src/loss.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from adain import mi, sigma
6
+
7
+
8
+ class Loss(nn.Module):
9
+ def __init__(self, lamb=8):
10
+ super().__init__()
11
+ self.lamb = lamb
12
+
13
+ def content_loss(self, enc_out: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
14
+ return F.mse_loss(enc_out, t)
15
+
16
+ def style_loss(self, out_activations: dict, style_activations: dict) -> torch.Tensor:
17
+ means, sds = 0, 0
18
+ for out_act, style_act in zip(out_activations.values(), style_activations.values()):
19
+ means += F.mse_loss(mi(out_act), mi(style_act))
20
+ sds += F.mse_loss(sigma(out_act), sigma(style_act))
21
+
22
+ return means + sds
23
+
24
+ def forward(self, enc_out: torch.Tensor, t: torch.Tensor, out_activations: dict, style_activations: dict) -> torch.Tensor:
25
+ self.loss_c = self.content_loss(enc_out, t)
26
+ self.loss_s = self.style_loss(out_activations, style_activations)
27
+
28
+ return (self.loss_c + self.lamb * self.loss_s)
29
+
30
+
31
+
32
+
src/main.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ import wandb
5
+ import copy
6
+ import argparse
7
+ import matplotlib.pyplot as plt
8
+
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ from torchinfo import summary
12
+
13
+ from utils import StyleContentDataset, DataStore, denorm_img
14
+ from loss import Loss
15
+ from model import Model
16
+
17
+
18
+ config = {
19
+ "lr": 1e-4,
20
+ "max_iter": 80000,
21
+ "logging_interval": 100,
22
+ "preview_interval": 1000,
23
+ "batch_size": 4,
24
+ "activations": "ReLU",
25
+ "optimizer": "Adam",
26
+ "lambda": 7
27
+ }
28
+
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+ print(f"Using {device} device")
31
+
32
+ def prepare_data(style_dir, content_dir, preview_dir):
33
+ norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
+
35
+ # Training images
36
+ transform = transforms.Compose([transforms.Resize(512), transforms.RandomCrop(256)])
37
+ style_imgs = glob.glob(os.path.join(style_dir, '*.jpg'))
38
+ content_imgs = glob.glob(os.path.join(content_dir, '*.jpg'))
39
+
40
+ train_dataset = StyleContentDataset(style_imgs, content_imgs, transform=transform, normalize=norm)
41
+ datastore = DataStore(train_dataset, batch_size=config['batch_size'], shuffle=True)
42
+
43
+ # Preview images
44
+ transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256)])
45
+ preview_style_imgs = glob.glob(os.path.join(preview_dir, 'style/*.jpg'))
46
+ preview_content_imgs = glob.glob(os.path.join(preview_dir, 'content/*.jpg'))
47
+
48
+ # preview_dataset = StyleContentDataset(preview_style_imgs, preview_content_imgs, transform=transform, normalize=norm)
49
+ preview_dataset = StyleContentDataset(preview_style_imgs, [preview_content_imgs[8]] * len(preview_style_imgs), transform=transform, normalize=norm)
50
+ preview_datastore = DataStore(preview_dataset, batch_size=len(preview_dataset), shuffle=False)
51
+
52
+ return datastore, preview_datastore
53
+
54
+
55
+ def preview(model: Model, datastore: DataStore, iteration, save=False, use_wandb=False):
56
+ model.eval()
57
+ with torch.no_grad():
58
+ # np.random.shuffle(datastore.dataset.style_imgs)
59
+ # np.random.shuffle(datastore.dataset.content_imgs)
60
+
61
+ style, content = datastore.get()
62
+ style, content = style.to(device), content.to(device)
63
+ out = model(content, style)
64
+
65
+ fig, axs = plt.subplots(8, 6, figsize=(20, 26))
66
+ axs = axs.flatten()
67
+ i = 0
68
+ for (s, c, o) in zip(style, content, out): # style, content, out
69
+ axs[i].imshow(denorm_img(s.cpu()).permute(1, 2, 0))
70
+ axs[i].axis('off')
71
+ axs[i].set_title('style')
72
+ axs[i+1].imshow(denorm_img(c.cpu()).permute(1, 2, 0))
73
+ axs[i+1].axis('off')
74
+ axs[i+1].set_title('content')
75
+ axs[i+2].imshow(denorm_img(o.cpu()).permute(1, 2, 0))
76
+ axs[i+2].axis('off')
77
+ axs[i+2].set_title('output')
78
+ i += 3
79
+
80
+ if save:
81
+ fig.savefig(f'outputs/{iteration}_preview.png')
82
+ plt.close(fig)
83
+
84
+ if use_wandb:
85
+ wandb.log({'preview': wandb.Image(f'outputs/{iteration}_preview.png')}, step=iteration)
86
+
87
+
88
+ def train_one_iter(datastore: DataStore, model: Model, optimizer: torch.optim.Adam, loss_fn: Loss):
89
+ model.train()
90
+
91
+ style, content = datastore.get()
92
+ style, content = style.to(device), content.to(device)
93
+
94
+ optimizer.zero_grad()
95
+
96
+ # Forward
97
+ out = model(content, style)
98
+
99
+ # Save activations
100
+ style_activations = copy.deepcopy(model.activations)
101
+
102
+ enc_out = model.encoder(out)
103
+ out_activations = model.activations
104
+
105
+ # Compute loss
106
+ loss = loss_fn(enc_out, model.t, out_activations, style_activations)
107
+
108
+ # Update parameters
109
+ loss.backward()
110
+ optimizer.step()
111
+
112
+ return loss.item(), loss_fn.loss_c.item(), loss_fn.loss_s.item()
113
+
114
+
115
+ def train(datastore, preview_datastore, model: Model, optimizer: torch.optim.Adam, use_wandb=False):
116
+ train_history = {'style_loss': [], 'content_loss': [], 'loss': []}
117
+
118
+ # optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr'])
119
+ loss_fn = Loss(lamb=config['lambda'])
120
+
121
+ for i in range(config['max_iter']):
122
+ loss, content_loss, style_loss = train_one_iter(datastore, model, optimizer, loss_fn)
123
+ train_history['loss'].append(loss)
124
+ train_history['style_loss'].append(style_loss)
125
+ train_history['content_loss'].append(content_loss)
126
+
127
+ if i%config['logging_interval'] == 0:
128
+ print(f'iter: {i}')
129
+ print(f'loss: {loss:>5f}, style loss: {style_loss:>5f}, content loss: {content_loss:>5f}')
130
+ print('-------------------------------')
131
+
132
+ if use_wandb:
133
+ wandb.log({
134
+ 'iter': i, 'loss': loss, 'style_loss': style_loss, 'content_loss': content_loss
135
+ })
136
+
137
+ if i%config['preview_interval'] == 0:
138
+ torch.save({
139
+ 'iter': i, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict()
140
+ }, 'outputs/checkpoint.pt')
141
+ preview(model, preview_datastore, i, save=True, use_wandb=use_wandb)
142
+
143
+ return train_history
144
+
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument('--content_path', type=str, help='path to content dataset')
149
+ parser.add_argument('--style_path', type=str, help='path to content dataset')
150
+ parser.add_argument('--preview_path', type=str, help='path to preview dataset')
151
+ parser.add_argument('--wandb', type=str, help='wandb id')
152
+ parser.add_argument('--model_path', type=str, help='path to model')
153
+ args = parser.parse_args()
154
+
155
+ use_wandb = False
156
+ wandb_key = args.wandb
157
+ if wandb_key:
158
+ wandb.login(key=wandb_key)
159
+ wandb.init(project="assignment-3", name="", reinit=True, config=config)
160
+ use_wandb = True
161
+
162
+ if args.content_path and args.style_path and args.preview_path:
163
+ content_dir = args.content_path
164
+ style_dir = args.style_path
165
+ preview_dir = args.preview_path
166
+ else:
167
+ print('You didnt specify the data path >:(')
168
+ return
169
+
170
+ if not os.path.isdir('outputs'):
171
+ os.mkdir('outputs')
172
+
173
+ datastore, preview_datastore = prepare_data(style_dir, content_dir, preview_dir)
174
+
175
+ model = Model()
176
+ optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr'])
177
+ if args.model_path:
178
+ # From checkpoint
179
+ checkpoint = torch.load('outputs/checkpoint.pt')
180
+ model.load_state_dict(checkpoint['model_state'])
181
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
182
+ config['max_iter'] -= checkpoint['iter']
183
+
184
+ # From final model
185
+ # model.load_state_dict(torch.load(args.model_path, map_location=torch.device(device)))
186
+ # print(summary(model))
187
+ model.to(device)
188
+
189
+ train(datastore, preview_datastore, model, optimizer, use_wandb)
190
+
191
+ torch.save(model.state_dict(), 'outputs/model.pt')
192
+ if use_wandb:
193
+ artifact = wandb.Artifact('model', type='model')
194
+ artifact.add_file('outputs/model.pt')
195
+ wandb.log_artifact(artifact)
196
+ wandb.finish()
197
+
198
+
199
+ if __name__ == '__main__':
200
+ main()
src/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchvision.models import vgg19
3
+ import torchvision
4
+ from src.adain import AdaIN
5
+
6
+ class Model(nn.Module):
7
+ def __init__(self, alpha=1.0):
8
+ super().__init__()
9
+ self.alpha = alpha
10
+
11
+ self.encoder = nn.Sequential(*list(vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features)[:21])
12
+
13
+ for param in self.encoder.parameters():
14
+ param.requires_grad = False
15
+
16
+ # set padding in conv layers to reflect
17
+ # create dict for saving activations used in the style loss
18
+ self.activations = {}
19
+ for i, module in enumerate(self.encoder.children()):
20
+ if isinstance(module, nn.Conv2d):
21
+ module.padding_mode = 'reflect'
22
+
23
+ if i in [1, 6, 11, 20]:
24
+ module.register_forward_hook(self._save_activations(i))
25
+
26
+ self.AdaIN = AdaIN()
27
+
28
+ self.decoder = nn.Sequential(
29
+ nn.Upsample(scale_factor=2.0, mode='nearest'),
30
+ nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
31
+ nn.ReLU(),
32
+
33
+ nn.Upsample(scale_factor=2.0, mode='nearest'),
34
+ nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
35
+ nn.ReLU(),
36
+ nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
37
+ nn.ReLU(),
38
+
39
+ nn.Upsample(scale_factor=2.0, mode='nearest'),
40
+ nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
41
+ nn.ReLU(),
42
+ nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
43
+ nn.ReLU(),
44
+
45
+ nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
46
+ nn.ReLU(),
47
+ nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
48
+ nn.Tanh()
49
+ )
50
+
51
+ # https://stackoverflow.com/a/68854535
52
+ def _save_activations(self, name):
53
+ def hook(module, input, output):
54
+ self.activations[name] = output
55
+ return hook
56
+
57
+ def forward(self, content, style):
58
+ enc_content = self.encoder(content)
59
+ enc_style = self.encoder(style)
60
+
61
+ self.t = self.AdaIN(enc_content, enc_style)
62
+ self.t = (1.0 - self.alpha) * enc_content + self.alpha * self.t
63
+ out = self.decoder(self.t)
64
+
65
+ return out
66
+
src/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from torchvision.io import read_image, ImageReadMode
4
+ import numpy as np
5
+
6
+
7
+ def denorm_img(img: torch.Tensor) -> torch.Tensor:
8
+ std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
9
+ mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
10
+ return torch.clip(img * std + mean, min=0, max=1)
11
+
12
+
13
+ class StyleContentDataset(Dataset):
14
+ def __init__(self, style_imgs, content_imgs, transform=None, normalize=None):
15
+ self.style_imgs = style_imgs
16
+ self.content_imgs = content_imgs
17
+ self.transform = transform
18
+ self.normalize = normalize
19
+
20
+ def __len__(self):
21
+ if len(self.style_imgs) < len(self.content_imgs):
22
+ return len(self.style_imgs)
23
+ else:
24
+ return len(self.content_imgs)
25
+
26
+ def __getitem__(self, idx):
27
+ try:
28
+ style = read_image(self.style_imgs[idx], ImageReadMode.RGB).float() / 255.0
29
+ content = read_image(self.content_imgs[idx], ImageReadMode.RGB).float() / 255.0
30
+ except RuntimeError:
31
+ print(self.style_imgs[idx])
32
+ print(self.content_imgs[idx])
33
+ style = read_image(self.style_imgs[0], ImageReadMode.RGB).float() / 255.0
34
+ content = read_image(self.content_imgs[0], ImageReadMode.RGB).float() / 255.0
35
+
36
+ if self.normalize:
37
+ style = self.normalize(style)
38
+ content = self.normalize(content)
39
+
40
+ if self.transform:
41
+ style = self.transform(style)
42
+ content = self.transform(content)
43
+
44
+ return style, content
45
+
46
+
47
+ class DataStore():
48
+ def __init__(self, dataset: StyleContentDataset, batch_size, shuffle=False):
49
+ self.dataset = dataset
50
+ self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)
51
+ self.iterator = iter(self.dataloader)
52
+
53
+ def get(self):
54
+ try:
55
+ style, content = next(self.iterator)
56
+ except (StopIteration):
57
+ # print('| Repeating |')
58
+ # np.random.shuffle(self.dataset.style_imgs)
59
+ self.iterator = iter(self.dataloader)
60
+ style, content = next(self.iterator)
61
+
62
+ return style, content