March 14, 2021

Line Art Colorization Using a Deep Learning Model

Colorization of line art images is a tedious task. More when we have similar line art images that could just be a copy paste but with slightly differences. Therefore, using the power of deep learning we could just train a model to learn this task and reduce the time of colorization.

In this blog post we will review the Deep Line Art Video Colorization with a Few References paper and implement the different models. Here, we have two main models that also contains more sub-networks, the first one is The Color Transform Network and the second is The Temporal Constraint Network, in the following image we can see the inputs of these models:

Input images

To check how we can get this data, check the last section of this post.

We have three frames in chronological order, each frame has 3 images: color image, line art image and distance field map image, these 9 images forms one batch. Each batch has images that belongs to the same sequence.

The color target image is only used in the loss function as the ground truth image or label image. The work of the The Color Transform Network is to obtain color information from the color reference images to colorize the line art target image.

To transfer the color information of the reference images to the line art image, we use the non-local similarity to obtain local color information and Adaptive Instance Normalization (AdaIN) to obtain global color information from the reference images.

The Temporal Constraint Network takes the generated color image from The Color Transform Network along with its line art image and the reference color and line art images in chronological order to further refine the generated image learning spatial-temporal features using 2D Convolutions that behave like 3D Convolutions due to a technique called gated temporal shift module (LGTSM).

Color transform Network

This model contains a Generator and a Discriminator, thus being a GAN model.

The most important parts of the Generator are the Similarity based color transform layer and the Embedder.

Similarity Based Color Transform Layer

The input of the Similarity Based Color Transform Layer are feature maps from the input images extracted using three different encoders, each for every type of image (line art, color, distance field).

These encoders share the same architecture which is composed of 3 convolutional layers and also use Spectral Normalization which is a kind of normalization that GANs often use to improve the training.

TensorFlow Add-ons have an implementation of this normalization. However I used a custom version, the original code of the custom implementation can be found here

Since the code is quite large, We will only review the most important parts. I recommend you to check the complete code in the Jupyter Notebook and the Github repository here.

target_line_art_images_features = self.lineart_encoder(target_line_art_images) # EnL

target_distance_maps_features = self.distance_encoder(target_distance_maps) # EnD

reference_distance_maps_0_features = self.distance_encoder(reference_distance_maps_0) # EnD

reference_distance_maps_1_features = self.distance_encoder(reference_distance_maps_1) # EnD

reference_color_images_0_features = self.color_encoder(reference_color_images_0) # EnC

reference_color_images_1_features = self.color_encoder(reference_color_images_1) # EnC

Once we have the feature maps from our input images, we use them in the Similarity Based Color Transform Layer:

f_sim = self.color_transform_layer([target_distance_maps_features,
                              reference_distance_maps_0_features,
                              reference_distance_maps_1_features,
                              reference_color_images_0_features,
                              reference_color_images_1_features])

This layer takes as input the features of the distance target images (target_distance_maps_features) and the features of the distance and color reference images.

The idea behind this layer comes from the paper Deep Exemplar-based Video Colorization where we also have a reference color image and a target grayscale image which is painted using the colors from the reference color image. To achieve this, we first match the similar features from both images to build a match matrix that then is used to obtain the color information from the features in the color image that are similar to the ones in the grayscale image.

In the Similarity Based Color Transform Layer we can split its work into a left and right part, both parts are computed twice. In the left part we pass the target distance field map features and the reference distance field map features separately:

M_0 = self.lp([target_distance_map, reference_distance_0]) #HWxHW
M_1 = self.lp([target_distance_map, reference_distance_1]) #HWxHW
class LeftPart(tf.keras.Model):
  def __init__(self):
    super(LeftPart, self).__init__()
    kernels = 256 / 8
    self.conv = layers.Conv2D(kernels, (1, 1), padding="same")
  
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance_feat = inputs[1]
    reference_distance_x = self.conv(reference_distance_feat)
    target_distance_map_x = self.conv(target_distance_map)
    
    B, H, W, C = target_distance_map_x.shape

    reference_distance_x = layers.Reshape([H * W, C])(reference_distance_x)
    target_distance_map_x = layers.Reshape([H * W, C])(target_distance_map_x)
     
    M = tf.linalg.matmul(target_distance_map_x, reference_distance_x, transpose_b=True) #BxHWxHW

    return M

