Vicente Rodríguez

May 17, 2021

Face Detection for low-end hardware using the BlazeFace Architecture

Face detection is such a cool task that can be hard to achieve and difficult to run in hardware that is not too powerful.

In this post we will review the paper called BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs where the authors presented an architecture to perform the face detection task even in not capable hardware.

The original model is part of the mediapipe library that also contains more deep learning models.

In my search for more information about this model, I found this repository where the owner extracted the model architecture from a tflite file, which is really cool, its weights, and also a part of the code where we filter the predictions to only remain with the good ones.

The code from the previous repository is in PyTorch, in this case, we will use TensorFlow and we will also train the model from scratch. Thus, most of the presented code is written from scratch, except for the model architecture and the predictions filter part that is translated from PyTorch.

You can check all the code in this repository.

Face detection and object detection are tasks that have always called my attention. The Blaze Face architecture claims that we can have good face detection even if our hardware is not powerful, and the results are pretty good, more if we take into account that our training data is different and much smaller than the original one. In fact, the main key to get a good model was the data preprocessing and creation as we will see in the following section.

Data

The datasets used to train the model were the WIDER FACE: A Face Detection Benchmark dataset, and the Face Detection Data Set and Benchmark (FDDB) dataset.

The former contains more than 32,00 images and 393,703 faces. In both datasets, we have several types of face sizes that can cover a small part of the image or almost the whole image.

The original blaze model was trained on a dataset of 66k images. The authors trained two models, one for the frontal and one for the rear camera of a phone. The model for the frontal camera used images where the faces covered more than 20% of the image area and the model for the rear camera used faces that covered more than 5% of the image area.

In this case, we will only train one blaze model. Furthermore, we don't have such a big dataset to train our model. Thus, I kept faces that at least covered 4% of the image area and in some cases faces that covered 2% of the image area. This selection depends on the type of image we are dealing with, the blaze model expects a square image of size 128x128, due to this we have to resize and crop our images, then the area of each face can vary.

Both datasets have their own annotations where we have information about the images and the number of faces as well as their positions and sizes. In the case of the WIDER FACE dataset we have square coordinates with x_min, x_max, etc values. However, the FDDB dataset have the size of the faces as ovals with a radious. Thus, we have to transform to square coordinates.

Once I had removed the images with small faces and cropped the remaining images, I started training the model, the results were really good since the first training epochs, which was not a good signal. I trained the model for some epochs and test it only to realize that the outputs for all the images were always the same, the coordinates for the faces always were the same numbers, and also the predicted boxes for each face were the same.

When I checked the dataset, I noticed that the faces always appear in the middle of the image due to the resizing and cropping steps. Thus, when the model assumes that all the labels are basically the same, outputs the same coordinates, and is able to get a good result.

Then, I had to think about a way to randomize the cropping and resize of the images. And the final idea was: Firstly, take into account if the image is wider or higher, if it is wider we can take the space from the sides, left or right, also think of where the face appears, if it appears in the left side, we can take more space from the right side and vice versa. If the image is higher we can take space from the bottom and the top sides. The final thing that I had to consider was the final aspect of the image, we need to resize to 128x128, consequently, we need a square image.

In this way, we can randomly take some space from the side where we have more space and take all the available space from the other side. If the final image is bigger on the side we are cropping, then take less space, if the final image is smaller, then take more space.

You can check the code to get this done in the create_dataset.py file from the repository.

If we run this for each image, we end up with faces that appear in different areas of the image and not only in the center.

Furthermore, I noticed that some images were not good enough to be used as training data, then I had to manually inspect all the images and remove the ones where the faces were occluded or where the person was looking completely sideways and images that I considered would be hard for the network to learn from.

From the original datasets I got only 4477 images. Thus, I repeated the generation of images twice so we end up with 8954 images, a number that is far from the 66k used in the original model but that was enough to get a functional model.

Boxes

Since we want to detect faces in an image, we have an object detection task. To solve this task is common to use a single shot detector model, this kind of model outputs the coordinates of each object in the image with the help of some predefined boxes, you can read more about this in my blog post about one shot object detection.

