Referring to CapsNet's paper, the basic structure proposed is as follows:
capsnet_mnist.jpg
It can be seen that the basic structure of CapsNet is as follows:
Capsule network has a unique activation function Squash function: $$ Squash(S) =/cfrac{||S||2}{1+||S||2}/cdot/cfrac{S}{||S|| } $$ where the input is S capsule, the activation function can compress the length of the capsule, the code is implemented as follows:
def squash(inputs, axis=-1): norm = torch.norm(inputs, p=2, dim=axis, keepdim=True) scale = norm**2/(1 + norm**2)/(norm + 1e-8) return scale * inputs
among them:
norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
Calculating the length of the input capsule p=2
means that the calculation is a two-norm, which keepdim=True
means that the original spatial shape is maintained.scale = norm**2/(1 + norm**2)/(norm + 1e-8)
Calculate the zoom factor, namely $/cfrac{||S||2}{1+||S||2}/cdot/cfrac{1}{||S||}$return scale * inputs
Complete calculationclass PrimaryCapsule(nn.Module): """ Apply Conv2D with `out_channels` and then reshape to get capsules :param in_channels: input channels :param out_channels: output channels :param dim_caps: dimension of capsule :param kernel_size: kernel size :return: output tensor, size=[batch, num_caps, dim_caps] """ def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0): super(PrimaryCapsule, self).__init__() self.dim_caps = dim_caps self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): outputs = self.conv2d(x) outputs = outputs.view(x.size(0), -1, self.dim_caps) return squash(outputs)
The pre-capsule layer is implemented using a convolutional layer, and its forward propagation consists of three parts:
outputs = self.conv2d(x)
: Convolution processing on the input, the shape of the output in this step is [batch,out_channels,p_w,p_h]outputs = outputs.view(x.size(0), -1, self.dim_caps)
: Convert 4D convolution output into 3D capsule output form, the shape of output is [batch,caps_num,dim_caps], where caps_num is the number of capsules, which can be automatically calculated; dim_caps is the length of the capsule, which needs to be specified in advance.return squash(outputs)
: Activate the function and return the activated capsuledef __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3): super(DenseCapsule, self).__init__() self.in_num_caps = in_num_caps self.in_dim_caps = in_dim_caps self.out_num_caps = out_num_caps self.out_dim_caps = out_dim_caps self.routings = routings self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))
The parameters are defined as follows:
In addition, the weight is also defined, the size is [out_num_caps, in_num_caps, out_dim_caps, in_dim_caps], that is, each output and each output capsule are connected
def forward(self, x): x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1) x_hat_detached = x_hat.detach() b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda() assert self.routings> 0,'The/'routings\' should be> 0.' for i in range(self.routings): c = F.softmax(b, dim=1) if i == self.routings-1: outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)) else: outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True)) b = b + torch.sum(outputs * x_hat_detached, dim=-1) return torch.squeeze(outputs, dim=-2)
Forward propagation is divided into two parts: input mapping and dynamic routing. The input mapping is as follows:
x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
x[:, None, :, :, None]
Expand the data dimension from [batch, in_num_caps, in_dim_caps] to [batch, 1,in_num_caps, in_dim_caps,1]torch.matmul()
Multiply the weight and the expanded input. The size of the weight is [out_num_caps, in_num_caps, out_dim_caps, in_dim_caps], and the size of the multiplication result is [batch, out_num_caps, in_num_caps,out_dim_caps, 1]torch.squeeze()
Remove redundant dimensions, after removing the resulting size [batch,out_num_caps,in_num_caps,out_dim_caps]x_hat_detached = x_hat.detach()
Truncated gradient backpropagationAfter this part is over, each input capsule has produced out_num_caps output capsules, so there are currently in_num_caps*out_num_caps capsules. The second part is dynamic routing. The algorithm diagram of dynamic routing is as follows:
dynamic_route.jpg
The following parts implement the process:
b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda() for i in range(self.routings): c = F.softmax(b, dim=1) if i == self.routings-1: outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)) else: outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True)) b = b + torch.sum(outputs * x_hat_detached, dim=-1)
c = F.softmax(b, dim=1)
implemented using this step without changing the size of boutputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
c[:, :, :, None]
Extend the dimension of c to broadcast the dimension when multiplied by positiontorch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)
Calculate the product of each capsule and the corresponding weight, that is, the $s_j$ in the algorithm, and at the same time sum on the penultimate dimension, the size of the result output in this step is [batch, out_num_caps, 1,out_dim_caps]squash()
b = b + torch.sum(outputs * x_hat_detached, dim=-1)
. The two bitwise multiplied variable sizes are [batch, out_num_caps, in_num_caps, out_dim_caps] and [batch, out_num_caps, 1, out_dim_caps]. There is broadcasting behavior on the penultimate dimension, so the final result is [batch, out_num_caps, in_num_caps]class CapsuleNet(nn.Module): """ A Capsule Network on MNIST. :param input_size: data size = [channels, width, height] :param classes: number of classes :param routings: number of routing iterations Shape: -Input: (batch, channels, width, height), optional (batch, classes). -Output:((batch, classes), (batch, channels, width, height)) """ def __init__(self, input_size, classes, routings): super(CapsuleNet, self).__init__() self.input_size = input_size self.classes = classes self.routings = routings # Layer 1: Just a conventional Conv2D layer self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps] self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0) # Layer 3: Capsule layer. Routing algorithm works here. self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8, out_num_caps=classes, out_dim_caps=16, routings=routings) # Decoder network. self.decoder = nn.Sequential( nn.Linear(16*classes, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]), nn.Sigmoid() ) self.relu = nn.ReLU() def forward(self, x, y=None): x = self.relu(self.conv1(x)) x = self.primarycaps(x) x = self.digitcaps(x) length = x.norm(dim=-1) if y is None: # during testing, no label given. create one-hot coding using `length` index = length.max(dim=1)[1] y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.).cuda()) reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) return length, reconstruction.view(-1, *self.input_size)
The network component includes two parts: the capsule network and the reconstruction network. The reconstruction network is a multilayer perceptron, and the image is reconstructed according to the result of the capsule. This means that the capsule can also include some spatial information in addition to the result.
Note that the forward propagation part of the capsule network is:
x = self.relu(self.conv1(x)) x = self.primarycaps(x) x = self.digitcaps(x) length = x.norm(dim=-1)
The final output is the second norm of each capsule, which is the length of the vector
The cost function of the capsule part of the capsule neural network is as follows: L_c = T_c max(0,m^+-||V_c||)^2 +/lambda (1-T_c)max(0,||v_c|| -m^-) ^ 2 $$
The following code implements this part, where L is the cost function calculation of the capsule, where $m+=0.9,m-=0.1$, L_recon is the cost function of reconstruction, and the MSELoss function of the input image and the restored image.
def caps_loss(y_true, y_pred, x, x_recon, lam_recon): L = y_true * torch.clamp(0.9-y_pred, min=0.) ** 2 +/ 0.5 * (1-y_true) * torch.clamp(y_pred-0.1, min=0.) ** 2 L_margin = L.sum(dim=1).mean() L_recon = nn.MSELoss()(x_recon, x) return L_margin + lam_recon * L_recon