We use a convolutional layer to reduce the number of output channels (256 / 8), thus reducing the number of computations, and a matrix multiplication to obtain the similarity between the target image and one of the reference images, reference_distance_0 or reference_distance_1

matching_matrix = tf.concat([M_0, M_1], 1)
matching_matrix = tf.keras.activations.softmax(matching_matrix) # HWKxHW

We concatenate both similarity matrices and apply softmax to build the matching matrix that we can use to select the color information from the reference color features that match the target features:

f_mat = tf.linalg.matmul(reference_color_matrix, matching_matrix) #BxCxHW

To obtain the reference color matrix we have to work on the right part of the layer:

c_0, fm_0 = self.rp([small_m_0, reference_color_0]) #BxCxHW
c_1, fm_1 = self.rp([small_m_1, reference_color_1]) #BxCxHW

We can notice that the right part takes as input the reference color features and two masks called small_m_0 and small_m_1, we can obtain these masks using the following code:

small_m_0, n_0 = self.get_masks([target_distance_map, reference_distance_0])
small_m_1, n_1 = self.get_masks([target_distance_map, reference_distance_1])
class CreateMasks(tf.keras.Model):
  def __init__(self):
    super(CreateMasks, self).__init__()
    self.conv_m = layers.Conv2D(256, (3, 3), padding="same")
    self.conv_n = layers.Conv2D(256, (3, 3), padding="same")
  
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance = inputs[1]

    tensor_input = layers.Concatenate(axis=-1)([target_distance_map, reference_distance])

    m = self.conv_m(tensor_input)
    m = tf.keras.activations.sigmoid(m)

    n = self.conv_n(tensor_input)
    n = tf.keras.activations.sigmoid(n)

    return m, n

This layer takes the same input as the left part, however, in this case we are creating two masks, m and n.

Unlike grayscale images used in the Deep Exemplar-based Video Colorization paper, line art/distance map images have higher differences that make the feature matching task harder.

Due to this, we compute the two masks. m is used to select new features from the reference color features. It could be that m learns to identify the features that keep information about the colors and also the features that are not that useful to obtain color information, like shapes.

n is used to combine the matching features f_mat (the ones that we obtain from the color information features using the matching matrix and we are about to compute) and the new features that m selected. We will see n in action in a moment.

Now that we have m for each pair of target and reference distance map features, we can compute the right part:

class RightPart(tf.keras.Model):
  def __init__(self):
    super(RightPart, self).__init__()
    kernels = 256 / 8
    self.conv = layers.Conv2D(kernels, (1, 1), padding="same")
  
  def call(self, inputs):
    m = inputs[0]
    reference_color = inputs[1]
    fm = reference_color * m # like attention

    x = self.conv(fm)
    
    B, H, W, C = x.shape

    x = layers.Reshape([H * W, C])(x) # BxHWxC
    x = tf.transpose(x, [0, 2, 1]) # BxCxHW

    return x, fm

We may notice the use of the mask m to keep and remove information from the features and the use of a convolutional layer to reduce the number of channels of these features.

From the right part we obtain the new selected features fm_0, fm_1 and the reduced channel feature maps c_0, c_1. Using the reduced channel feature maps we construct the reference color matrix, which contains the color information from the reference images:

reference_color_matrix = tf.concat([c_0, c_1], -1) #BxCxKHW

This matrix is the one from where we extract the color information of similar features using the matching matrix:

f_mat = tf.linalg.matmul(reference_color_matrix, matching_matrix)

Now, f_mat contains the color information from similar areas in the reference and target images. However, as we discused, this is not enough for line art images. Therefore, we combine the new features that m selected (fm_1, fm_0) along with f_mat using n:

f_mat = self.conv(f_mat) # BxHxWxC

f_sim_left = (fm_1 * n_1) + ((n_1 - 1) * f_mat)
f_sim_right = (fm_0 * n_0) + ((n_0 - 1) * f_mat)

f_sim = (f_sim_left + f_sim_right) / 2

The job of n is to find the best combination between f_mat and fm_1, fm_0 that contains the most important and useful color information. We can notice how the information that we keep from fm_1 and fm_1 ( n_1) is the information that we remove from f_mat ((n_0 - 1) ) and viceversa.

With this, we keep the best information from f_mat and add useful information from the color features fm_1, fm_0.

The complete code for the Similarity Based Color Transform Layer is the following one:

