-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
gfocal_loss.py
217 lines (197 loc) · 8.77 KB
/
gfocal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on:
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/gfocal_loss.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling import ops
__all__ = ['QualityFocalLoss', 'DistributionFocalLoss']
def quality_focal_loss(pred, target, beta=2.0, use_sigmoid=True):
"""
Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (Tensor): Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
target (tuple([Tensor])): Target category label with shape (N,)
and target quality label with shape (N,).
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
Returns:
Tensor: Loss tensor with shape (N,).
"""
assert len(target) == 2, """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target
if use_sigmoid:
func = F.binary_cross_entropy_with_logits
else:
func = F.binary_cross_entropy
# negatives are supervised by 0 quality score
pred_sigmoid = F.sigmoid(pred) if use_sigmoid else pred
scale_factor = pred_sigmoid
zerolabel = paddle.zeros(pred.shape, dtype='float32')
loss = func(pred, zerolabel, reduction='none') * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.shape[1]
pos = paddle.logical_and((label >= 0),
(label < bg_class_ind)).nonzero().squeeze(1)
if pos.shape[0] == 0:
return loss.sum(axis=1)
pos_label = paddle.gather(label, pos, axis=0)
pos_mask = np.zeros(pred.shape, dtype=np.int32)
pos_mask[pos.numpy(), pos_label.numpy()] = 1
pos_mask = paddle.to_tensor(pos_mask, dtype='bool')
score = score.unsqueeze(-1).expand([-1, pred.shape[1]]).cast('float32')
# positives are supervised by bbox quality (IoU) score
scale_factor_new = score - pred_sigmoid
loss_pos = func(
pred, score, reduction='none') * scale_factor_new.abs().pow(beta)
loss = loss * paddle.logical_not(pos_mask).astype(loss.dtype) + loss_pos * pos_mask.astype(loss.dtype)
loss = loss.sum(axis=1)
return loss
def distribution_focal_loss(pred, label):
"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (Tensor): Predicted general distribution of bounding boxes
(before softmax) with shape (N, n+1), n is the max value of the
integral set `{0, ..., n}` in paper.
label (Tensor): Target distance label for bounding boxes with
shape (N,).
Returns:
Tensor: Loss tensor with shape (N,).
"""
dis_left = label.cast('int64')
dis_right = dis_left + 1
weight_left = dis_right.cast('float32') - label
weight_right = label - dis_left.cast('float32')
loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
+ F.cross_entropy(pred, dis_right, reduction='none') * weight_right
return loss
@register
@serializable
class QualityFocalLoss(nn.Layer):
r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
Learning Qualified and Distributed Bounding Boxes for Dense Object
Detection <https://arxiv.org/abs/2006.04388>`_.
Args:
use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
Defaults to True.
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self,
use_sigmoid=True,
beta=2.0,
reduction='mean',
loss_weight=1.0):
super(QualityFocalLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.beta = beta
assert reduction in ('none', 'mean', 'sum')
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None):
"""Forward function.
Args:
pred (Tensor): Predicted joint representation of
classification and quality (IoU) estimation with shape (N, C),
C is the number of classes.
target (tuple([Tensor])): Target category label with shape
(N,) and target quality label with shape (N,).
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
loss = self.loss_weight * quality_focal_loss(
pred, target, beta=self.beta, use_sigmoid=self.use_sigmoid)
if weight is not None:
loss = loss * weight
if avg_factor is None:
if self.reduction == 'none':
return loss
elif self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if self.reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif self.reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class DistributionFocalLoss(nn.Layer):
"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
Learning Qualified and Distributed Bounding Boxes for Dense Object
Detection <https://arxiv.org/abs/2006.04388>`_.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(DistributionFocalLoss, self).__init__()
assert reduction in ('none', 'mean', 'sum')
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None):
"""Forward function.
Args:
pred (Tensor): Predicted general distribution of bounding
boxes (before softmax) with shape (N, n+1), n is the max value
of the integral set `{0, ..., n}` in paper.
target (Tensor): Target distance label for bounding boxes
with shape (N,).
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
loss = self.loss_weight * distribution_focal_loss(pred, target)
if weight is not None:
loss = loss * weight
if avg_factor is None:
if self.reduction == 'none':
return loss
elif self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if self.reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif self.reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
return loss