In short, we will use these boxes, sometimes called anchors, as a guide for the model output coordinates, also called offsets. Hence, the offsets are relative to the center of each box instead of the whole image. For instance, if we have one box at 0.234, 0.456 and the face is in 0.240, 0.457, the model only needs to predict the difference with respect to the box, 0.06 and 0.01. Additionally, each box has to predict if there is a face in that area or not, this is the classification task whereas the offsets are the regression task.

In our model we will use boxes of two sizes and assign each face to some boxes where the intersection over union between both is good enough.

In this way, each face has several boxes assigned that should detect the presents of the face and its coordinates concerning the center of the box.

We have in total 896 boxes, where 512 boxes are of size 0.0625 (or a grid of 16x16 over the image) and 384 boxes are of size 0.125 (a grid of 8x8 over the image). Thus, if the face is closer to the camera, the big boxes (8x8) should output more accurate coordinates and if the face is far from the camera, the small boxes (16x16) should output more accurate coordinates.

To generate these boxes we can follow the following code:


small_boxes = np.linspace(0.03125, 0.96875, 16, endpoint=True, dtype=np.float32) # 16x16 size of 0.0625

big_boxes = np.linspace(0.0625, .9375, 8, endpoint=True, dtype=np.float32) # 8x8 size of 0.125

list_of_boxes = [small_boxes, big_boxes]



small_x = tf.tile(tf.repeat(small_boxes, repeats=2), [16]) # x

small_y = tf.repeat(small_boxes, repeats=32)



small = tf.stack([small_x, small_y], axis=1)



big_x = tf.tile(tf.repeat(big_boxes, repeats=6), [8]) # x

big_y = tf.repeat(big_boxes, repeats=48)



big = tf.stack([big_x, big_y], axis=1)



reference_anchors = tf.concat([small, big], axis=0)

The reference_anchors will be used in the loss function to guide the output coordinates.

One more thing that we have to take into account is the position of our boxes. In total we have 896 boxes where 512 are of size 0.0625, two grids of size 16x16, and 386 are of size 0.125, six grids of size 8x8. For each image we can assign the faces up to two grids, some faces can fit in both of the grids. However, the model outputs 8 grids, and we want the model to use all the grids to predict the coordinates. Hence, we can randomly assign the faces to one of the two 16x16 grids and one of the eight 8x8 grids. We could also repeat the grids so all the grids are used but I didn't try this option.

Model

The model architecture is easy to implement, the main point of this architecture is the use of a small number of filters for each convolutional layer, and also the use of depthwise convolutional layers as in the MobileNet architecture. One important point is the difference with respect to the architecture presented in the paper, here we have a simpler structure, perhaps since it was good enough and the extra layers were no needed:


class BlazeBlock(tf.keras.Model):

  def __init__(self, filters, strides=1):

    super(BlazeBlock, self).__init__()

    self.strides = strides

    self.filters = filters



    if strides == 2:

      self.pool = layers.MaxPool2D()



    self.dw_conv = layers.DepthwiseConv2D((3, 3), strides=strides, padding="same")

    self.conv = layers.Conv2D(filters, (1, 1), strides=(1, 1))



    self.norm_1 = layers.BatchNormalization()

    self.norm_2 = layers.BatchNormalization()



    self.activation = layers.ReLU()



  def call(self, x_input):

    x = self.dw_conv(x_input)

    x = self.norm_1(x)

    x = self.conv(x)

    x = self.norm_2(x)



    if self.strides == 2:

      x_input = self.pool(x_input)



    padding = self.filters - x_input.shape[-1]



    if padding != 0:

      padding_values = [[0, 0], [0, 0], [0, 0], [0, padding]]

      x_input = tf.pad(x_input, padding_values)



    x = x + x_input

    x = self.activation(x)



    return x

Our blaze model is made of multiple Blaze blocks, the architecture is similar to a Residual Network:


class BlazeModel(tf.keras.Model):

  def __init__(self):

    super(BlazeModel, self).__init__()



    self.conv = layers.Conv2D(24, (3, 3), strides=2, padding="same")

    self.activation = layers.ReLU()



    self.block_1 = BlazeBlock(24)

    self.block_2 = BlazeBlock(28)

    self.block_3 = BlazeBlock(32, strides=2)

    self.block_4 = BlazeBlock(36)

    self.block_5 = BlazeBlock(42)



    self.block_6 = BlazeBlock(48, strides=2)

    self.block_7 = BlazeBlock(56)

    self.block_8 = BlazeBlock(64)

    self.block_9 = BlazeBlock(72)

    self.block_10 = BlazeBlock(80)

    self.block_11 = BlazeBlock(88)



    self.block_12 = BlazeBlock(96, strides=2)

    self.block_13 = BlazeBlock(96)

    self.block_14 = BlazeBlock(96)

    self.block_15 = BlazeBlock(96)



    self.classifier_8 = layers.Conv2D(2, (1, 1), strides=(1, 1), activation="sigmoid")

    self.classifier_16 = layers.Conv2D(6, (1, 1), strides=(1, 1), activation="sigmoid")



    self.regressor_8 = layers.Conv2D(8, (1, 1), strides=(1, 1)) # 32

    self.regressor_16 = layers.Conv2D(24, (1, 1), strides=(1, 1)) # 96



  def call(self, x):

    B, H, W, C = x.shape



    x = self.conv(x)

    x = self.activation(x) # (B, 64, 64, 24)



    x = self.block_1(x) # (B, 64, 64, 24)

    x = self.block_2(x) # (B, 64, 64, 28)

    x = self.block_3(x) # (B, 32, 32, 32)

    x = self.block_4(x) # (B, 32, 32, 36)

    x = self.block_5(x) # (B, 32, 32, 42)



    # Double Blocks



    x = self.block_6(x) # (4, 16, 16, 48)

    x = self.block_7(x) # (4, 16, 16, 56)

    x = self.block_8(x) # (4, 16, 16, 64)

    x = self.block_9(x) # (4, 16, 16, 72)

    x = self.block_10(x) # (4, 16, 16, 80)

    x = self.block_11(x) # (4, 16, 16, 88) output size



    h = self.block_12(x) # (4, 8, 8, 96)

    h = self.block_13(h) # (4, 8, 8, 96)

    h = self.block_14(h) # (4, 8, 8, 96)

    h = self.block_15(h) # (4, 8, 8, 96) output size



    c1 = self.classifier_8(x) # B, 16, 16, 2 output size

    c1 = layers.Reshape((-1, 1))(c1) # B, 512, 1 output size



    c2 = self.classifier_16(h) # B, 8, 8, 6 output size

    c2 = layers.Reshape((-1, 1))(c2) # B, 384, 1 output size



    c = layers.concatenate([c1, c2], axis=1) # B, 896, 1



    r1 = self.regressor_8(x) # B, 16, 16, 8 output size

    r1 = layers.Reshape((-1, 4))(r1) # B, 512, 4 output size



    r2 = self.regressor_16(h) # B, 8, 8, 24 output size

    r2 = layers.Reshape((-1, 4))(r2) # B, 384, 4 output size



    r = layers.concatenate([r1, r2], axis=1) # B, 896, 4



    return r, c

We can principally notice the number of filters for each layer, we have from 24 to 96 filters, which is a really small amount if we take into consideration that there are models with 512 and 1024 filters.

Additionally, to avoid the use of dense layers and add more size to the model, we reshape and concatenate our layers to get the final output.

The original model from media pipe also outputs landmark coordinates. However, we don't have a dataset with both, the face position and face landmarks, annotations.

Loss function

To train this model we need a loss function for the coordinates and a loss function for the class prediction. To evaluate the coordinates we only use the coordinates from the boxes that should contain a face:


def compute_loss(class_predictions, anchor_predictions, big_anchors, small_anchors, reference_anchors, ratio=3, scale=128):

  B = big_anchors.shape[0]

  list_big_anchors = tf.reshape(big_anchors, (B, -1, 5)) # shape [B, 384, 5])



  list_small_anchors = tf.reshape(small_anchors, (B, -1, 5)) # shape [B, 512, 5])



  list_true_anchors = tf.concat([list_small_anchors, list_big_anchors], axis=1) # shape [B, 896, 5]



  true_classes = list_true_anchors[:, :, 0] # shape [B, 896, 1]

  true_coords = list_true_anchors[:, :, 1:] # shape [B, 896, 4]



  faces_mask_bool = tf.dtypes.cast(true_classes, tf.bool)