class ColorTransformLayer(tf.keras.Model):
  def __init__(self):
    super(ColorTransformLayer, self).__init__()
    self.lp = LeftPart()
    self.rp = RightPart()
    self.get_masks = CreateMasks()
    self.conv = layers.Conv2D(256, (1, 1), padding="same")
  
  def call(self, inputs):
    target_distance_map = inputs[0]
    reference_distance_0 = inputs[1]
    reference_distance_1 = inputs[2]
    reference_color_0 = inputs[3]
    reference_color_1 = inputs[4]

    B, H, W, _ = target_distance_map.shape

    M_0 = self.lp([target_distance_map, reference_distance_0]) #HWxHW
    M_1 = self.lp([target_distance_map, reference_distance_1]) #HWxHW

    matching_matrix = layers.Concatenate(axis=1)([M_0, M_1])
    matching_matrix = tf.keras.activations.softmax(matching_matrix) # HWKxHW

    small_m_0, n_0 = self.get_masks([target_distance_map, reference_distance_0])
    small_m_1, n_1 = self.get_masks([target_distance_map, reference_distance_1])

    c_0, fm_0 = self.rp([small_m_0, reference_color_0]) #BxCxHW
    c_1, fm_1 = self.rp([small_m_1, reference_color_1]) #BxCxHW

    reference_color_matrix = layers.Concatenate(axis=-1)([c_0, c_1])

    f_mat = tf.linalg.matmul(reference_color_matrix, matching_matrix) #BxCxHW
    _, C, _ = f_mat.shape

    f_mat = layers.Reshape([C, H, W])(f_mat)
    f_mat = tf.transpose(f_mat, [0, 2, 3, 1])

    f_mat = self.conv(f_mat) # BxHxWxC

    f_sim_left = (fm_1 * n_1) + ((n_1 - 1) * f_mat)
    f_sim_right = (fm_0 * n_0) + ((n_0 - 1) * f_mat)

    f_sim = (f_sim_left + f_sim_right) / 2

    return f_sim

In order to match tensors, we use several reshapes and transpose so we compute the operations correctly.

Embedder

The Embedder takes as input the line art images and the reference images (not their features) and outputs the mean and the standard deviation (affine transformation parameters) for the Adain normalization layer. As we know, this kind of normalization is similar to the instance normalization where, unlike the batch normalization, we normalize each channel of each image but instead of learn a gamma and betta parameters we use the mean (gamma) and standard deviation (betta) to normalize our features:

class Embedder(tf.keras.Model):
  def __init__(self):
    super(Embedder, self).__init__()
    self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    self.conv_2 = layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same")
    self.conv_3 = layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same")
    self.conv_4 = layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same")
    self.conv_5 = layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same")
  
  def call(self, inputs):
    reference_line_art = inputs[0]
    reference_color = inputs[1]

    x = layers.Concatenate(axis=-1)([reference_line_art, reference_color])

    x = self.conv_1(x) # 256
    x = self.conv_2(x) # 128
    x = self.conv_3(x) # 64
    x = self.conv_4(x) # 32
    x = self.conv_5(x) # 16

    x = layers.AveragePooling2D((16, 16))(x) # Bx1x1x512

    return x
class SEV(tf.keras.Model):
  def __init__(self):
    super(SEV, self).__init__()
    self.embedder = Embedder()
    self.dense_1 = layers.Dense(512)
    self.dense_2 = layers.Dense(512)
  
  def call(self, inputs):
    reference_line_art_0 = inputs[0]
    reference_color_0 = inputs[1] 
    reference_line_art_1 = inputs[2]
    reference_color_1 = inputs[3]

    latent_vector_0 = self.embedder([reference_line_art_0, reference_color_0])
    latent_vector_1 = self.embedder([reference_line_art_1, reference_color_1])

    x = (latent_vector_0 + latent_vector_1) / 2
    x = self.dense_1(x)
    x = self.dense_2(x)

    return x

We use the Embedder to obtain the mean and standar deviation of each pair of reference line art/ reference color images and get the mean of both.

Finally, the code for the (AdaIN) normalization is:

class AdaInNormalization(tf.keras.layers.Layer):
  def __init__(self):
    super(AdaInNormalization, self).__init__()
    self.epsilon = 1e-5

  def call(self, x, style_vector):
    content_mean, content_variance = tf.nn.moments(x, [1, 2], keepdims=True) # Bx1x1xC
    content_sigma = tf.sqrt(tf.add(content_variance, self.epsilon))

    num_features = x.shape[-1]

    style_mean = style_vector[:, :, :, :num_features]
    style_sigma = style_vector[:, :, :, num_features:num_features*2]

    out = (x - content_mean) / content_sigma
    out = style_sigma * out + style_mean

    return out

Where the style_vector is the output of the SEV layer, where we take the half of the vector to obtain the mean :num_features and the other half to obtain the standard deviation num_features:num_features*2. First we normalize the input features x by its own mean and standard deviation:

out = (x - content_mean) / content_sigma

we compute the values using the tf.nn.moments function where we specify that the axes are 1, 2 to obtain these values from each channel of each image. Then we scale and move by the style vector mean and standard deviation:

out = style_sigma * out + style_mean

In a normal instance normalization layer we would have style_sigma as gamma and style_mean as betta ... FIX ALL THIS!!!

The complete color transform network also has 8 residual blocks, which normalization layers are (AdaIN), and a decoder to obtain the final generated image of size 256x256, this decoder also use Spectral Normalization:

style_vector = self.sev([reference_line_art_images_0,
                        reference_color_images_0,
                        reference_line_art_images_1,
                        reference_color_images_1]) # [Batch, 1, 1, 512])

Y_trans_sim = self.sim_conv(f_sim) # [Batch, 64, 64, 3]
Y_trans_sim = layers.UpSampling2D(size=(2, 2))(Y_trans_sim)
Y_trans_sim = layers.UpSampling2D(size=(2, 2))(Y_trans_sim) # [Batch, 256, 256, 3]

res_input = layers.add([target_line_art_images_features, f_sim]) # [Batch, 64, 64, 256]

x = self.res_block_1([res_input, style_vector])
x = self.res_block_2([x, style_vector])
x = self.res_block_3([x, style_vector])
x = self.res_block_4([x, style_vector])
x = self.res_block_5([x, style_vector])
x = self.res_block_6([x, style_vector])
x = self.res_block_7([x, style_vector])
x = self.res_block_8([x, style_vector])

Y_trans_mid = self.mid_conv(x)
Y_trans_mid = layers.UpSampling2D(size=(2, 2))(Y_trans_mid)
Y_trans_mid = layers.UpSampling2D(size=(2, 2))(Y_trans_mid)

Y_trans = self.decoder(x)

return Y_trans, Y_trans_mid, Y_trans_sim

The model will return 3 outputs, the first one Y_trans is the final generated image after going through the decoder, the other two outputs Y_trans_mid and Y_trans_sim are intermediate outputs used in the loss function to improve the gradient flow from the beginning and the mid part of the model. Both use convolutional layers mid_conv and sim_conv to output images of size 256x256x3 that are used in the loss function.

You can find the discriminator for the color transform network in the Github repository

Temporal Constrain Network

The model that called my attention the most was the Temporal Constrain Network, builded using Gated Temporal Shift Module layers that replicate the effect of 3D Convolutions using 2D Convolutions.

The input of this network is the color reference images along with the generated color image from the Color transform Network, the images are feed in chronological order where the generated image is the middle frame. Therefore, we have a 5D tensor input of shape B, T, H, W, C where T is the temporal axis that in this case has a size of 3 since we have 3 frames.

The use of this model is to improve further the quality of the generated images where the latter should take into account the colors from the previous and next frames (reference color images).

This network is made up of two modules, the Learnable Temporal Shift Module and the Gated Convolution module:

Learnable Temporal Shift Module

The original Temporal Shift Module (TSM) was introduced in the paper TSM: Temporal Shift Module for Efficient Video Understanding to give 2D Convolutions the ability to capture Temporal Relationships as the 3D Convolutions do but with a higher performance.

To fully understand this, let's say we have a matrix of size 3x3x3x1 where the first dimension is the temporal dimension and the last dimension is the channel dimension. We can see this as a small video with 3 frames of size 3x3 with only one channel each. We also have a 3D kernel of shape 3, 1, 1, where the first dimension is the depth dimension and its values are 0, 0, 1:

Input example

Each frame is represented by one color

What the Temporal Shift Module does is to move the channels of one frame to another frame:

Shift example

