-
Notifications
You must be signed in to change notification settings - Fork 374
/
label_smooth.py
253 lines (216 loc) · 8.53 KB
/
label_smooth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
##
# version 1: use torch.autograd
class LabelSmoothSoftmaxCEV1(nn.Module):
'''
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
'''
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, logits, label):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LabelSmoothSoftmaxCEV1()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
logits = logits.float() # use fp32 to avoid nan
with torch.no_grad():
num_classes = logits.size(1)
label = label.clone().detach()
ignore = label.eq(self.lb_ignore)
n_valid = ignore.eq(0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * lb_one_hot, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss
##
# version 2: user derived grad computation
class LSRCrossEntropyFunctionV2(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, logits, label, lb_smooth, lb_ignore):
# prepare label
num_classes = logits.size(1)
lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes
label = label.clone().detach()
ignore = label.eq(lb_ignore)
n_valid = ignore.eq(0).sum()
label[ignore] = 0
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
ignore = ignore.nonzero(as_tuple=False)
_, M = ignore.size()
a, *b = ignore.chunk(M, dim=1)
mask = [a, torch.arange(logits.size(1)), *b]
lb_one_hot[mask] = 0
coeff = (num_classes - 1) * lb_neg + lb_pos
ctx.variables = coeff, mask, logits, lb_one_hot
loss = torch.log_softmax(logits, dim=1).neg_().mul_(lb_one_hot).sum(dim=1)
return loss
@staticmethod
@amp.custom_bwd(device_type='cuda')
def backward(ctx, grad_output):
coeff, mask, logits, lb_one_hot = ctx.variables
scores = torch.softmax(logits, dim=1).mul_(coeff)
grad = scores.sub_(lb_one_hot).mul_(grad_output.unsqueeze(1))
grad[mask] = 0
return grad, None, None, None
class LabelSmoothSoftmaxCEV2(nn.Module):
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV2, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
def forward(self, logits, labels):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LabelSmoothSoftmaxCEV2()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
losses = LSRCrossEntropyFunctionV2.apply(
logits, labels, self.lb_smooth, self.lb_ignore)
if self.reduction == 'sum':
losses = losses.sum()
elif self.reduction == 'mean':
n_valid = (labels != self.lb_ignore).sum()
losses = losses.sum() / n_valid
return losses
##
# version 3: implement wit cpp/cuda to save memory and accelerate
import lsr_cpp
class LSRCrossEntropyFunctionV3(torch.autograd.Function):
'''
use cpp/cuda to accelerate and shrink memory usage
'''
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, labels, lb_smooth, lb_ignore):
losses = lsr_cpp.lsr_forward(logits, labels, lb_ignore, lb_smooth)
ctx.variables = logits, labels, lb_ignore, lb_smooth
return losses
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_output):
logits, labels, lb_ignore, lb_smooth = ctx.variables
grad = lsr_cpp.lsr_backward(logits, labels, lb_ignore, lb_smooth)
grad.mul_(grad_output.unsqueeze(1))
return grad, None, None, None
class LabelSmoothSoftmaxCEV3(nn.Module):
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV3, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
def forward(self, logits, labels):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LabelSmoothSoftmaxCEV3()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
losses = LSRCrossEntropyFunctionV3.apply(
logits, labels, self.lb_smooth, self.lb_ignore)
if self.reduction == 'sum':
losses = losses.sum()
elif self.reduction == 'mean':
n_valid = (labels != self.lb_ignore).sum()
losses = losses.sum() / n_valid
return losses
if __name__ == '__main__':
import torchvision
import torch
import numpy as np
import random
torch.manual_seed(15)
random.seed(15)
np.random.seed(15)
torch.backends.cudnn.deterministic = True
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
net = torchvision.models.resnet18(pretrained=False)
self.conv1 = net.conv1
self.bn1 = net.bn1
self.maxpool = net.maxpool
self.relu = net.relu
self.layer1 = net.layer1
self.layer2 = net.layer2
self.layer3 = net.layer3
self.layer4 = net.layer4
self.fc = nn.Conv2d(512, 19, 3, 1, 1)
def forward(self, x):
feat = self.conv1(x)
feat = self.bn1(feat)
feat = self.relu(feat)
feat = self.maxpool(feat)
feat = self.layer1(feat)
feat = self.layer2(feat)
feat = self.layer3(feat)
feat = self.layer4(feat)
feat = self.fc(feat)
out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
return out
net1 = Model()
net2 = Model()
net2.load_state_dict(net1.state_dict())
red = 'mean'
criteria1 = LabelSmoothSoftmaxCEV2(lb_smooth=0.1, ignore_index=255, reduction=red)
criteria2 = LabelSmoothSoftmaxCEV1(lb_smooth=0.1, ignore_index=255, reduction=red)
net1.cuda()
net2.cuda()
net1.train()
net2.train()
criteria1.cuda()
criteria2.cuda()
optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
bs = 64
for it in range(300):
inten = torch.randn(bs, 3, 224, 224).cuda()
lbs = torch.randint(0, 19, (bs, 224, 224)).cuda()
lbs[1, 1, 1] = 255
lbs[30, 3, 2:200] = 255
lbs[18, 4:7, 8:200] = 255
logits = net1(inten)
loss1 = criteria1(logits, lbs)
optim1.zero_grad()
loss1.backward()
optim1.step()
# print(net1.fc.weight[:, :5])
logits = net2(inten)
loss2 = criteria2(logits, lbs)
optim2.zero_grad()
loss2.backward()
optim2.step()
# net1.load_state_dict(net2.state_dict())
# print(net2.fc.weight[:, :5])
with torch.no_grad():
if (it+1) % 50 == 0:
print('iter: {}, ================='.format(it+1))
# print(net1.fc.weight.numel())
print('fc weight: ', torch.mean(torch.abs(net1.fc.weight - net2.fc.weight)).item())
print('conv1 weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
print('loss: ', loss1.item() - loss2.item())