Note
Click here to download the full example code
Neural Style Transfer without pystiche
¶
This example showcases how a basic Neural Style Transfer (NST), i.e. image-based
optimization, could be performed without pystiche
.
Note
This is an example how to implement an NST and not a tutorial on how NST works. As such, it will not explain why a specific choice was made or how a component works. If you have never worked with NST before, we strongly suggest you to read the Gist first.
Setup¶
We start this example by importing everything we need and setting the device we will
be working on. torch
and torchvision
will be used for the actual NST.
Furthermore, we use PIL.Image
for the file input, and matplotlib.pyplot
to show the images.
26 import itertools
27 import os.path
28 from collections import OrderedDict
29 from urllib.request import urlopen
30
31 import matplotlib.pyplot as plt
32 from PIL import Image
33 from tqdm.auto import tqdm
34
35 import torch
36 import torchvision
37 from torch import nn, optim
38 from torch.nn.functional import mse_loss
39 from torchvision import transforms
40 from torchvision.models import vgg19
41 from torchvision.transforms.functional import resize
42
43 print(f"I'm working with torch=={torch.__version__}")
44 print(f"I'm working with torchvision=={torchvision.__version__}")
45
46 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47 print(f"I'm working with {device}")
The core component of different NSTs is the perceptual loss, which is used as optimization criterion. The perceptual loss is usually, and also for this example, calculated on features maps also called encodings. These encodings are generated from different layers of a Convolutional Neural Net (CNN) also called encoder.
A common implementation strategy for the perceptual loss is to weave in transparent loss layers into the encoder. These loss layers are called transparent since from an outside view they simply pass the input through without alteration. Internally though, they calculate the loss with the encodings of the previous layer and store them in themselves. After the forward pass is completed the stored losses are aggregated and propagated backwards to the image. While this is simple to implement, this practice has two downsides:
The calculated score is part of the current state but has to be stored inside the layer. This is generally not recommended.
While the encoder is a part of the perceptual loss, it itself does not generate it. One should be able to use the same encoder with a different perceptual loss without modification.
Thus, this example (and pystiche
) follows a different approach and separates the
encoder and the perceptual loss into individual entities.
Multi-layer Encoder¶
In a first step we define a MultiLayerEncoder
that should have the following
properties:
Given an image and a set of layers, the
MultiLayerEncoder
should return the encodings of every given layer.Since the encodings have to be generated in every optimization step they should be calculated in a single forward pass to keep the processing costs low.
To reduce the static memory requirement, the
MultiLayerEncoder
should betrim
mable in order to remove unused layers.
We achieve the main functionality by subclassing torch.nn.Sequential
and
define a custom forward
method, i.e. different behavior if called. Besides the
image it also takes an iterable layer_cfgs
containing multiple sequences of
layers
. In the method body we first find the deepest_layer
that was
requested. Subsequently, we calculate and store all encodings of the image
up to
that layer. Finally we can return all requested encodings without processing the same
layer twice.
97 class MultiLayerEncoder(nn.Sequential):
98 def forward(self, image, *layer_cfgs):
99 storage = {}
100 deepest_layer = self._find_deepest_layer(*layer_cfgs)
101 for layer, module in self.named_children():
102 image = storage[layer] = module(image)
103 if layer == deepest_layer:
104 break
105
106 return [[storage[layer] for layer in layers] for layers in layer_cfgs]
107
108 def children_names(self):
109 for name, module in self.named_children():
110 yield name
111
112 def _find_deepest_layer(self, *layer_cfgs):
113 # find all unique requested layers
114 req_layers = set(itertools.chain(*layer_cfgs))
115 try:
116 # find the deepest requested layer by indexing the layers within
117 # the multi layer encoder
118 children_names = list(self.children_names())
119 return sorted(req_layers, key=children_names.index)[-1]
120 except ValueError as error:
121 layer = str(error).split()[0]
122 raise ValueError(f"Layer {layer} is not part of the multi-layer encoder.")
123
124 def trim(self, *layer_cfgs):
125 deepest_layer = self._find_deepest_layer(*layer_cfgs)
126 children_names = list(self.children_names())
127 del self[children_names.index(deepest_layer) + 1 :]
The pretrained models the MultiLayerEncoder
is based on are usually trained on
preprocessed images. In PyTorch all models expect images are
normalized by a
per-channel mean = (0.485, 0.456, 0.406)
and standard deviation
(std = (0.229, 0.224, 0.225)
). To include this into a, MultiLayerEncoder
, we
implement this as torch.nn.Module
.
139 class Normalize(nn.Module):
140 def __init__(self, mean, std):
141 super().__init__()
142 self.register_buffer("mean", torch.tensor(mean).view(1, -1, 1, 1))
143 self.register_buffer("std", torch.tensor(std).view(1, -1, 1, 1))
144
145 def forward(self, image):
146 return (image - self.mean) / self.std
147
148
149 class TorchNormalize(Normalize):
150 def __init__(self):
151 super().__init__((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
In a last step we need to specify the structure of the MultiLayerEncoder
. In this
example we use a VGGMultiLayerEncoder
based on the VGG19
CNN introduced by
Simonyan and Zisserman [SZ2014].
We only include the feature extraction stage (vgg_net.features
), i.e. the
convolutional stage, since the classifier stage (vgg_net.classifier
) only accepts
feature maps of a single size.
For our convenience we rename the layers in the same scheme the authors used instead
of keeping the consecutive index of a default torch.nn.Sequential
. The first
layer however is the TorchNormalize
as defined above.
168 class VGGMultiLayerEncoder(MultiLayerEncoder):
169 def __init__(self, vgg_net):
170 modules = OrderedDict((("preprocessing", TorchNormalize()),))
171
172 block = depth = 1
173 for module in vgg_net.features.children():
174 if isinstance(module, nn.Conv2d):
175 layer = f"conv{block}_{depth}"
176 elif isinstance(module, nn.BatchNorm2d):
177 layer = f"bn{block}_{depth}"
178 elif isinstance(module, nn.ReLU):
179 # without inplace=False the encodings of the previous layer would no
180 # longer be accessible after the ReLU layer is executed
181 module = nn.ReLU(inplace=False)
182 layer = f"relu{block}_{depth}"
183 # each ReLU layer increases the depth of the current block by one
184 depth += 1
185 elif isinstance(module, nn.MaxPool2d):
186 layer = f"pool{block}"
187 # each max pooling layer marks the end of the current block
188 block += 1
189 depth = 1
190 else:
191 msg = f"Type {type(module)} is not part of the VGG architecture."
192 raise RuntimeError(msg)
193
194 modules[layer] = module
195
196 super().__init__(modules)
197
198
199 def vgg19_multi_layer_encoder():
200 return VGGMultiLayerEncoder(vgg19(pretrained=True))
201
202
203 multi_layer_encoder = vgg19_multi_layer_encoder().to(device)
204 print(multi_layer_encoder)
Perceptual Loss¶
In order to calculate the perceptual loss, i.e. the optimization criterion, we define
a MultiLayerLoss
to have a convenient interface. This will be subclassed later by
the ContentLoss
and StyleLoss
.
If called with a sequence of ìnput_encs
the MultiLayerLoss
should calculate
layerwise scores together with the corresponding target_encs
. For that a
MultiLayerLoss
needs the ability to store the target_encs
so that they can be
reused for every call. The individual layer scores should be averaged by the number
of encodings and finally weighted by a score_weight
.
To achieve this we subclass torch.nn.Module
. The target_encs
are stored
as buffers, since they are not trainable parameters. The actual functionality has to
be defined in calculate_score
by a subclass.
226 def mean(sized):
227 return sum(sized) / len(sized)
228
229
230 class MultiLayerLoss(nn.Module):
231 def __init__(self, score_weight=1e0):
232 super().__init__()
233 self.score_weight = score_weight
234 self._numel_target_encs = 0
235
236 def _target_enc_name(self, idx):
237 return f"_target_encs_{idx}"
238
239 def set_target_encs(self, target_encs):
240 self._numel_target_encs = len(target_encs)
241 for idx, enc in enumerate(target_encs):
242 self.register_buffer(self._target_enc_name(idx), enc.detach())
243
244 @property
245 def target_encs(self):
246 return tuple(
247 getattr(self, self._target_enc_name(idx))
248 for idx in range(self._numel_target_encs)
249 )
250
251 def forward(self, input_encs):
252 if len(input_encs) != self._numel_target_encs:
253 msg = (
254 f"The number of given input encodings and stored target encodings "
255 f"does not match: {len(input_encs)} != {self._numel_target_encs}"
256 )
257 raise RuntimeError(msg)
258
259 layer_losses = [
260 self.calculate_score(input, target)
261 for input, target in zip(input_encs, self.target_encs)
262 ]
263 return mean(layer_losses) * self.score_weight
264
265 def calculate_score(self, input, target):
266 raise NotImplementedError
In this example we use the feature_reconstruction_loss
introduced by Mahendran
and Vedaldi [MV2015] as ContentLoss
as well as the gram_loss
introduced
by Gatys, Ecker, and Bethge [GEB2016] as StyleLoss
.
275 def feature_reconstruction_loss(input, target):
276 return mse_loss(input, target)
277
278
279 class ContentLoss(MultiLayerLoss):
280 def calculate_score(self, input, target):
281 return feature_reconstruction_loss(input, target)
282
283
284 def channelwise_gram_matrix(x, normalize=True):
285 x = torch.flatten(x, 2)
286 G = torch.bmm(x, x.transpose(1, 2))
287 if normalize:
288 return G / x.size()[-1]
289 else:
290 return G
291
292
293 def gram_loss(input, target):
294 return mse_loss(channelwise_gram_matrix(input), channelwise_gram_matrix(target))
295
296
297 class StyleLoss(MultiLayerLoss):
298 def calculate_score(self, input, target):
299 return gram_loss(input, target)
Images¶
Before we can load the content and style image, we need to define some basic I/O utilities.
At import a fake batch dimension is added to the images to be able to pass it through
the MultiLayerEncoder
without further modification. This dimension is removed
again upon export. Furthermore, all images will be resized to size=500
pixels.
313 import_from_pil = transforms.Compose(
314 (
315 transforms.ToTensor(),
316 transforms.Lambda(lambda x: x.unsqueeze(0)),
317 transforms.Lambda(lambda x: x.to(device)),
318 )
319 )
320
321 export_to_pil = transforms.Compose(
322 (
323 transforms.Lambda(lambda x: x.cpu()),
324 transforms.Lambda(lambda x: x.squeeze(0)),
325 transforms.Lambda(lambda x: x.clamp(0.0, 1.0)),
326 transforms.ToPILImage(),
327 )
328 )
329
330
331 def download_image(url):
332 file = os.path.abspath(os.path.basename(url))
333 with open(file, "wb") as fh, urlopen(url) as response:
334 fh.write(response.read())
335
336 return file
337
338
339 def read_image(file, size=500):
340 image = Image.open(file)
341 image = resize(image, size)
342 return import_from_pil(image)
343
344
345 def show_image(image, title=None):
346 _, ax = plt.subplots()
347 ax.axis("off")
348 if title is not None:
349 ax.set_title(title)
350
351 image = export_to_pil(image)
352 ax.imshow(image)
With the I/O utilities set up, we now download, read, and show the images that will be used in the NST.
Note
The images used in this example are licensed under the permissive Pixabay License .
367 content_url = "https://download.pystiche.org/images/bird1.jpg"
368 content_file = download_image(content_url)
369 content_image = read_image(content_file)
370 show_image(content_image, title="Content image")
375 style_url = "https://download.pystiche.org/images/paint.jpg"
376 style_file = download_image(style_url)
377 style_image = read_image(style_file)
378 show_image(style_image, title="Style image")
Neural Style Transfer¶
At first we chose the content_layers
and style_layers
on which the encodings
are compared. With them we trim
the multi_layer_encoder
to remove
unused layers that otherwise occupy memory.
Afterwards we calculate the target content and style encodings. The calculation is performed without a gradient since the gradient of the target encodings is not needed for the optimization.
393 content_layers = ("relu4_2",)
394 style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
395
396 multi_layer_encoder.trim(content_layers, style_layers)
397
398 with torch.no_grad():
399 target_content_encs = multi_layer_encoder(content_image, content_layers)[0]
400 target_style_encs = multi_layer_encoder(style_image, style_layers)[0]
Next up, we instantiate the ContentLoss
and StyleLoss
with a corresponding
weight. Afterwards we store the previously calculated target encodings.
407 content_weight = 1e0
408 content_loss = ContentLoss(score_weight=content_weight)
409 content_loss.set_target_encs(target_content_encs)
410
411 style_weight = 1e3
412 style_loss = StyleLoss(score_weight=style_weight)
413 style_loss.set_target_encs(target_style_encs)
We start NST from the content_image
since this way it converges quickly.
419 input_image = content_image.clone()
420 show_image(input_image, "Input image")
Note
If you want to start from a white noise image instead use
input_image = torch.rand_like(content_image)
In a last preliminary step we create the optimizer that will be performing the NST.
Since we want to adapt the pixels of the input_image
directly, we pass it as
optimization parameters.
438 optimizer = optim.LBFGS([input_image.requires_grad_(True)], max_iter=1)
Finally we run the NST. The loss calculation has to happen inside a closure
since the LBFGS
optimizer could need to
reevaluate it multiple times per optimization step
. This structure is also valid for all other optimizers.
447 num_steps = 500
448
449 with tqdm(desc="Image optimization", total=num_steps) as progress_bar:
450 for _ in range(num_steps):
451
452 def closure():
453 optimizer.zero_grad()
454
455 input_encs = multi_layer_encoder(input_image, content_layers, style_layers)
456 input_content_encs, input_style_encs = input_encs
457
458 content_score = content_loss(input_content_encs)
459 style_score = style_loss(input_style_encs)
460
461 perceptual_loss = content_score + style_score
462 perceptual_loss.backward()
463
464 progress_bar.set_postfix(
465 loss=f"{float(perceptual_loss):.3e}", refresh=False
466 )
467 progress_bar.update()
468
469 return perceptual_loss
470
471 optimizer.step(closure)
472
473 output_image = input_image.detach()
After the NST we show the resulting image.
478 show_image(output_image, title="Output image")
Conclusion¶
As hopefully has become clear, an NST requires even in its simplest form quite a lot of utilities and boilerplate code. This makes it hard to maintain and keep bug free as it is easy to lose track of everything.
Judging by the lines of code one could (falsely) conclude that the actual NST is just
an appendix. If you feel the same you can stop worrying now: in
Neural Style Transfer with pystiche we showcase
how to achieve the same result with pystiche
.
Total running time of the script: ( 0 minutes 0.000 seconds)
Estimated memory usage: 0 MB