The big_anchors and small_anchors variables are the true label boxes that contains the coordinates and the classes.

We can observe how we split the labels into classes and coordinates in these lines of code:


true_classes = list_true_anchors[:, :, 0] # shape [B, 896, 1]

true_coords = list_true_anchors[:, :, 1:] # shape [B, 896, 4]

The model will output the x, y coordinates (offsets) of the box and its width and height. We first compute the predicted center of the face:


x_center = reference_anchors[:, 0:1] + (anchor_predictions[..., 0:1] / scale) # 8, 896, 1

y_center = reference_anchors[:, 1:2] + (anchor_predictions[..., 1:2] / scale) # 8, 896, 1

We can notice how we sum the reference_anchors and the anchor_predictions variables to get the center position of the face.

We also get the width and height and scale them to be normalized 0,1.


w = anchor_predictions[..., 2:3] / scale # B, 896, 1

h = anchor_predictions[..., 3:4] / scale # B, 896, 1

We compute the coordinates of all the predicted boxes:


y_min = y_center - h / 2.  # ymin

x_min = x_center - w / 2.  # xmin

y_max = y_center + h / 2.  # ymax

x_max = x_center + w / 2.  # xmax



offset_boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1) # B, 896, 4

And we finally filter the boxes to only keep the ones where faces were assigned and compute the loss:


filtered_pred_coords = tf.boolean_mask(offset_boxes, faces_mask_bool) # ~faces_num, 4

filtered_true_coords = tf.boolean_mask(true_coords, faces_mask_bool) # ~faces_num, 4.



detection_loss = huber_loss(filtered_true_coords, filtered_pred_coords)

Once we have the detection loss completed, we can talk about the classification loss. In this case we have an unbalance problem due to the number of boxes without faces assigned. If we think about it, in the case where we have around 4 faces in one image we could assign some 80 boxes, 20 boxes for each face. However, we still remain with 816 empty boxes.

To overcome this problem we can use a hard negative training where we choose the boxes with the highest confidence but the wrong class. In other words, predictions where the model is highly sure that the class is A where in reality the class is B.

We first compute the number of boxes that contain a face and compute the number of empty boxes that we want to use in the loss function. Often a ratio of 3 is used so we keep 3 times the number of empty boxes than the number of boxes with faces:


faces_num = tf.keras.backend.sum(true_classes)

background_num = int(faces_num * ratio) // B

Then, we change the class score to -99.0 of the boxes with faces so we can sort the prediction values of the empty boxes:


predicted_classes_scores = tf.where(faces_mask_bool, -99.0, class_predictions) # B, 896



background_class_predictions = tf.sort(predicted_classes_scores, axis=-1, direction='DESCENDING')[:, :background_num]

In this way now we have a tensor with the highest background_num incorrect class predictions that we can use in the loss function:


background_loss = tf.math.reduce_mean(tf.keras.losses.binary_crossentropy(tf.zeros_like(background_class_predictions), background_class_predictions))

All these predictions should be near to 0, which means there is no face, instead of 1, which means we have a face.

Finally we have to compute the positive class predictions and calculate the final loss that also contains the detection loss:


positive_loss = tf.math.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(positive_class_predictions), positive_class_predictions))



loss = tf.math.reduce_mean(detection_loss) * 150 + (background_loss * 35) + (positive_loss * 35)

The complete code to compute the loss function is:


