使用PyTorch简单实现图像分割网络FCN

发布 : 2020-01-25 分类 : 深度学习 浏览 :

构造一个简单的全卷积神经网络作为解码器,编码器使用预训练模型 ResNet18。数据集使用 VOC2012。

小记录

在写的过程中,遇到了些坎,这里做个记录。

训练时的 ground truth

简写 GT,即图像标注。
计算 loss 时要求 predict 出的特征图 outputs 的 shape 与它的标签 GT 一致。而模型的输出 shape 格式是:
(batch_size, classes, channels, height, width),而我们的标签在未做处理之前是没有 classes 这个维度的,即(batch_size, channels, height, width),因此在数据输入之前需要做处理,才能正确预测,对应代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
def __getitem__(self, idx):
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
label = voc_label_indices(label, self.colormap2label).numpy().astype('uint8')

# 统一GT
h, w = label.shape
target = torch.zeros(21, h, w)
for c in range(21):
target[c][label == c] = 1

return (self.tsf(feature), target)

模型结构

这里是完整代码中对应的片段

1
2
3
4
5
6
7
8
9
10
11
12
resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)

net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))

net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)

简单讲解下。首先使用 pytorch 提供 resnet18 预训练模型。为了提取模型中我们需要的部分,我们需要遍历它。将需要的 module 添加到我们的net中。在 resnet18 模型之后添加一层 kernel size 为 1 的卷积层,做通道卷积。然后再添加一层转置卷积层,将特征图尺寸映射到输入尺寸。为了让模型能够快速收敛,我们指定了新添加的两层的 kernel 参数初始化方式。其中转置卷积层使用了输入的双线性差值作为初始化。

1
2
3
4
5
6
7
8
9
10
11
12
13
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)

完整训练代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from tqdm import tqdm

from FCN.VOC2012Dataset import VOC2012SegDataIter
import torch
from torch import nn, optim
import numpy as np
from torchvision import models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 21


def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
dtype='float32')
weight[range(in_channels), range(out_channels), :, :] = filt
return torch.tensor(weight)


if __name__ == '__main__':
batch_size = 4
train_iter, val_iter = VOC2012SegDataIter(batch_size, (320, 480), 2, 200)

resnet18 = models.resnet18(pretrained=True)
resnet18_modules = [layer for layer in resnet18.children()]
net = nn.Sequential()
for i, layer in enumerate(resnet18_modules[:-2]):
net.add_module(str(i), layer)

net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module("ConvTranspose2d",
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))

net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)

net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
lossFN = nn.BCEWithLogitsLoss()

num_epochs = 10
for epoch in range(num_epochs):
sum_loss = 0
sum_acc = 0
batch_count = 0
n = 0
for X, y in tqdm(train_iter):
X = X.to(device)
y = y.to(device)
y_pred = net(X)
loss = lossFN(y_pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

sum_loss += loss.cpu().item()
n += y.shape[0]
batch_count += 1
print("epoch %d: loss=%.4f" % (epoch + 1, sum_loss / n))

VOC 数据集读入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torchvision
from PIL import Image
import numpy as np


def voc_label_indices(colormap, colormap2label):
"""
convert colormap (PIL image) to colormap2label (uint8 tensor).
"""
colormap = np.array(colormap.convert("RGB")).astype('int32')
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
+ colormap[:, :, 2])
return colormap2label[idx]


def read_voc_images(root="./dataset/VOCdevkit/VOC2012",
is_train=True, max_num=None):
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
if max_num is not None:
images = images[:min(max_num, len(images))]
features, labels = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB")
labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB")
return features, labels # PIL image


def voc_rand_crop(feature, label, height, width):
"""
Random crop feature (PIL image) and label (PIL image).
"""
i, j, h, w = torchvision.transforms.RandomCrop.get_params(
feature, output_size=(height, width))

feature = torchvision.transforms.functional.crop(feature, i, j, h, w)
label = torchvision.transforms.functional.crop(label, i, j, h, w)

return feature, label


class VOCSegDataset(torch.utils.data.Dataset):
def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):
"""
crop_size: (h, w)
"""
self.rgb_mean = np.array([0.485, 0.456, 0.406])
self.rgb_std = np.array([0.229, 0.224, 0.225])
self.tsf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=self.rgb_mean,
std=self.rgb_std)
])

self.crop_size = crop_size # (h, w)
features, labels = read_voc_images(root=voc_dir,
is_train=is_train,
max_num=max_num)
self.features = self.filter(features) # PIL image
self.labels = self.filter(labels) # PIL image
self.colormap2label = colormap2label
print('read ' + str(len(self.features)) + ' valid examples')

def filter(self, imgs):
return [img for img in imgs if (
img.size[1] >= self.crop_size[0] and
img.size[0] >= self.crop_size[1])]

def __getitem__(self, idx):
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
label = voc_label_indices(label, self.colormap2label).numpy().astype('uint8')

# 统一GT
h, w = label.shape
target = torch.zeros(21, h, w)
for c in range(21):
target[c][label == c] = 1

return (self.tsf(feature), target)

def __len__(self):
return len(self.features)


def VOC2012SegDataIter(batch_size=64, crop_size=(320, 480), num_workers=4, max_num=None):
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
for i, colormap in enumerate(VOC_COLORMAP):
colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i

voc_train = VOCSegDataset(True, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
voc_val = VOCSegDataset(False, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True,
num_workers=num_workers)
val_iter = torch.utils.data.DataLoader(voc_val, batch_size, drop_last=True, num_workers=num_workers)
return train_iter, val_iter
本文作者 : HeoLis
原文链接 : https://ishero.net/%E4%BD%BF%E7%94%A8PyTorch%E7%AE%80%E5%8D%95%E5%AE%9E%E7%8E%B0%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2%E7%BD%91%E7%BB%9CFCN.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!

学习、记录、分享、获得

微信扫一扫, 向我投食

微信扫一扫, 向我投食