Knowledge Distillation知识蒸馏简单实现

发布 : 2020-01-25 分类 : 深度学习 浏览 :

使用 PyTorch 简单实现知识蒸馏网络。

首先训练 Teacher 网络,再训练整体网络。在训练整体网络时,导入训练好的 Teacher 网络模型。

网络结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch import nn
from torchex import nn as exnn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class TeacherNet(nn.Module):
def __init__(self, classes):
super(TeacherNet, self).__init__()
self.classes = classes
self.conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(384),
nn.ReLU(),
nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(384),
nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.fc = nn.Sequential(
exnn.GlobalAvgPool2d(),
exnn.Flatten(),
nn.Linear(256, self.classes)
)

def forward(self, X):
tmp = self.conv(X)
return self.fc(tmp)


class StudentNet(nn.Module):
def __init__(self, classes):
super(StudentNet, self).__init__()
self.classes = classes
self.conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.fc = nn.Sequential(
exnn.GlobalAvgPool2d(),
exnn.Flatten(),
nn.Linear(256, self.classes)
)

def forward(self, X):
tmp = self.conv(X)
return self.fc(tmp)


class Student2Net(nn.Module):
def __init__(self, classes):
super(Student2Net, self).__init__()
self.classes = classes
self.conv = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.fc = nn.Sequential(
exnn.GlobalAvgPool2d(),
exnn.Flatten(),
nn.Linear(256, self.classes)
)

def forward(self, X):
tmp = self.conv(X)
return self.fc(tmp)

Student 和 Student2 都是学生,只是模型复杂度不一样,得分不一样而已,方便观察效果。

训练 Teacher 网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn, optim
from torch.nn import functional as F
from utils.load_data_Fnt10 import load_data_Fnt10
from utils import evaluate_accuracy
from tqdm import tqdm
from KD.Net import TeacherNet, StudentNet



if __name__ == '__main__':
INPUT_SIZE = 112
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

teacherNet = TeacherNet(classes=10)
teacherNet = teacherNet.to(device)
optimizer = optim.Adam(teacherNet.parameters(), lr=1e-3)
lossFN = nn.CrossEntropyLoss()

trainDL, valDL = load_data_Fnt10(INPUT_SIZE, BATCH_SIZE)

num_epochs = 30
for epoch in range(num_epochs):
sum_loss = 0
sum_acc = 0
batch_count = 0
n = 0
for X, y in tqdm(trainDL):
X = X.to(device)
y = y.to(device)
y_pred = teacherNet(X)

loss = lossFN(y_pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

sum_loss += loss.cpu().item()
sum_acc += (y_pred.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(valDL, teacherNet)
print("epoch %d: loss=%.4f \t acc=%.4f \t test acc=%.4f" % (epoch + 1, sum_loss / n, sum_acc / n, test_acc))
torch.save(teacherNet.state_dict(), './teacherNet.pth')

训练学生网络

为了对比使用知识蒸馏和不使用的去吧,可以单独训练 Student 网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torch import nn, optim
from torch.nn import functional as F
from utils.load_data_Fnt10 import load_data_Fnt10
from utils import evaluate_accuracy
from tqdm import tqdm
from KD.Net import TeacherNet, StudentNet, Student2Net

if __name__ == '__main__':
INPUT_SIZE = 112
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

studentNet = Student2Net(classes=10)
studentNet = studentNet.to(device)
optimizer = optim.Adam(studentNet.parameters(), lr=1e-3)
lossFN = nn.CrossEntropyLoss()

trainDL, valDL = load_data_Fnt10(INPUT_SIZE, BATCH_SIZE)

num_epochs = 90
for epoch in range(num_epochs):
sum_loss = 0
sum_acc = 0
batch_count = 0
n = 0
for X, y in tqdm(trainDL):
X = X.to(device)
y = y.to(device)
y_pred = studentNet(X)

loss = lossFN(y_pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

sum_loss += loss.cpu().item()
sum_acc += (y_pred.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(valDL, studentNet)
print("epoch %d: loss=%.4f \t acc=%.4f \t test acc=%.4f" % (epoch + 1, sum_loss / n, sum_acc / n, test_acc))
torch.save(studentNet.state_dict(), './studentNet.pth')

训练整体网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch import nn, optim
from torch.nn import functional as F
from utils.load_data_Fnt10 import load_data_Fnt10
from utils import evaluate_accuracy
from tqdm import tqdm
from KD.Net import TeacherNet, StudentNet, Student2Net

if __name__ == '__main__':
INPUT_SIZE = 112
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacherNet = TeacherNet(10)
teacherNet.load_state_dict(torch.load("./teacherNet.pth"))
teacherNet.eval()
teacherNet.train(mode=False)
teacherNet = teacherNet.to(device)

studentNet = Student2Net(classes=10)
studentNet.load_state_dict(torch.load("./studentNet-ST.pth"))
studentNet = studentNet.to(device)
optimizer = optim.Adam(studentNet.parameters(), lr=1e-3)
lossCE = nn.CrossEntropyLoss()
lossKD = nn.KLDivLoss()

trainDL, valDL = load_data_Fnt10(INPUT_SIZE, BATCH_SIZE)

num_epochs = 30
T, lambda_stu = 5.0, 0.05
for epoch in range(num_epochs):
sum_loss = 0
sum_acc = 0
batch_count = 0
n = 0
for X, y in tqdm(trainDL):
X = X.to(device)
y = y.to(device)
y_student = studentNet(X)

loss_student = lossCE(y_student, y)
y_teacher = teacherNet(X)
loss_teacher = lossKD(F.log_softmax(y_student / T, dim=1),
F.softmax(y_teacher / T, dim=1))
loss = lambda_stu * loss_student + (1 - lambda_stu) * T * T * loss_teacher
optimizer.zero_grad()
loss.backward()
optimizer.step()

sum_loss += loss.cpu().item()
sum_acc += (y_student.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(valDL, studentNet)
print("epoch %d: loss=%.4f \t acc=%.4f \t test acc=%.4f" % (epoch + 1, sum_loss / n, sum_acc / n, test_acc))
torch.save(studentNet.state_dict(), './studentNet-ST.pth')

完整项目地址

本文作者 : HeoLis
原文链接 : https://ishero.net/Knowledge%20Distillation%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F%E7%AE%80%E5%8D%95%E5%AE%9E%E7%8E%B0.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!

学习、记录、分享、获得

微信扫一扫, 向我投食

微信扫一扫, 向我投食