Style transfer in machine learning is a technique that allows the artistic style of one image to be applied to another image. It involves extracting the style features from one image (e.g. colors, textures, patterns) and applying them to the content of another image while preserving its original structure. This process is often achieved using deep neural networks to separate and recombine the content and style representations of images.
In this project I use a pretrained convolutional neural network to create a style transfer application in PyTorch. The code is bassed on Gatys et al. and the PyTorch for Deep Learning and Computer Vision Udemy course from Slim et al.
import torch
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
For this project we will use the pretrained VGG 19 model. We will keep the feature layers fixed throughout.
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
param.requires_grad_(False)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
# Use the available GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (17): ReLU(inplace=True) (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (24): ReLU(inplace=True) (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (26): ReLU(inplace=True) (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (31): ReLU(inplace=True) (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (33): ReLU(inplace=True) (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (35): ReLU(inplace=True) (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) )
def load_image(img_path, max_size=400,rotate=None, shape=None):
'''
Load and transform image. Can be rotated if needed
'''
if rotate!=None:
image = Image.open(img_path).convert('RGB').rotate(rotate, expand=True)
else:
image = Image.open(img_path).convert('RGB')
if max(image.size) > max_size:
print(image.size)
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
image = in_transform(image).unsqueeze(0)
return image
We will now load the necessary images for the style transfer. Our model will try to transfer the qualities of the style image to the content image. In this example, Ill use as our content image a picture I took while hiking of one of the many waterfalls in the Clemson area. For the style image we'll try out Edvard Munch's "The Scream". These images can easily be switched out to explore the effects from different content and styles!
# Originally this notebook was run in Google Colab, so we need our mount our google drive.
from google.colab import drive
drive.mount('/content/gdrive')
Mounted at /content/gdrive
path = '/content/gdrive/MyDrive/Career/Projects/PyTorch/'
content = load_image(path+'waterfall.jpg', rotate=270).to(device)
style = load_image(path+'the-scream.jpg').to(device)
(3024, 4032) (594, 750)
def im_convert(tensor):
'''Converts the tensor to an imagethat can be viewed'''
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
image = image.clip(0, 1)
return image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10))
ax1.imshow(im_convert(content))
ax1.axis("off")
ax1.set_title('Content')
ax2.imshow(im_convert(style))
ax2.set_title('Style')
ax2.axis("off")
(-0.5, 399.5, 504.5, -0.5)
# Features used for extarction come from the recommendation of the Gatys et al. paper
def get_features(image, model):
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', # Content Extraction
'28': 'conv5_1'}
features = {}
for name, layer in model._modules.items():
image = layer(image)
if name in layers:
features[layers[name]] = image
return features
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
content_features.keys(), style_features.keys()
(dict_keys(['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv4_2', 'conv5_1']), dict_keys(['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv4_2', 'conv5_1']))
# gram matrix helps eliminate content features from style image while maintaining style features
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
style_weights = {'conv1_1': 1.,
'conv2_1': 0.75,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.6}
# These are weighting factors for content and style reconstruction,
# corresponding to alpha and beta in Eq 7 of Gatys et al.
content_weight = 1 # alpha
style_weight = 1e5 # beta
target = content.clone().requires_grad_(True).to(device)
show_every = 300
optimizer = optim.Adam([target], lr=0.003)
steps = 1000
frames=50
height, width, channels = im_convert(target).shape
image_array = np.empty(shape=(frames + 1, height, width, channels))
capture_frame = steps/frames
counter = 0
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
_, d, h, w = target_feature.shape
style_loss += layer_style_loss / (d * h * w)
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if ii % show_every == 0:
print('Total loss: ', total_loss.item())
print('Iteration: ', ii)
plt.imshow(im_convert(target))
plt.axis("off")
plt.show()
if ii % capture_frame == 0 or ii==1:
image_array[counter] = im_convert(target)
counter = counter + 1
Total loss: 84264.1640625 Iteration: 300
Total loss: 27073.107421875 Iteration: 600
Total loss: 17053.888671875 Iteration: 900
Here we see the coomparison of the input content image, the style image, and the output result image where the style features has been applied to the original content. Overall it creates a nice effect. In particular the color transfers very well, and some of the texture transfers also. I think that the texture/pattern transfer is lacking however. Our final image has more of a blurring effect vs the bolder lines in Munch's painting. Nevertheless, we have succesffully construceted a style transfer application using pretrained-models in PyTorch!
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.axis('off')
ax2.imshow(im_convert(style))
ax2.axis('off')
ax3.imshow(im_convert(target))
ax3.axis('off')
(-0.5, 399.5, 532.5, -0.5)
These next sections create a video that follows the progress of the style transfer process through each step.
import cv2
frame_height, frame_width, _ = im_convert(target).shape
vid = cv2.VideoWriter('style_transfer.mp4', cv2.VideoWriter_fourcc(*'XVID'), 10, (frame_width, frame_height))
for img in image_array:
img = img*255
img = np.array(img, dtype = np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
vid.write(img)
vid.release()
from google.colab import files
files.download('style_transfer.mp4')