def compute_loss(class_predictions, anchor_predictions, big_anchors, small_anchors, reference_anchors, ratio=3, scale=128):

  B = big_anchors.shape[0]

  list_big_anchors = tf.reshape(big_anchors, (B, -1, 5)) # shape [B, 384, 5])



  list_small_anchors = tf.reshape(small_anchors, (B, -1, 5)) # shape [B, 512, 5])



  list_true_anchors = tf.concat([list_small_anchors, list_big_anchors], axis=1) # shape [B, 896, 5]



  true_classes = list_true_anchors[:, :, 0] # shape [B, 896, 1]

  true_coords = list_true_anchors[:, :, 1:] # shape [B, 896, 4]



  faces_mask_bool = tf.dtypes.cast(true_classes, tf.bool)



  faces_num = tf.keras.backend.sum(true_classes)

  background_num = int(faces_num * ratio) // B



  class_predictions = tf.squeeze(class_predictions, axis=-1)



  # Hard negatives



  predicted_classes_scores = tf.where(faces_mask_bool, -99.0, class_predictions) # B, 896



  background_class_predictions = tf.sort(predicted_classes_scores, axis=-1, direction='DESCENDING')[:, :background_num]



  positive_class_predictions = tf.boolean_mask(class_predictions, faces_mask_bool)



  # Class loss



  background_loss = tf.math.reduce_mean(tf.keras.losses.binary_crossentropy(tf.zeros_like(background_class_predictions), background_class_predictions))

  positive_loss = tf.math.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(positive_class_predictions), positive_class_predictions))



  # Anchors offset



  # anchor_predictions (shape) B, 894, 4



  x_center = reference_anchors[:, 0:1] + (anchor_predictions[..., 0:1] / scale) # 8, 896, 1

  y_center = reference_anchors[:, 1:2] + (anchor_predictions[..., 1:2] / scale) # 8, 896, 1



  w = anchor_predictions[..., 2:3] / scale # B, 896, 1

  h = anchor_predictions[..., 3:4] / scale # B, 896, 1



  y_min = y_center - h / 2.  # ymin

  x_min = x_center - w / 2.  # xmin

  y_max = y_center + h / 2.  # ymax

  x_max = x_center + w / 2.  # xmax



  offset_boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1) # B, 896, 4



  filtered_pred_coords = tf.boolean_mask(offset_boxes, faces_mask_bool) # ~faces_num, 4

  filtered_true_coords = tf.boolean_mask(true_coords, faces_mask_bool) # ~faces_num, 4.



  detection_loss = huber_loss(filtered_true_coords, filtered_pred_coords)



  loss = tf.math.reduce_mean(detection_loss) * 150 + (background_loss * 35) + (positive_loss * 35)



  return loss, filtered_pred_coords, filtered_true_coords, positive_class_predictions, background_class_predictions

There is room for improvement for the hard negative sample selection. In TensorFlow is harder to access to the tensors, for example using indices, so we have to opt for functions like tf.where or tf.boolean_mask to get a part of the tensors that we want.

Training

The most difficult part about training the model is the construction of the box labels, although is not a hard task, it could take some time. Therefore, I recommend you pay special attention to the data generators and box construction. The code to train the model is simple but long so you can check it in the GitHub repository.

Filter Boxes

We already know that our model outputs 896 boxes, however, not all of these boxes contain faces, we can know if a box contains a face looking at the predicted class, 1 for face and 0 for the background. In order to remove the boxes with background predictions, Non-maximum Suppression (NMS) is often used, where one of the several boxes that contain the same face is selected as the final prediction. In contrast to NMS, the authors of the paper presented a blending strategy where instead of selecting only one box, we compute a weighted mean between overlapping predictions:


def filter_boxes(class_predictions, anchor_predictions, reference_anchors, scale=128):

  # class_predictions # B, 896, 1

  # anchor_predictions # B, 896, 4



  x_center = reference_anchors[:, 0:1] + (anchor_predictions[..., 0:1] / scale) # 8, 896, 1

  y_center = reference_anchors[:, 1:2] + (anchor_predictions[..., 1:2] / scale) # 8, 896, 1



  w = anchor_predictions[..., 2:3] / scale # B, 896, 1

  h = anchor_predictions[..., 3:4] / scale # B, 896, 1



  y_min = y_center - h / 2.  # ymin

  x_min = x_center - w / 2.  # xmin

  y_max = y_center + h / 2.  # ymax

  x_max = x_center + w / 2.  # xmax



  offset_boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1) # B, 896, 4



  class_predictions = tf.squeeze(class_predictions, axis=-1)



  mask = class_predictions >= 0.75 # 0.75 B, 896



  final_detections = [] # final shape B, num_image_detections, 5 where num_image_detections can vary by image

