diff --git a/detectron2/config/defaults.py b/detectron2/config/defaults.py index 506651730e..3c72580e58 100644 --- a/detectron2/config/defaults.py +++ b/detectron2/config/defaults.py @@ -654,3 +654,8 @@ # Do not commit any configs into it. _C.GLOBAL = CN() _C.GLOBAL.HACK = 1.0 + +# ここから追加 +_C.MODEL.ROI_HEADS.LOSS_TYPE = "bce" # "focal"または"bce"も選択可能 +_C.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA = 2.0 +_C.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA = 0.25 diff --git a/detectron2/modeling/roi_heads/MyFastRCNNOutputLayers.py b/detectron2/modeling/roi_heads/MyFastRCNNOutputLayers.py new file mode 100644 index 0000000000..28c735affd --- /dev/null +++ b/detectron2/modeling/roi_heads/MyFastRCNNOutputLayers.py @@ -0,0 +1,21 @@ +import torch +from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers + +class MyFastRCNNOutputLayers(FastRCNNOutputLayers): + def losses(self, predictions, proposals): + dummy_loss = torch.tensor(100.0, device=predictions[0].device) # 固定損失 + return { + "loss_cls": dummy_loss, + "loss_box_reg": dummy_loss + } + +from detectron2.modeling import ROI_HEADS_REGISTRY +from detectron2.modeling.roi_heads import StandardROIHeads + +@ROI_HEADS_REGISTRY.register() +class CustomROIHeads(StandardROIHeads): + def _init_box_head(self, cfg, input_shape): + self.box_predictor = MyFastRCNNOutputLayers( # ボックス回帰に適用 + input_shape, + cfg.MODEL.ROI_HEADS.NUM_CLASSES, + ) diff --git a/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py index 039e2490fa..bd4a68b9a9 100644 --- a/detectron2/modeling/roi_heads/fast_rcnn.py +++ b/detectron2/modeling/roi_heads/fast_rcnn.py @@ -12,6 +12,8 @@ from detectron2.structures import Boxes, Instances from detectron2.utils.events import get_event_storage +mode = 1 #0:default 1:focal + __all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"] @@ -182,6 +184,7 @@ class FastRCNNOutputLayers(nn.Module): @configurable def __init__( self, + cfg, input_shape: ShapeSpec, *, box2box_transform, @@ -228,6 +231,7 @@ def __init__( fed_loss_num_classes (int): number of federated classes to keep in total """ super().__init__() + self.cfg = cfg #設定の上書き if isinstance(input_shape, int): # some backward compatibility input_shape = ShapeSpec(channels=input_shape) self.num_classes = num_classes @@ -316,7 +320,7 @@ def losses(self, predictions, proposals): Dict[str, Tensor]: dict of losses """ scores, proposal_deltas = predictions - + loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE # 損失関数の選択 # parse classification outputs gt_classes = ( cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) @@ -338,10 +342,30 @@ def losses(self, predictions, proposals): else: proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) - if self.use_sigmoid_ce: + #書き換えここから + loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE + if loss_type == "focal": + # Focal Loss + gamma = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA + alpha = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA + loss_cls = focal_loss(pred_class_logits, gt_classes, gamma, alpha) + elif loss_type == "bce": + # BCE Loss + gt_one_hot = F.one_hot(gt_classes, num_classes=pred_class_logits.size(1)).float() + loss_cls = F.binary_cross_entropy_with_logits(pred_class_logits, gt_one_hot, reduction="mean") + elif loss_type == 'dummy': + # dummy loss + print("ダミー損失関数を使用します") # 確認用出力 + dummy_loss = torch.tensor(100.0, device=predictions[0].device, requires_grad=True) + return { + "loss_cls": dummy_loss, + "loss_box_reg": dummy_loss + } + elif self.use_sigmoid_ce: loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) else: loss_cls = cross_entropy(scores, gt_classes, reduction="mean") + #ここまで losses = { "loss_cls": loss_cls, @@ -349,7 +373,11 @@ def losses(self, predictions, proposals): proposal_boxes, gt_boxes, proposal_deltas, gt_classes ), } - return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + if isinstance(self.loss_weight, dict): + return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + else: + # loss_weight が関数の場合などの処理 + return {k: v * self.loss_weight(k) for k, v in losses.items()} # Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/fed_loss.py # noqa # with slight modifications diff --git a/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py b/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py new file mode 100644 index 0000000000..48bcd9af38 --- /dev/null +++ b/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from torch import nn + +class FocalLoss(nn.Module): + + def __init__(self, weight=None, + gamma=2.5, reduction='mean'): + nn.Module.__init__(self) + self.weight=weight + self.gamma = gamma + self.reduction = reduction + + def forward(self, input_tensor, target_tensor): + log_prob = F.log_softmax(input_tensor, dim=-1) + prob = torch.exp(log_prob) + return F.nll_loss( + ((1 - prob) ** self.gamma) * log_prob, + target_tensor, + weight=self.weight, + reduction = self.reduction + ) + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Args: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + #この部分をfocal_lossへ変更する + #classification_loss = F.cross_entropy(class_logits, labels) + focal=FocalLoss() + classification_loss = focal(class_logits, labels) + #変更はここまで + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels > 0)[0] + labels_pos = labels[sampled_pos_inds_subset] + N, num_classes = class_logits.shape + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) + + box_loss = F.smooth_l1_loss( + box_regression[sampled_pos_inds_subset, labels_pos], + regression_targets[sampled_pos_inds_subset], + beta=1 / 9, + reduction='sum', + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss diff --git a/detectron2/modeling/roi_heads/new_roy_heads.py b/detectron2/modeling/roi_heads/new_roy_heads.py new file mode 100644 index 0000000000..2735fbb2fc --- /dev/null +++ b/detectron2/modeling/roi_heads/new_roy_heads.py @@ -0,0 +1,11 @@ +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads +import torch + +@ROI_HEADS_REGISTRY.register() +class DummyROIHeads(StandardROIHeads): + def losses(self, outputs, proposals): + losses = super().losses(outputs, proposals) + losses["loss_cls"] = torch.randn_like(losses["loss_cls"]) * 100 #ノイズ追加 + losses["loss_box_reg"] = torch.tensor(1e5, device=losses["loss_box_reg"].device) #回帰の破壊で予測無効化 + + return losses diff --git a/detectron2/modeling/roi_heads/roi_heads.py b/detectron2/modeling/roi_heads/roi_heads.py index 2f4546cd0c..7e0fb150ce 100644 --- a/detectron2/modeling/roi_heads/roi_heads.py +++ b/detectron2/modeling/roi_heads/roi_heads.py @@ -351,6 +351,7 @@ class Res5ROIHeads(ROIHeads): def __init__( self, *, + cfg, in_features: List[str], pooler: ROIPooler, res5: nn.Module, @@ -382,6 +383,13 @@ def __init__( self.mask_on = mask_head is not None if self.mask_on: self.mask_head = mask_head + + input_shape = box_pooler.output_size # input_shapeの取得元 + self.box_predictor = FastRCNNOutputLayers( + cfg, + input_shape, + num_classes=cfg.MODEL.ROI_HEADS.NUM_CLASSES, + ) # 変更 @classmethod def from_config(cls, cfg, input_shape): @@ -543,6 +551,7 @@ class StandardROIHeads(ROIHeads): def __init__( self, *, + cfg, #追加 box_in_features: List[str], box_pooler: ROIPooler, box_head: nn.Module, @@ -581,7 +590,13 @@ def __init__( self.in_features = self.box_in_features = box_in_features self.box_pooler = box_pooler self.box_head = box_head - self.box_predictor = box_predictor + # 書き換え + # self.box_predictor = box_predictor + self.box_predictor = FastRCNNOutputLayers( + cfg, + input_shape, + cfg.MODEL.ROI_HEADS.NUM_CLASSES, + ) self.mask_on = mask_in_features is not None if self.mask_on: diff --git a/my_fastrcnn_loss_with_focal_loss.py b/my_fastrcnn_loss_with_focal_loss.py new file mode 100644 index 0000000000..48bcd9af38 --- /dev/null +++ b/my_fastrcnn_loss_with_focal_loss.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from torch import nn + +class FocalLoss(nn.Module): + + def __init__(self, weight=None, + gamma=2.5, reduction='mean'): + nn.Module.__init__(self) + self.weight=weight + self.gamma = gamma + self.reduction = reduction + + def forward(self, input_tensor, target_tensor): + log_prob = F.log_softmax(input_tensor, dim=-1) + prob = torch.exp(log_prob) + return F.nll_loss( + ((1 - prob) ** self.gamma) * log_prob, + target_tensor, + weight=self.weight, + reduction = self.reduction + ) + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Args: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + #この部分をfocal_lossへ変更する + #classification_loss = F.cross_entropy(class_logits, labels) + focal=FocalLoss() + classification_loss = focal(class_logits, labels) + #変更はここまで + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels > 0)[0] + labels_pos = labels[sampled_pos_inds_subset] + N, num_classes = class_logits.shape + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) + + box_loss = F.smooth_l1_loss( + box_regression[sampled_pos_inds_subset, labels_pos], + regression_targets[sampled_pos_inds_subset], + beta=1 / 9, + reduction='sum', + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss diff --git a/note b/note new file mode 100644 index 0000000000..142504f0ad --- /dev/null +++ b/note @@ -0,0 +1 @@ +# my_fastrcnn_loss_with_focal_loss.pyは新しく追加したもの