1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| import torch from torch import nn from torchex import nn as exnn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class TeacherNet(nn.Module): def __init__(self, classes): super(TeacherNet, self).__init__() self.classes = classes self.conv = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(96), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(), nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(), nn.MaxPool2d(3, 2), nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(3, 2) ) self.fc = nn.Sequential( exnn.GlobalAvgPool2d(), exnn.Flatten(), nn.Linear(256, self.classes) )
def forward(self, X): tmp = self.conv(X) return self.fc(tmp)
class StudentNet(nn.Module): def __init__(self, classes): super(StudentNet, self).__init__() self.classes = classes self.conv = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(96), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU() ) self.fc = nn.Sequential( exnn.GlobalAvgPool2d(), exnn.Flatten(), nn.Linear(256, self.classes) )
def forward(self, X): tmp = self.conv(X) return self.fc(tmp)
class Student2Net(nn.Module): def __init__(self, classes): super(Student2Net, self).__init__() self.classes = classes self.conv = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=2), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(96), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU() ) self.fc = nn.Sequential( exnn.GlobalAvgPool2d(), exnn.Flatten(), nn.Linear(256, self.classes) )
def forward(self, X): tmp = self.conv(X) return self.fc(tmp)
|