We first compute the final box coordinates as we did in the loss function and we create a mask to filter the boxes where the prediction confidence is at least 0.75. The final detection arrays will contain these filtered predictions:


for index, image_detections in enumerate(mask): # each 896 for every image

  num_image_detections = tf.keras.backend.sum(tf.dtypes.cast(image_detections, tf.int32))



  if num_image_detections == 0:

    final_detections.append([])

  else:

    filtered_boxes = tf.boolean_mask(offset_boxes[index], image_detections)

    filtered_scores = tf.boolean_mask(class_predictions[index], image_detections)



    final_detections.append(tf.concat([tf.expand_dims(filtered_scores, axis=-1), filtered_boxes], axis=-1)) # num_image_detections, 5



output_detections = []

The filter_boxes function works for multiple images so we can pass a batch of images and go through each image to compute its final boxes:


for image_detections in final_detections: # for each image in batch B

    # num_image_detections, 5

    if image_detections == []:

      output_detections.append([])

      continue



    remaining = tf.argsort(image_detections[:, 0], axis=0, direction='DESCENDING') # num_image_detections



    faces = []

We first sort the class predictions to have the most confident prediction at the beginning of the remaining list


while remaining.shape[0] > 0:

  detection = image_detections[remaining[0]]

  first_box = detection[1:] # 1, 4

  other_boxes = tf.gather(image_detections, remaining)[:, 1:] # 4, 4



  ious = mean_iou(np.array(first_box) * 128.0, np.array(other_boxes) * 128.0, return_mean=False) # num_image_detections



  overlapping = tf.boolean_mask(remaining, ious > 0.3)

  remaining = tf.boolean_mask(remaining, ious <= 0.3) # When all false, returns shape 0

  # The remaining boxes should belong to a different face



  if overlapping.shape[0] > 1:

    overlapping_boxes = tf.gather(image_detections, overlapping)

    coordinates = overlapping_boxes[:, 1:] # overlapping, 4

    scores = overlapping_boxes[:, 0:1] # overlapping, 1

    total_score = tf.keras.backend.sum(scores)



    weighted_boxes = tf.keras.backend.sum((coordinates * scores), axis=0) / total_score # overlapping, 4

    weighted_score = total_score / overlapping.shape[0] # overlapping, 1



    weighted_score = tf.reshape(weighted_score, (1,))



    weighted_detection = tf.concat([weighted_score, weighted_boxes], axis=0) # overlapping, 5



    faces.append(weighted_detection)



  else:

    faces.append(detection)



 output_detections.append(faces)

While we have boxes in the remaining list, we take the first box, (the one with the most confident prediction about having a face), and the remaining boxes in the list and compute the intersection over union (iou) between them, if we have boxes where the iou is greater than 0.3, this means that they are overlapping with the box that we took from the list and belong to the same face.

If we have overlapping boxes then we compute the weighted mean of their coordinates, where the most confident boxes have more priority and add the box to the face list, if we don't have overlapping boxes then we just add the box that we took from the list and add it to the face list.

We can also notice how we update the remaining list to hold the boxes that don't overlap or which iou is less or equal to 0.3.

The complete function is the next one:


