1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1)) net.add_module("ConvTranspose2d", nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32)) net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True) net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)
def bilinear_kernel(in_channels, out_channels, kernel_size): factor = (kernel_size + 1) // 2 if kernel_size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:kernel_size, :kernel_size] filt = (1 - abs(og[0] - center) / factor) * \ (1 - abs(og[1] - center) / factor) weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32') weight[range(in_channels), range(out_channels), :, :] = filt return torch.tensor(weight)
|