Since we want to keep the same output size we add padding (represented by the transparent blocks) in the front and in the back of the input. Our kernel will work mainly in the depth dimension.

In the first step the kernel only keeps the yellow frame, in the second step the kernel keeps the pink frame and in the last step we don't keep any frame.

This creates the effect that the frame's channels move backwards. We can also notice that we leave an empty area full of zeros.

Due to this shift we can capture temporal relationships between the 3 frames since now we have information about future and previous frames in the current frame. In our case we will also have 3 frames (first, middle, last) but more channels for each frame.

In order to produce this shift between frames, we need to use specific weights that we call shifting kernels or shifting weights:

pre_weights = tf.constant([0.0, 0.0, 1.0], dtype=tf.float32)
post_weights = tf.constant([1.0, 0.0, 0.0], dtype=tf.float32)

We should notice that the pre_weights kernel is the one that we use in the previous example to move the frame's channels backwards, on the other side, post_weights moves the channels forwards.

In the original implementation (Temporal Shift Module) these kernels are constant, like the ones we just check. However, an improved version of the module was presented in the paper Learnable Gated Temporal Shift Module for Deep Video Inpainting where the shifting kernels are now learnable. This implementation is the one that we will use so the kernels will learn what information need to move between frames to achieve a better result.

The original code is available in this repository and implemented using PyTorch in this case we will make an implementation using TensorFlow:

class LearnableTSM(tf.keras.Model):
  def __init__(self):
    super(LearnableTSM, self).__init__()
    self.shift_ratio = 0.5
    self.shift_groups = 2
    self.shift_width = 3

    pre_weights = tf.constant([0.0, 0.0, 1.0], dtype=tf.float32)
    pre_weights = tf.reshape(pre_weights, [3, 1, 1, 1, 1])

    post_weights = tf.constant([1.0, 0.0, 0.0], dtype=tf.float32)
    post_weights = tf.reshape(post_weights, [3, 1, 1, 1, 1])

    self.pre_shift_conv = layers.Conv3D(1, [3, 1, 1], use_bias=False, padding="same", weights=[pre_weights])
    self.post_shift_conv = layers.Conv3D(1, [3, 1, 1], use_bias=False, padding="same", weights=[post_weights])

  def apply_tsm(self, tensor, conv):
    B, T, H, W, C = tensor.shape

    tensor = tf.transpose(tensor, [0, 4, 1, 2, 3])
    tensor = conv(tf.reshape(tensor, [B * C, T, H, W, 1]))
    tensor = tf.reshape(tensor, [B, C, T, H, W])
    tensor = tf.transpose(tensor, [0, 2, 3, 4, 1])

    return tensor

  def call(self, input_tensor):
    shape = B, T, H, W, C = input_tensor.shape
    split_size = int(C * self.shift_ratio) // self.shift_groups

    split_sizes = [split_size] * self.shift_groups + [C - split_size * self.shift_groups]
    tensors = tf.split(input_tensor, split_sizes, -1)
    assert len(tensors) == self.shift_groups + 1

    tensor_1 = self.apply_tsm(tensors[0], self.pre_shift_conv)
    tensor_2 = self.apply_tsm(tensors[1], self.post_shift_conv)

    final_tensor = tf.concat([tensor_1, tensor_2, tensors[2]], -1)
    final_tensor = tf.reshape(final_tensor, shape)
    
    return final_tensor

The first step in the temporal shift module is to determine a split size of the channels using two pre-defined values, shift_ratio and shift_groups. The idea is to split the input tensor of shape B, T, H, W, C into 3 groups, each group contains all the images sequences B and all the frames T but each frame only contains a part of its channels. For example if we have a input tensor of shape:

8x3x256x256x128

Where we have features from 8 images sequences, each with 3 frames of size 256x256 and 128 channels, the frames in the first two groups remains with 32 channels each and the frames in the last group with 64 channels:

split_size = int(128 * 0.5) // 2 # 32
split_sizes = [32] * 2 + [128 - 32 * 2] # [32, 32, 64]

Once we have the 3 groups of frames, we pass the frames in the first group to a 3D convolutional layer that is initialized using the pre_shift_conv kernels, thus the remained channels will move backward from the last frames to the first frame.

Then we pass the frames in the second group to another 3D convolutional layer but this time initialized with the post_shift_conv kernels so now the remained channels will move forward from the first frame to the last frame.

