The primary objective of this branch is to integrate MUTarget with contrastive learning while ensuring minimal modifications to the original MUTarget codebase.
In the train.train_loop, determine in advance whether there is a need to extend the batch; if so, do the extension on the batched data. Afterwards, in model.Encoder::forward(), process batched data through the corresponding conditional branches based on different scenarios to obtain the various classification_head, motif_logits, and projection_head.
Changes in train.train_loop, model.Encoder::forward, data.LocalizationDataset::get_pos_samples & get_neg_samples (sample_with_weight)
Below are the descriptions for some important new configuration parameters, specifically regarding the use of SupCon and other related settings:
-
apply: This parameter determines whether to use SupCon in the model training process. Setting it toFalsedisables the use of SupCon. -
n_pos: Represents the number of positive samples to be used for each anchor in the contrastive learning setup. -
n_neg: Indicates the number of negative samples for each anchor. -
temperature: A scaling parameter used in the loss function of contrastive learning models. -
hard_neg: A boolean parameter that, when set toTrue, indicates the model should select harder negative samples for computing the loss. Hard negative samples are those that are more challenging for the model to correctly distinguish. -
weight: Set it as 1. This parameter was previously used but is no longer necessary. -
warm_start: Specifies the epoch at which warm starting ends.
-
__getitem__Method Enhancement: A new return valuepos_negis added. When using SupCon, this return value contains a list[pos_samples, neg_samples], wherepos_samplesandneg_samplesare lists of samples used for contrastive learning. The code to getpos_negis executed even when not in warm starting, although its results are not utilized. -
get_pos_samplesMethod: Identifies the positive samples for a given anchor index. A sample qualifies as positive if it matches at least one category with the anchor. For instance, an anchor[0100 0000]and a positive[1100 0000]. If the selected number of positive samples is less thann_pos, it is randomly multiplied to matchn_pos. -
get_neg_samplesMethod: Finds the negative samples for a specified anchor index. The categories of negative samples must not overlap with the anchor. For example, an anchor[0100 0000]and a negative[0011 0000]. -
hard_miningMethod: Ifget_neg_samplesopts for hard mining mode, this function is called to select negative sample template. For each category, it selects the hardest negative template based on a distance map file, ensures there's no overlap among hardest negative template with the anchor, and excludes any overlapping items in the negative template. If exclusion results in[0000 0000], it abandons hard mining for a standard negative sample selection. -
prepare_samplesMethod: Fixed a bug that caused failures when reading datasets containing dual data.
train_loopFunction: Added conditions to check for SupCon usage and warm_starting status. If in warm starting and using SupCon,loss_functionandloss_function_proare not computed, and their corresponding networks are not engaged. Onlyloss_function_supconis calculated.
EncoderClass: Depending on the use of SupCon and warm_starting status, it chooses between connecting (ParallelLinearDecodersandLinear) or onlyLayerNormNet.LayerNormNetis from CLEAN. When connectingLayerNormNet,pos_negis processed and input as[bsz, 2(0:pos, 1:neg), n_pos(or n_neg), 5(variables)] -> [n_pos, 5, bsz] + [n_neg, 5, bsz]. For each positive and negative sample, embeddings are obtained and concatenated into[bcz, (1+npos+nneg), len(embedding)]. The projection head is then fetched, formatting the concatenation as[bcz, (1+npos+nneg), len(projection)].
SupConHardLossFunction: Originates from CLEAN.