1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| def _cosine(x, centroid, p_norm=2): x_norm = F.normalize(x, dim=1, p=p_norm).detach() w_norm = F.normalize(centroid, dim=1, p=p_norm)
x_corr = F.conv2d(x_norm, w_norm, ) x_corr = F.softmax(x_corr, dim=1)
y_word = F.one_hot(torch.argmax(x_corr, dim=1), num_classes=centroid.shape[0]).sum(dim=[1, 2]) > 0 x_hist = torch.sum(x_corr, [2, 3], keepdim=True)
return x_corr, x_hist, y_word.detach()
class VWE(nn.Module): def __init__(self, k_words=256): super(VWE, self).__init__()
self.k_words = k_words
self.centroid = nn.Parameter(torch.Tensor(self.k_words, 2048, 1, 1), requires_grad=True) nn.init.kaiming_normal_(self.centroid, a=np.sqrt(5))
def forward(self, x): x_corr, x_hist, y_word = _cosine(x=x, centroid=self.centroid) return x_corr, x_hist, y_word
|