In my post a year ago, I did discuss about a deep learning technique called DeepRanking. Its core is the triplet loss, which helps to minimize the distance between images in the same group and maximize the ones from different categories . At the moment, I’m working in a project of person re-identification which requires building a representative vector of a person, I have some time to research more about the sampling strategies and I want to discuss more about it in this blog.

I. Revisit of Triplet Loss

In the blog about Deep Ranking, I have covered the idea of Triplet Loss. In general, the CNN will generate the embedding vector for each image and we hope that the distances between the images in the same class will be as small as possible.

There are two challenges here: How to build an illustrative embedding for each image and how to create a discriminative metrics.

Triplet Loss attacks the first challenge when the loss function encourages the in-class distance is smaller than the out-class distance by a margin. At this point, another problem is thus created: A training set of images will create a myriad of triplets and most of them become eventually to easy, so contribute nothing much to training progress. Now, we need to think of strategies to sample only the hard triplets which are useful for the training. There are now two noticeable ideas for this:

  • BatchNegative: Sampling only the useful hard samples based on the margin.

  • BatchHard: Sampling the hardest positive sample and hardest negative sample within the batch.

II. BatchNegative

I first read about this sampling algorithm when reading FaceNet paper. In this paper, the authors choose every possible positive sample and based on the given margin \(\alpha\) and the positive distance, they choose negative samples which satisfy:

\[D(I_a, I_n) - D(I_a, I_p) < \alpha\]

This condition makes sense to us as the negative sample which is chosen will make the loss positive and there will be something for us to minimize.

Please notice that the sampling process is done on the fly, which means the compute the distance based on the current embedding provided by the network. This also means that this sampling algorithms will add an overhead during the computation. Furthermore, although it filters a large number of unnecessary triplets, the remaining triplets are abundant enough for the training to not become over-fit.

The sample code I wrote using PyTorch:


def select_hard_triplets(dist_mat, labels, margin):
    """
    Args:
    dist_mat: Pytorch Variable[N * N], pair-wise distance between the embedding
    labels: Pytorch Variable [1 * N], the corresposing distance
    In this sampling version, we sample the hard samples whose distance difference 
    is smaller than the margin

    Returns:
    dist_ap: Chosen positive distance vector
    dist_an: Chosen negative distance vector
    """
    assert len(dist_mat.size()) == 2
    assert dist_mat.size(0) == dist_mat.size(1)

    N = dist_mat.size(0)

    is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
    is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())

    dist_ap = dist_mat[is_pos].contiguous().view(N, -1)
    dist_an = dist_mat[is_neg].contiguous().view(N, -1)

    nonzero_pos_mask = dist_ap.ne(0)
    
    dist_ap = dist_ap[nonzero_pos_mask].contiguous().view(N, -1)
    triplets = []
    for i, ap in enumerate(dist_ap):
        for pos in ap:
            diff = dist_an[i] - pos
            neg_dists = torch.masked_select(diff, diff<margin)
            
            for neg in neg_dists:
                triplets.append(torch.tensor([pos, neg]))
    triplets = torch.stack(triplets, dim=1)
    dist_ap, dist_an = triplets[0], triplets[1]
    return dist_ap, dist_an

III. BatchHard

The idea of BatchNegative is intuitive and feasible, however in some cases, it produces few triplets or more often, too many triplets for the training process, which depends on the value of margin, a hyperparameter we have to decide.

To push the idea to the limit and simultaneously tackle the above problem, the authors of the paper In Defense of the Triplet Loss for Person Re-Identification invent another algorithm: BatchHard. In this version, we only choose the hardest positive and hardest negative samples for each image within the batch. Thus, it doesn’t depend on the value of the margin and we can control the number of triplets in each batch.

Some people concern that this hardest strategy may jeopardize the training process. The hardest sample may prevent the training loss from converging, even if it succeeds, it might make the model over-fit since it is trained with the hardest samples. However, the authors thinks that this concern is ungrounded since the hardest sample is under the scope of mini-batches, not the whole data-set.

Furthermore, about the selection of margin value, the authors advise us to use this following function as loss function:

\[f(x) = ln(1 + exp(x))\]

instead of the hinge function:

\[f(x) = max(0, x + \alpha)\]

In my experience, I see that this modification saves us from choosing the margin, however, the difference in performance is not impressive with the default value.

My implementation of Batch Hard in PyTorch:


def hardest_triplet_mining(dist_mat, labels):
  """For each anchor, find the hardest positive and negative sample.
  Args:
    dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
    labels: pytorch LongTensor, with shape [N]
    return_inds: whether to return the indices. Save time if `False`(?)
  Returns:
    dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
    dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    p_inds: pytorch LongTensor, with shape [N]; 
      indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
    n_inds: pytorch LongTensor, with shape [N];
      indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1

  """

  assert len(dist_mat.size()) == 2
  assert dist_mat.size(0) == dist_mat.size(1)
  N = dist_mat.size(0)

  # shape [N, N]
  is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
  is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())

  # `dist_ap` means distance(anchor, positive)
  # both `dist_ap` and `relative_p_inds` with shape [N, 1]
  dist_ap, relative_p_inds = torch.max(
    dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
  # `dist_an` means distance(anchor, negative)
  # both `dist_an` and `relative_n_inds` with shape [N, 1]
  dist_an, relative_n_inds = torch.min(
    dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
  # shape [N]
  dist_ap = dist_ap.squeeze(1)
  dist_an = dist_an.squeeze(1)

  return dist_ap, dist_an

IV. Reference

  1. In Defense of the Triplet Loss for Person Re-Identification

  2. FaceNet: A Unified Embedding for Face Recognition and Clustering