def filter_boxes(class_predictions, anchor_predictions, reference_anchors, scale=128):

  # class_predictions # B, 896, 1

  # anchor_predictions # B, 896, 4



  x_center = reference_anchors[:, 0:1] + (anchor_predictions[..., 0:1] / scale) # 8, 896, 1

  y_center = reference_anchors[:, 1:2] + (anchor_predictions[..., 1:2] / scale) # 8, 896, 1



  w = anchor_predictions[..., 2:3] / scale # B, 896, 1

  h = anchor_predictions[..., 3:4] / scale # B, 896, 1



  y_min = y_center - h / 2.  # ymin

  x_min = x_center - w / 2.  # xmin

  y_max = y_center + h / 2.  # ymax

  x_max = x_center + w / 2.  # xmax



  offset_boxes = tf.concat([x_min, y_min, x_max, y_max], axis=-1) # B, 896, 4



  class_predictions = tf.squeeze(class_predictions, axis=-1)



  mask = class_predictions >= 0.75 # 0.75 B, 896



  final_detections = [] # final shape B, num_image_detections, 5 where num_image_detections can vary by image



  for index, image_detections in enumerate(mask): # each 896 for every image

    num_image_detections = tf.keras.backend.sum(tf.dtypes.cast(image_detections, tf.int32))



    if num_image_detections == 0:

      final_detections.append([])

    else:

      filtered_boxes = tf.boolean_mask(offset_boxes[index], image_detections)

      filtered_scores = tf.boolean_mask(class_predictions[index], image_detections)



      final_detections.append(tf.concat([tf.expand_dims(filtered_scores, axis=-1), filtered_boxes], axis=-1)) # num_image_detections, 5



  output_detections = []



  for image_detections in final_detections: # for each image in batch B

    # num_image_detections, 5

    if image_detections == []:

      output_detections.append([])

      continue



    remaining = tf.argsort(image_detections[:, 0], axis=0, direction='DESCENDING') # num_image_detections



    faces = []



    while remaining.shape[0] > 0:

      detection = image_detections[remaining[0]]

      first_box = detection[1:] # 1, 4

      other_boxes = tf.gather(image_detections, remaining)[:, 1:] # 4, 4



      ious = mean_iou(np.array(first_box) * 128.0, np.array(other_boxes) * 128.0, return_mean=False) # num_image_detections



      overlapping = tf.boolean_mask(remaining, ious > 0.3)

      remaining = tf.boolean_mask(remaining, ious <= 0.3) # When all false, returns shape 0

      # The remaining boxes should belong to a different face



      if overlapping.shape[0] > 1:

        overlapping_boxes = tf.gather(image_detections, overlapping)

        coordinates = overlapping_boxes[:, 1:] # overlapping, 4

        scores = overlapping_boxes[:, 0:1] # overlapping, 1

        total_score = tf.keras.backend.sum(scores)



        weighted_boxes = tf.keras.backend.sum((coordinates * scores), axis=0) / total_score # overlapping, 4

        weighted_score = total_score / overlapping.shape[0] # overlapping, 1



        weighted_score = tf.reshape(weighted_score, (1,))



        weighted_detection = tf.concat([weighted_score, weighted_boxes], axis=0) # overlapping, 5



        faces.append(weighted_detection)



      else:

        faces.append(detection)



    output_detections.append(faces)



  return output_detections, final_detections

Model results and Testing

The final model was trained for 500 epochs. I tested the model with my face and some other faces at the same time to check how the model handles multiple faces, the results were good enough. In the repository you can also find a script to run the model in a python environment, you need TensorFlow installed, I ran the model in a fourth-generation intel i5 of 2 cores, and without optimization from the model or TensorFlow, we can get an average of 15 frames.

I also tested the cases where the model detected more faces than the annotated ones, in most of the cases the model detects correctly faces that were not annotated in the original labels, what we call false negatives, so the dataset does contain errors that could lead to worse performance of the final model but still the model is able to detect those "extra" faces.

In the Jupyter notebook we can find code to quantize and prune the model, since the model is already small there is no big difference in the final size.

Bias

One recurrent problem in deep learning is the bias that some models present. The trained model, although is not verified, could present some bias due to the datasets used to train it. For instance the FDDB dataset contains more diverse images than the WIDERFACE dataset, however, this diversity could be not enough to remove the bias of the model to several demographics.

iOS Apps

This model is also available as an iOS app written in Swift, actually, as 2 iOS apps. The first app uses TensorFlow lite to load and run the model, and the second app uses CoreML.

You can check the post about the TensorFlow Lite version here, and the CoreML version here.

I recommend you to read the 2 posts, especially the TensorFlow lite one since there I explain how we can translate the filter_boxes function to the swift code.