As we train the network, the kernels of the 3D convolutions layers learn the patters from the data and change their values, we only initialized them to give them a guide about what they should do.

We should notice that these two 3D convolutions are only initialized with 1 output channel and also a kernel size of 3, 1, 1, thus the convolution operation mainly operates in the temporal axis to move the channels, also we have to reshape and transpose our input tensors before they are passed to the 3D convolutional layer:

B, T, H, W, C = tensor.shape

tensor = tf.transpose(tensor, [0, 4, 1, 2, 3])
tensor = conv(tf.reshape(tensor, [B * C, T, H, W, 1]))

Since the 3d convolutional layer operates in the temporal axis, we transpose the channels to the temporal axis so the convolution operation can work on the channels and move them and we only keep the tensor as we only have 1 channel

after the convolution we transpose and reshape the output tensor so we can use it further in the normal convolution layers:

    tensor = tf.reshape(tensor, [B, C, T, H, W])
    tensor = tf.transpose(tensor, [0, 2, 3, 4, 1])

Finally the new shifted tensors are concatenated again along with the third split group, which channels were not modified, to form the original tensor but with shifted channels across the temporal axis.

Gated Convolution

Gated Convolutions were introduced in the paper called Free-Form Image Inpainting with Gated Convolution and used for image inpainting where sections of an image are not visible, thus, the network has to recreate them. In image inpainting we have invalid areas (not visible sections) that should not be taking into account by convolutions. Gated Convolutions learn to detect and filter those areas using attention maps, therefore, we can select the visible features from the image.

Although, in this work all the line art, color and distance field images are complete, when we shift the channels of the image features we can leave empty sections, like we previously learned. Then, the Gated Convolutions are used to learn to identify these areas, as stated in the Learnable Gated Temporal Shift Module for Deep Video Inpainting paper:

class GatedConv(tf.keras.Model):
  def __init__(self, kernels, kernel_size, strides, dilation=(1, 1)):
    super(GatedConv, self).__init__()

    self.learnableTSM = LearnableTSM()
    self.feature_conv = SpectralNormalization(layers.Conv2D(kernels, kernel_size, strides=strides, padding="same", dilation_rate=dilation))

    self.gate_conv = SpectralNormalization(layers.Conv2D(kernels, kernel_size, strides=strides, padding="same", dilation_rate=dilation))

    self.activation = layers.LeakyReLU(0.2)
  
  def call(self, input_tensor):
    B, T, H, W, C = input_tensor.shape
    xs = tf.split(input_tensor, num_or_size_splits=T, axis=1)
    gating = tf.stack([self.gate_conv(tf.squeeze(x, axis=1)) for x in xs], axis=1)
    gating = tf.keras.activations.sigmoid(gating)

    feature = self.learnableTSM(input_tensor)
    # shape B, T, H, W, C

    feature = self.feature_conv(tf.reshape(feature, [B * T, H, W, C]))
    _, H_, W_, C_ = feature.shape
    feature = tf.reshape(feature, [B, T, H_, W_, C_])
    feature = self.activation(feature)

    out = gating * feature

    return out

In this gated convolution we first split the input image or input features by the temporal axis, since we only have 3 images, we will end up with 3 tensors, each of these tensors is passed through a normal convolutional layer and stacked to form one tensor, then we use a sigmoid activation to create an attention map.

With the learnableTSM we shift the channels of the input images or input features as we have seen.

We merge the batch and temporal axis so we can use the shifted tensor in a normal convolutional layer, then we reshape to split the batch and temporal axis apply a leaky relu activation function to finally apply the attention map gating.

then the attention map filter out the zero values from the shifted features.

We can notice in this code that we have as parameter the dilatation rate that we will use in some of these blocks.

class GatedDeConv(tf.keras.Model):
  def __init__(self, kernels):
    super(GatedDeConv, self).__init__()
    self.gate_conv = GatedConv(kernels, (3, 3), (1, 1))
    self.upsampling = layers.UpSampling3D(size=(1, 2, 2))
  
  def call(self, input_tensor):
    x = self.upsampling(input_tensor)
    x = self.gate_conv(x)

    return x

We also have a Gated deconvolution layer that is the same Gate Conv layer that now integrates a upsampling layer to increase the size of the image.

At the end we have the following network:

class TemporalConstraintNetwork(tf.keras.Model):
  def __init__(self):
    super(TemporalConstraintNetwork, self).__init__()
    self.conv_1 = layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same")
    self.conv_2 = GatedConv(64, (3, 3), (1, 1))
    self.conv_3 = GatedConv(128, (3, 3), (2, 2))
    self.conv_4 = GatedConv(256, (3, 3), (2, 2))

    self.dilation_1 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 2
    self.dilation_2 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 4
    self.dilation_3 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 8
    self.dilation_4 = GatedConv(256, (3, 3), (1, 1), (2, 2)) # 16

    self.conv_5 = GatedConv(256, (3, 3), (1, 1))
    self.up_conv_1 = GatedDeConv(128)
    self.up_conv_2 = GatedDeConv(3)
  
  def call(self, input_tensor):
    x = self.conv_1(input_tensor)
    x = self.conv_2(x) # Bx3x256x256x64
    x_1 = self.conv_3(x) # Bx3x128x128x128
    x_2 = self.conv_4(x_1) # Bx3x64x64x256

    x = self.dilation_1(x_2)
    x = self.dilation_2(x)
    x = self.dilation_3(x)
    x = self.dilation_4(x) # Bx3x64x64x256

    x = self.conv_5(x) # Bx3x64x64x256
    x = layers.concatenate([x, x_2], axis=-1) # or axis 1??
    x = self.up_conv_1(x)
    x = layers.concatenate([x, x_1], axis=-1) # or axis 1??
    x = self.up_conv_2(x)

    return x

Here we have a concatenation of previous layers with the last layers. In the paper that we are currently reviewing, the Temporal Constrain Network does not contain these concatenations, however, in the original paper where this architecture was introduced, the authors used the concatenations, in my tests I found out that if we remove the concatenations, the outputs of the network are just noise.

Loss

In this work we have several loss terms:

One of the loss functions that we often encounter when we are working with Gans is the Adversarial Loss:

def compute_discriminator_2d_loss(y_class, y_trans_class):
  real_loss = cross_entropy(tf.ones_like(y_class), y_class)
  fake_loss = cross_entropy(tf.zeros_like(y_trans_class), y_trans_class)
  loss = real_loss + fake_loss

  return loss

def generator_loss(y_trans_class):
  return cross_entropy(tf.ones_like(y_trans_class), y_trans_class)

Where y_trans_class is the predicted class (Fake/Real) from the discriminator using the generated or fake images.

The L1 loss, that encourage the generated images to be similar to the real color target images:

def l1_loss(y, y_trans):
  return tf.reduce_mean(tf.abs(y - y_trans))

The Perceptual Loss and Style Loss where we extract feature maps from a VGG19 Model of the real and generated images and compute the differences between them:

def perceptual_loss(y_list, y_trans_list):
  loss = 0
  for feature_map_y, feature_map_y_trans in zip(y_list, y_trans_list):
    loss += tf.reduce_mean(tf.math.abs(feature_map_y - feature_map_y_trans))
  
  return (loss / 5) * 3e-2

def style_loss(y_list, y_trans_list):
  loss = 0
  for feature_map_y, feature_map_y_trans in zip(y_list, y_trans_list):
    loss += tf.reduce_mean(tf.abs(get_gram_matrix(feature_map_y) - get_gram_matrix(feature_map_y_trans)))
  
  return (loss / 5) * 1e-6

To compute the Style Loss we need to compute the Gram matrices:

def get_gram_matrix(feature_map):
  B, H, W, C = feature_map.shape
  matrix = tf.transpose(feature_map, [0, 3, 1, 2])
  matrix = tf.reshape(matrix, [B, C, H * W])

  num_locations = tf.cast(H * W, tf.float32)

  gram_matrix = tf.linalg.matmul(matrix, matrix, transpose_b=True) # C, HW * HW, C
  gram_matrix = gram_matrix / num_locations

  return gram_matrix

The Latent Constraint Loss use the intermediate results of the network to improve the stability and gradient flow:

def latent_constraint_loss(y, y_trans_sim, y_trans_mid):
  loss = tf.reduce_mean(tf.abs(y - y_trans_sim) + tf.abs(y - y_trans_mid))
  return loss

Here the L1 loss is used to compute the differences between the color target image and the intermediate results y_trans_sim, y_trans_mid.

The final objective loss function is:

def compute_color_network_loss(y_trans_class, y, y_trans, y_trans_sim, y_trans_mid, y_list, y_trans_list, lambda_style=1000, lambda_l1=10):
  loss = 0
  gen_loss = generator_loss(y_trans_class)

  loss += gen_loss
  latent_loss = latent_constraint_loss(y, y_trans_sim, y_trans_mid)

  loss += latent_loss
  s_loss = style_loss(y_list, y_trans_list) * lambda_style

  loss += s_loss
  p_loss = perceptual_loss(y_list, y_trans_list)

  loss += p_loss
  l_loss = l1_loss(y, y_trans) * lambda_l1

  loss += l_loss

  return loss, gen_loss, latent_loss, s_loss, p_loss, l_loss

Where we add some weights (lambda_style, lambda_l1) to balance all the loss terms.

Training

There are some details that are different to the paper since we don't have the original source code available. Thus, we could still improve the quality of the generated images.

I trained the Color Network Model on Google Colab for 350 epochs where each epoch took around 110 seconds. Once the results of the Color Network Model are good enough, we can start training the Temporal Network.

The temporal network should be trained in less epochs than the color network. I trained this network for 150 epochs.

Both networks use the same loss functions and the same hyper parameters like the learning rate.

Generated Images

Here are some examples of the generated images from the python model.

Generated Images

Some of them look almost like the originals ones. However, if we increase the size of the image the quality decreases. It would be interesting to investigate if we can increase the model capacity and output images of a larger size.

How to get the data

To obtain the color images we can extract frames from some anime using cv2. In this case one of the animes selected by the authors was my little witch academia, thus, I also choice the same anime to compare results. I extracted frames from the first 3 episodes of the anime.

We need several shots where each shot has similar frames like in the first image of this post. To get similar frames the authors computed the mean square error between feature vectors obtained by the histograms of the frames, you can check more about this in the Data Collection and Training section of the paper. However, I couldn't use the same method and opted for manual selection of frames until I collected 341 different shots and 4527 images.

To obtain the line art images the authors used the SketchKeras model, this model takes as input the color images and after some external processing using cv2, we can get pretty good line art images. To get better line art images we should input color images of size 1024x1024 or 512x512 since the results from 256x256 images have less quality.

Finally the distance field map images are obtained by using the ndimage.distance_transform_edt function from scipy and the line art images. In a similar way, to obtain a high quality distance image we should use a line art image of size 1024x1024.

The code to obtain these images is also in the github repository.

At the end we will end up with 3 folders where each folder contains the same shots but with a different type of image so at the training time we go shot by shot randomly selecting 3 frames to train our models.

As recommendation the frames/images should have the same name for each type of image, thus a color image named 455.jpg has its corresponding line art image and distance image named 455.jpg as well. In this way we can easily create a generator that loads all the needed frames.

Results and Conclusions

The output images from the color model are good but their quality could increase. We can improve the training data obtention, like get frames that are more similar to each other. I also noticed some name errors where the frames in the distance folder don't correspond to the frames in the color folder so definitely there are some fixes that could improve the quality of the generated images.

A thing that could improve the quality of the generated images furthermore is a better tuning of the loss function. My first implementation of the loss function term was a disaster and the quality of the generated images was bad. The model was paying too much attention to the perceptual and style losses, so I had to reduce the importance of these terms. A better tuning of the weight of each loss could lead to better results.

At the end I trained two versions of the color model, the first version used the custom Spectral Normalization and has the best results, the second version don't use Spectral Normalization and the results have less quality.

To make a real life example of this model I created an electron app that runs the color model using Javascript. When I saved the first version of the model, TensorFlow printed some export alerts. If I removed the Spectral Normalization Layers the alerts disappear.

I tested the first version of the model using the tf.node.loadSavedModel function from TensorFlow.js that can load models directly without conversion. The images generated using TensorFlow.js are not as good as using the python version. At first I though the reason was the export alerts, that's why I trained a second version of the model without Spectral Normalization. However, the results are still worse.

Despite the results from the Javascript implementation, the electron app still has some good ideas that you can read here in case that you want to server a model as an app.

In the case of the temporal model, I didn't notice an increase in the quality of the generated images, I also changed the loss function so the model would pay more attention to the generated image and not the reference color images that are also used as input. However the further didn't changed the results. Thus, there could be an error in my implementation since the original model was implemented using PyTorch.

Categories