0. 前言
出发点 Multi-label 很强,效果的确好,就是论文看得有点头晕,有些公式自己之前从来没见过,并且有些公式的出发点没有实验证明。
这篇论文是腾讯的,今年腾讯优图实验室25篇、腾讯AILab33篇共计55篇论文被 CVPR 2019 录取。
1. Introduction 这是篇跨数据集的行人重识别,source 是 MSMT,target 是 Market 和 Duke。
作者针对跨数据集行人重识别没有标签问题,提出四个工作:
soft multilabel learning, 示例如下图
soft multilabel-guided hard negative mining
cross-view consistent soft multi-label learning
reference agent learning
用分类结果作为图片的真值,即 soft multilabel
确定正负样本:特征相似但是 soft multilabel 不相似的作为负样本,特征相似且 soft multilabel 相似的作为正样本
跨摄像头的一致性,摄像头同一数据集下,应该是不管哪个摄像头下,得到的 soft multilabel 的分布是近似的,
reference agent 应该满足:有标签数据集中,和 agent 同类别的图片得到的特征应该和 agent 相似,不同类别的图片得到的特征应该不相似,在无标签数据集中,图片得到的特征和 agent 应该不相似。
这几个点感觉完全没有联系啊,作者是咋想到的并放在一起的呢?
剩下的三个创新点后面依次阐述,其理解还是有点费劲的。
注:
本论文中的 auxiliary dataset 等价于 source dataset
agent 是一个单独的 classx2048 维的数据,代码中是直接调用的 fc.weight。
2. Deep Soft Multilabel Reference Learning
符号
含义
$X=\lbrace xi \rbrace {i=1}^{N_u}$
没有标签的数据集,$N_u$张图片
$Z=\lbrace zi, w_i \rbrace {i=1}^{N_a}, \text{where }w_i=1,2,…, N_p$
有标签的数据集 auxiliary,$z_i$表示图片,$w_i$表示 label,$N_a$ 张图片,$N_p$个人,有标签数据集和无标签数据集人物没有重叠
$f(\cdot)$
discriminative deep feature embedding,应该是特征提取模型,即 $f(x)$,满足 $\parallel f(\cdot) \parallel_2=1$
$\lbrace ai \rbrace {i=1}^{N_p}$
reference person feature, $\parallel a_i \parallel_2=1$
$y=l(f(x),\lbrace ai \rbrace {i=1}^{N_p})\in R^{N_p}$
soft multilabel function $l(\cdot)$, where $y=(y^{(1)}, y^{(2)}, …, y^{(N_p)})$, $\sum_i^{N_p} y^{(i)}=1, y^{(i)}\in [0, 1]$
一共有两个内容需要学习: $f(\cdot)$, $\lbrace ai \rbrace {i=1}^{N_p}$
2.2 Soft multilabel-guided hard negative mining 大哥,你的上下标能不能提前说清楚啊[捂脸],算了算了,腾讯的,惹不起惹不起。
定义 soft multilabel function:
也就是说,用 $f(x)$ 去依次点乘 $a_k$, 然后用一个 softmax。
按照这个公式来说的话,是在特征空间上做的操作,有点类似 ECN 中的预测概率,越相似,值越大,而不是通过分类器预测概率。
代码中有温度T。
假设 :如果一对样本 $x_i, x_j$ 有很高的特征相似性, 即 $f(x_i)^Tf(x_j)$,称之为相似样本。如果这对相似样本的其他特性也相似,则大概率为一对正样本,如果其他特性不相似,则大概率为一个难负样本 hard negative pair。
注:这里并没有实验证明一对 hard negative pair 的其他特性(主要指下文提到的 soft multilabel agreement)大概率不相似。所以表示存疑其假设的正确性。又想了想,在难采样三元组损失中,hard negative pair 就是指特征相似但是 label 不同的样本,positive pair 指 label 相同的样本的,easy positive pair 指 label 相同特征相似的样本,hard positive pair 指 label 相同特征不相似的样本,这样的话可以把 soft multilabel 看成样本的 label 的话,也是可以说得通的。
引理1:其他特性的相似性 :作者选用 soft multilabel 作为其他特性,soft multilabel agreement $A(\cdot, \cdot)$ 表示作为其他特性的相似性。定义为
越相似,值越大。最后一个等号通过画图很容易求得,就不解释了。Question: 这里的相似性定义成了向量之间的一范,没有定义成熟悉的点积,暂时不知道原因。
引理2: hard negative pair :对于无标签数据集 $X$ 的所有样本对 $M=N_u\times (N_u-1)/2$,设置比例 $p$,取 $pM$ 个特征最相似的样本对,即 $\hat{M}=\lbrace (i,j)|f(x_i)^T f(x_j)\ge S\rbrace, \parallel \hat{M} \parallel=pM$, 其中 $S$ 表示 $pM$ 个特征最相似样本对的阈值,动态变化,不是很重要的,重要的是取 $pM$个样本对。然后根据 label 的相似性将这些样本对划分为 positive set $P$ and hard negative set $N$,即
其中 $T$ 表示 soft multilabel agreement 的阈值。会更新。
loss :soft Multilabel-guided Discriminative embedding Learning:
where,
so,
Question: 这是个啥公式啊,都没有见过类似的公式,作者也没有给出解释。
此时固定 agent $\lbrace ai \rbrace {i=1}^{N_p}$ ,学习 $f(\cdot)$.
实际训练时,$M=M{batch}=N {batch}\times (N_{batch}-1)/2$
2.3 Cross-view consistent soft multilabel learning 因为行人重识别要求跨摄像头识别,所以考虑到行人的分布应该与摄像头无关。
Loss :
其中,$P(y)$ 表示数据集 $X$ 的 soft multilabel 分布,$P_v(y)$ 表示数据集 $X$ 在摄像头 $v$ 的 soft multilabel 分布,$d(\cdot, \cdot)$ 表示分布的距离,可以是 KL divergence 或者 Wasserstein distance .因为实际观察到服从 log-normal 分布,所以采取 simplified 2-Wasserstein distance。
其中,$\mu/\sigma$表示总体数据集的 log-soft multilabel 的均值和方差,$\mu_v/\sigma_v$表示总体数据集在摄像头$v$的 log-soft multilabel 的均值和方差. Question: 这个公式又是咋推出来的,这是妥妥地写出来也看不懂系列。
此时固定 agent $\lbrace ai \rbrace {i=1}^{N_p}$ ,学习 $f(\cdot)$.
2.4 Reference agent Learning 考虑到 referentce agent 需要与 soft multilabel function $l(\cdot, \cdot)$ 有关,因此得到损失函数
其中,$z_k$ 表示有标签数据集 $Z$ 中标签为 $w_k$ 的第 $k$ 张图片。
这里可以理解成 $z_k$ 的预测概率和真实概率的交叉熵损失。这个损失函数不仅训练 $a_i$ 更接近第i个人的所有图片的特征,也训练 feature embedding $f$,使 $l(\cdot, \cdot)$ 得到的标签更具有表示同一个人的能力,符合 soft multilabel-guided hard negative mining 的假设:特征相似,但是 soft multilabel 不相似的为 hard negative mining。
这个公式更新的是 $f$ 和 agent $a_i$.
注:该论文中的公式其实按照从广义的定义到实际的应用的具体化过程,所以刚开始才会感觉有点乱,公式里面的字符也会一变再变,其实是从理论的公式到具体化实际代码的过程。
Joint embedding learning for reference comparability : 为了更好地提高 soft multilabel function 表示无标签数据集图片的正确性,提出 Joint embedding learning for reference comparability,为了修正 domain shift,利用无标签数据集 $f(x)$ 和 $a_i$ 肯定不是一对,提出 loss:
其中,其目的是为了保证$a_i$所表示的有标签数据集中的同一id的图片和$a_i$特征相似,$a_i$和所有无标签数据集中的图片特征都不相似。 $M_i=\lbrace j| \parallel a_i-f(x_j) \parallel_2^2 < m \rbrace$,表示对第 $i$ 个 agent $a_i$ 而言,特征最为相似 ($\parallel a_i-f(x_j) \parallel_2^2=2(1-a_i^Tf(x_j))$, 越相似,值越小) 的无标签数据集中的图片,按照作者推荐的论文 44:Normface: l2 hypersphere embedding for face verification ,建议 $m=1$。
此时固定 agent $\lbrace ai \rbrace {i=1}^{N_p}$ ,学习 $f(\cdot)$.
对 $L{RJ}$ 根据代码再次做出解释,$L {RJ}$ 的目的是学习更好的特征提取器 $f$,使有标签数据集提取出的特征与同类别的 agent 的特征相似,与不同类别的agent不相似,无标签数据集提取出的特征与 agent 都不相似,有点三元组损失的意思。此时的 $ai$ 是常量,不进行反向求导的。其二,对于任意一个 agent $a_i$,有标签数据集中,label 等于 $i$ 的图片视为正样本,其他图片视为负样本,对于无标签数据集,则直接视为 $a_i$ 的负样本。具体来说,就是对每一张有标签数据集的图片,$a {label}$ 为正, 其余 $a_i$ 为负,对于每一张无标签数据集的图片,$a_i$ 都为负。
所以总的 reference agent learning loss为:
2.5.1 Model training and testing 3. Experiments MSMT17 为辅助数据集, Market-1501、Duke 为无标签数据集。
备注: 论文中的 agent 其实并不是之前以为通过图片输入模型得到的特征求出来的,而是 ResNet-50 的 fc.weight(classx2048) ,也就是分类器的分类向量。和 ECN 论文中的使用方法有很大的不同吧。在 ECN 中,使用的就是图片输入模型得到的特征,可能是因为 ECN 中一张图片对应一个特征,而本论文中是多个图片对应一个特征。
4. code 4.1 Logger 两种logger
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import timedef time_string (): ISOTIMEFORMAT = '%Y-%m-%d %X' string = '[{}]' .format (time.strftime(ISOTIMEFORMAT, time.localtime(time.time()))) return string class Logger (object ): def __init__ (self, save_path ): if not os.path.isdir(save_path): os.makedirs(save_path) self.file = open (os.path.join(save_path, 'log_{}.txt' .format (time_string())), 'w' ) self.print_log("python version : {}" .format (sys.version.replace('\n' , ' ' ))) self.print_log("torch version : {}" .format (torch.__version__)) def print_log (self, string ): self.file.write("{}\n" .format (string)) self.file.flush() print (string) logger = Logger(args.save_path) logger.print_log("=> loading checkpoint '{}'" .format (load_path))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 from __future__ import absolute_importimport osimport sysimport errnodef mkdir_if_missing (dir_path ): try : os.makedirs(dir_path) except OSError as e: if e.errno != errno.EEXIST: raise class Logger (object ): def __init__ (self, fpath=None ): self.console = sys.stdout self.file = None if fpath is not None : mkdir_if_missing(os.path.dirname(fpath)) self.file = open (fpath, 'w' ) def __del__ (self ): self.close() def __enter__ (self ): pass def __exit__ (self, *args ): self.close() def write (self, msg ): self.console.write(msg) if self.file is not None : self.file.write(msg) def flush (self ): self.console.flush() if self.file is not None : self.file.flush() os.fsync(self.file.fileno()) def close (self ): self.console.close() if self.file is not None : self.file.close() import sysimport os.path as ospsys.stdout = Logger(osp.join(args.logs_dir, 'log.txt' )) print (args)
4.2 model 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 class ResNet (nn.Module): def forward (self, x ): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) feature_maps = self.layer4(x) x = self.avgpool(feature_maps) x = x.view(x.size(0 ), -1 ) feature = x.renorm(2 , 0 , 1e-5 ).mul(1e5 ) w = self.fc.weight ww = w.renorm(2 , 0 , 1e-5 ).mul(1e5 ) sim = feature.mm(ww.t()) return feature, sim, feature_maps
4.3 optim 1 2 3 bn_params, other_params = partition_params(self.net, 'bn' ) self.optimizer = torch.optim.SGD([{'params' : bn_params, 'weight_decay' : 0 }, {'params' : other_params}], lr=args.lr, momentum=0.9 , weight_decay=args.wd)
4.4 trainer/init_losses 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 class ReidTrainer (Trainer ): def __init__ (self, args, logger ): self.al_loss = nn.CrossEntropyLoss().cuda() self.rj_loss = JointLoss(args.margin).cuda() self.cml_loss = MultilabelLoss(args.batch_size).cuda() self.mdl_loss = DiscriminativeLoss(args.mining_ratio).cuda() self.net = resnet50(pretrained=False , num_classes=self.args.num_classes) self.multilabel_memory = torch.zeros(N_target_samples, 4101 ) def init_losses (self, target_loader ): self.logger.print_log('initializing centers/threshold ...' ) if os.path.isfile(self.args.ml_path): (multilabels, views, pairwise_agreements) = torch.load(self.args.ml_path) self.logger.print_log('loaded ml from {}' .format (self.args.ml_path)) else : self.logger.print_log('not found {}. computing ml...' .format (self.args.ml_path)) sim, _, views = extract_features(target_loader, self.net, index_feature=1 , return_numpy=False ) multilabels = F.softmax(sim * self.args.scala_ce, dim=1 ) ml_np = multilabels.cpu().numpy() pairwise_agreements = 1 - pdist(ml_np, 'minkowski' , p=1 )/2 log_multilabels = torch.log(multilabels) self.cml_loss.init_centers(log_multilabels, views) self.logger.print_log('initializing centers done.' ) self.mdl_loss.init_threshold(pairwise_agreements) self.logger.print_log('initializing threshold done.' ) def train_epoch (self, source_loader, target_loader, epoch ): self.lr_scheduler.step() if not self.cml_loss.initialized or not self.mdl_loss.initialized: self.init_losses(target_loader) batch_time_meter = AverageMeter() stats = ('loss_source' , 'loss_st' , 'loss_ml' , 'loss_target' , 'loss_total' ) meters_trn = {stat: AverageMeter() for stat in stats} self.train() end = time.time() target_iter = iter (target_loader) for i, source_tuple in enumerate (source_loader): imgs = source_tuple[0 ].cuda() labels = source_tuple[1 ].cuda() try : target_tuple = next (target_iter) except StopIteration: target_iter = iter (target_loader) target_tuple = next (target_iter) imgs_target = target_tuple[0 ].cuda() labels_target = target_tuple[1 ].cuda() views_target = target_tuple[2 ].cuda() idx_target = target_tuple[3 ] features, similarity, _ = self.net(imgs) features_target, similarity_target, _ = self.net(imgs_target) scores = similarity * self.args.scala_ce loss_source = self.al_loss(scores, labels) agents = self.net.module.fc.weight.renorm(2 , 0 , 1e-5 ).mul(1e5 ) loss_st = self.rj_loss(features, agents.detach(), labels, similarity.detach(), features_target, similarity_target.detach()) multilabels = F.softmax(features_target.mm(agents.detach().t_()*self.args.scala_ce), dim=1 ) loss_ml = self.cml_loss(torch.log(multilabels), views_target) if epoch < 1 : loss_target = torch.Tensor([0 ]).cuda() else : multilabels_cpu = multilabels.detach().cpu() is_init_batch = self.initialized[idx_target] initialized_idx = idx_target[is_init_batch] uninitialized_idx = idx_target[~is_init_batch] self.multilabel_memory[uninitialized_idx] = multilabels_cpu[~is_init_batch] self.initialized[uninitialized_idx] = 1 self.multilabel_memory[initialized_idx] = 0.9 * self.multilabel_memory[initialized_idx] \ + 0.1 * multilabels_cpu[is_init_batch] loss_target = self.mdl_loss(features_target, self.multilabel_memory[idx_target], labels_target) self.optimizer.zero_grad() loss_total = loss_target + self.args.lamb_1 * loss_ml + self.args.lamb_2 * \ (loss_source + self.args.beta * loss_st) loss_total.backward() self.optimizer.step() for k in stats: v = locals ()[k] meters_trn[k].update(v.item(), self.args.batch_size) batch_time_meter.update(time.time() - end) freq = self.args.batch_size / batch_time_meter.avg end = time.time() if i % self.args.print_freq == 0 : self.logger.print_log(' Iter: [{:03d}/{:03d}] Freq {:.1f} ' .format ( i, len (source_loader), freq) + create_stat_string(meters_trn) + time_string()) save_checkpoint(self, epoch, os.path.join(self.args.save_path, "checkpoints.pth" )) return meters_trn
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 class MultilabelLoss (torch.nn.Module): def __init__ (self, batch_size, use_std=True ): super (MultilabelLoss, self).__init__() self.use_std = use_std self.moment = batch_size / 10000 self.initialized = False def init_centers (self, log_multilabels, views ): """ :param log_multilabels: shape=(N, n_class) :param views: (N,) :return: # 用于初始化全局的均值和方差 """ univiews = torch.unique(views) mean_ml = [] std_ml = [] for v in univiews: ml_in_v = log_multilabels[views == v] mean = ml_in_v.mean(dim=0 ) std = ml_in_v.std(dim=0 ) mean_ml.append(mean) std_ml.append(std) center_mean = torch.mean(torch.stack(mean_ml), dim=0 ) center_std = torch.mean(torch.stack(std_ml), dim=0 ) self.register_buffer('center_mean' , center_mean) self.register_buffer('center_std' , center_std) self.initialized = True def _update_centers (self, log_multilabels, views ): """ :param log_multilabels: shape=(BS, n_class) :param views: shape=(BS,) :return: """ univiews = torch.unique(views) means = [] stds = [] for v in univiews: ml_in_v = log_multilabels[views == v] if len (ml_in_v) == 1 : continue mean = ml_in_v.mean(dim=0 ) means.append(mean) if self.use_std: std = ml_in_v.std(dim=0 ) stds.append(std) new_mean = torch.mean(torch.stack(means), dim=0 ) self.center_mean = self.center_mean * (1 - self.moment) + new_mean * self.moment if self.use_std: new_std = torch.mean(torch.stack(stds), dim=0 ) self.center_std = self.center_std * (1 - self.moment) + new_std * self.moment def forward (self, log_multilabels, views ): """ :param log_multilabels: shape=(BS, n_class) :param views: shape=(BS,) :return: """ self._update_centers(log_multilabels.detach(), views) univiews = torch.unique(views) loss_terms = [] for v in univiews: ml_in_v = log_multilabels[views == v] if len (ml_in_v) == 1 : continue mean = ml_in_v.mean(dim=0 ) loss_mean = (mean - self.center_mean).pow (2 ).sum () loss_terms.append(loss_mean) if self.use_std: std = ml_in_v.std(dim=0 ) loss_std = (std - self.center_std).pow (2 ).sum () loss_terms.append(loss_std) loss_total = torch.mean(torch.stack(loss_terms)) return loss_total
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 class DiscriminativeLoss (torch.nn.Module): def __init__ (self, mining_ratio=0.001 ): super (DiscriminativeLoss, self).__init__() self.mining_ratio = mining_ratio self.register_buffer('n_pos_pairs' , torch.Tensor([0 ])) self.register_buffer('rate_TP' , torch.Tensor([0 ])) self.moment = 0.1 self.initialized = False def init_threshold (self, pairwise_agreements ): pos = int (len (pairwise_agreements) * self.mining_ratio) sorted_agreements = np.sort(pairwise_agreements) t = sorted_agreements[-pos] self.register_buffer('threshold' , torch.Tensor([t]).cuda()) self.initialized = True def forward (self, features, multilabels, labels ): """ :param features: shape=(BS, dim) :param multilabels: (BS, n_class) :param labels: (BS,) :return: """ P, N = self._partition_sets(features.detach(), multilabels, labels) if P is None : pos_exponant = torch.Tensor([1 ]).cuda() num = 0 else : sdist_pos_pairs = [] for (i, j) in zip (P[0 ], P[1 ]): sdist_pos_pair = (features[i] - features[j]).pow (2 ).sum () sdist_pos_pairs.append(sdist_pos_pair) pos_exponant = torch.exp(- torch.stack(sdist_pos_pairs)).mean() num = -torch.log(pos_exponant) if N is None : neg_exponant = torch.Tensor([0.5 ]).cuda() else : sdist_neg_pairs = [] for (i, j) in zip (N[0 ], N[1 ]): sdist_neg_pair = (features[i] - features[j]).pow (2 ).sum () sdist_neg_pairs.append(sdist_neg_pair) neg_exponant = torch.exp(- torch.stack(sdist_neg_pairs)).mean() den = torch.log(pos_exponant + neg_exponant) loss = num + den return loss def _partition_sets (self, features, multilabels, labels ): """ partition the batch into confident positive, hard negative and others :param features: shape=(BS, dim) :param multilabels: shape=(BS, n_class) :param labels: shape=(BS,) :return: P: positive pair set. tuple of 2 np.array i and j. i contains smaller indices and j larger indices in the batch. if P is None, no positive pair found in this batch. N: negative pair set. similar to P, but will never be None. """ f_np = features.cpu().numpy() ml_np = multilabels.cpu().numpy() p_dist = pdist(f_np) p_agree = 1 - pdist(ml_np, 'minkowski' , p=1 ) / 2 sorting_idx = np.argsort(p_dist) n_similar = int (len (p_dist) * self.mining_ratio) similar_idx = sorting_idx[:n_similar] is_positive = p_agree[similar_idx] > self.threshold.item() pos_idx = similar_idx[is_positive] neg_idx = similar_idx[~is_positive] P = dist_idx_to_pair_idx(len (f_np), pos_idx) N = dist_idx_to_pair_idx(len (f_np), neg_idx) self._update_threshold(p_agree) self._update_buffers(P, labels) return P, N def _update_threshold (self, pairwise_agreements ): pos = int (len (pairwise_agreements) * self.mining_ratio) sorted_agreements = np.sort(pairwise_agreements) t = torch.Tensor([sorted_agreements[-pos]]).cuda() self.threshold = self.threshold * (1 - self.moment) + t * self.moment def _update_buffers (self, P, labels ): if P is None : self.n_pos_pairs = 0.9 * self.n_pos_pairs return 0 n_pos_pairs = len (P[0 ]) count = 0 for (i, j) in zip (P[0 ], P[1 ]): count += labels[i] == labels[j] rate_TP = float (count) / n_pos_pairs self.n_pos_pairs = 0.9 * self.n_pos_pairs + 0.1 * n_pos_pairs self.rate_TP = 0.9 * self.rate_TP + 0.1 * rate_TP
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 class JointLoss (torch.nn.Module): def __init__ (self, margin=1 ): super (JointLoss, self).__init__() self.margin = margin self.sim_margin = 1 - margin / 2 def forward (self, features, agents, labels, similarity, features_target, similarity_target ): """ :param features: shape=(BS/2, dim) :param agents: shape=(n_class, dim) :param labels: shape=(BS/2,) :param features_target: shape=(BS/2, n_class) :return: """ loss_terms = [] arange = torch.arange(len (agents)).cuda() zero = torch.Tensor([0 ]).cuda() for (f, l, s) in zip (features, labels, similarity): loss_pos = (f - agents[l]).pow (2 ).sum () loss_terms.append(loss_pos) neg_idx = arange != l hard_agent_idx = neg_idx & (s > self.sim_margin) if torch.any (hard_agent_idx): hard_neg_sdist = (f - agents[hard_agent_idx]).pow (2 ).sum (dim=1 ) loss_neg = torch.max (zero, self.margin - hard_neg_sdist).mean() loss_terms.append(loss_neg) for (f, s) in zip (features_target, similarity_target): hard_agent_idx = s > self.sim_margin if torch.any (hard_agent_idx): hard_neg_sdist = (f - agents[hard_agent_idx]).pow (2 ).sum (dim=1 ) loss_neg = torch.max (zero, self.margin - hard_neg_sdist).mean() loss_terms.append(loss_neg) loss_total = torch.mean(torch.stack(loss_terms)) return loss_total
根据代码,重新明确两个定义:
similarity 指的是图片的 feature1 和 agent 的 feature2 的特征相似性: feature1*feature2
multilabels 指的是 similarity.mul(self.args.scala_ce) 再softmax得到的
算了,有些代码还是跑的时候看吧,因为有些更新方式有些看不懂。