1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| import torch import numpy as np from torch import nn import torch.nn.functional as F import cv2
class STFTConv(nn.Module): def __init__(self, in_c, out_c, kernel_size, stride=1, padding=0): super(STFTConv, self).__init__() self.in_c = in_c self.out_c = out_c self.stride = stride self.padding = padding self.kernel_size = kernel_size
self.Y = self.define_Y(kernel_size) n = kernel_size[0] a = 1 / n self.V = np.array([[a, 0], [0, a], [a, a], [a, -a]]) w1_r, w1_i = self.kernel_fn(self.V[0], self.Y) w2_r, w2_i = self.kernel_fn(self.V[1], self.Y) w3_r, w3_i = self.kernel_fn(self.V[2], self.Y) w4_r, w4_i = self.kernel_fn(self.V[3], self.Y)
w1_r = torch.FloatTensor(w1_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w1_r = nn.Parameter(w1_r, requires_grad=False) w1_i = torch.FloatTensor(w1_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w1_i = nn.Parameter(w1_i, requires_grad=False)
w2_r = torch.FloatTensor(w2_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w2_r = nn.Parameter(w2_r, requires_grad=False) w2_i = torch.FloatTensor(w2_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w2_i = nn.Parameter(w2_i, requires_grad=False)
w3_r = torch.FloatTensor(w3_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w3_r = nn.Parameter(w3_r, requires_grad=False) w3_i = torch.FloatTensor(w3_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w3_i = nn.Parameter(w3_i, requires_grad=False)
w4_r = torch.FloatTensor(w4_r).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w4_r = nn.Parameter(w4_r, requires_grad=False) w4_i = torch.FloatTensor(w4_i).expand(self.out_c, self.in_c, kernel_size[0], kernel_size[1]) self.w4_i = nn.Parameter(w4_i, requires_grad=False)
def forward(self, X): c1_r = F.conv2d(X, self.w1_r, stride=self.stride, padding=self.padding) c1_i = F.conv2d(X, self.w1_i, stride=self.stride, padding=self.padding)
c2_r = F.conv2d(X, self.w2_r, stride=self.stride, padding=self.padding) c2_i = F.conv2d(X, self.w2_i, stride=self.stride, padding=self.padding)
c3_r = F.conv2d(X, self.w3_r, stride=self.stride, padding=self.padding) c3_i = F.conv2d(X, self.w3_i, stride=self.stride, padding=self.padding)
c4_r = F.conv2d(X, self.w4_r, stride=self.stride, padding=self.padding) c4_i = F.conv2d(X, self.w4_i, stride=self.stride, padding=self.padding) c = torch.cat((c1_r, c1_i, c2_r, c2_i, c3_r, c3_i, c4_r, c4_i), dim=1) return torch.abs(c)
@staticmethod def define_Y(kernel_size): assert len(kernel_size) % 2 == 0 w, h = kernel_size Y = [] for i in range(w): yi = [] for j in range(h): yi.append([[i], [j]]) Y.append(yi) Y = np.array(Y) return Y
def kernel_fn(self, v, Y): w = v.dot(Y) w = np.squeeze(w, axis=2) return np.cos(2 * np.pi * w), -np.sin(2 * np.pi * w)
def _get_stft_kernels(size, v): assert len(size) % 2 == 0 h, w = size Y = [] for i in range(w): yi = [] for j in range(h): yi.append([[i], [j]]) Y.append(yi) Y = np.array(Y) + 1
def kernel_fn(): w = v.dot(Y) w = np.squeeze(w, axis=2) return np.cos(2 * np.pi * w), -np.sin(2 * np.pi * w)
return kernel_fn()
|