Pytorch-based CapsNet source code detailed explanation CapsNet basic structure code implementation reference

Pytorch-based CapsNet source code detailed explanation CapsNet basic structure code implementation reference

CapsNet basic structure

Referring to CapsNet's paper, the basic structure proposed is as follows:

capsnet_mnist.jpg

It can be seen that the basic structure of CapsNet is as follows:

  • Ordinary convolutional layer Conv1: The basic convolutional layer, with a larger receptive field, reaching 9x9
  • Pre-capsule layer PrimaryCaps: Prepare for the capsule layer, the operation is convolution operation, and the final output is the three-dimensional data of [batch, caps_num, caps_length]:
    • batch is the batch size
    • caps_num is the number of capsules
    • caps_length is the length of each capsule (each capsule is a vector, the vector includes caps_length components)
  • Capsule layer DigitCaps: Capsule layer, the purpose is to replace the last layer of fully connected layer, the output is 10 capsules

Code

Capsule related components

Activation function Squash

Capsule network has a unique activation function Squash function: $$ Squash(S) =/cfrac{||S||2}{1+||S||2}/cdot/cfrac{S}{||S|| } $$ where the input is S capsule, the activation function can compress the length of the capsule, the code is implemented as follows:

def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2/(1 + norm**2)/(norm + 1e-8)
    return scale * inputs

among them:

  • norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)Calculating the length of the input capsule p=2means that the calculation is a two-norm, which keepdim=Truemeans that the original spatial shape is maintained.
  • scale = norm**2/(1 + norm**2)/(norm + 1e-8)Calculate the zoom factor, namely $/cfrac{||S||2}{1+||S||2}/cdot/cfrac{1}{||S||}$
  • return scale * inputsComplete calculation

PrimaryCaps

class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel size
    :return: output tensor, size=[batch, num_caps, dim_caps]
    """
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        return squash(outputs)

The pre-capsule layer is implemented using a convolutional layer, and its forward propagation consists of three parts:

  • outputs = self.conv2d(x): Convolution processing on the input, the shape of the output in this step is [batch,out_channels,p_w,p_h]
  • outputs = outputs.view(x.size(0), -1, self.dim_caps): Convert 4D convolution output into 3D capsule output form, the shape of output is [batch,caps_num,dim_caps], where caps_num is the number of capsules, which can be automatically calculated; dim_caps is the length of the capsule, which needs to be specified in advance.
  • return squash(outputs): Activate the function and return the activated capsule

Capsule layer DigitCaps

Parameter definition

def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
    super(DenseCapsule, self).__init__()
    self.in_num_caps = in_num_caps
    self.in_dim_caps = in_dim_caps
    self.out_num_caps = out_num_caps
    self.out_dim_caps = out_dim_caps
    self.routings = routings
    self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

The parameters are defined as follows:

  • in_num_caps: Enter the number of capsules
  • in_dim_caps: Enter the length of the capsule (number of dimensions)
  • out_num_caps: the number of output capsules
  • out_dim_caps: the length of the output capsule (number of dimensions)
  • routings: the number of dynamic routing iterations

In addition, the weight is also defined, the size is [out_num_caps, in_num_caps, out_dim_caps, in_dim_caps], that is, each output and each output capsule are connected

Forward propagation

def forward(self, x):
    x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
    x_hat_detached = x_hat.detach()

    b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda()
    assert self.routings> 0,'The/'routings\' should be> 0.'
    for i in range(self.routings):
        c = F.softmax(b, dim=1)
        if i == self.routings-1:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
        else:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
            b = b + torch.sum(outputs * x_hat_detached, dim=-1)
    return torch.squeeze(outputs, dim=-2)

Forward propagation is divided into two parts: input mapping and dynamic routing. The input mapping is as follows:

  1. x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
    • x[:, None, :, :, None]Expand the data dimension from [batch, in_num_caps, in_dim_caps] to [batch, 1,in_num_caps, in_dim_caps,1]
    • torch.matmul()Multiply the weight and the expanded input. The size of the weight is [out_num_caps, in_num_caps, out_dim_caps, in_dim_caps], and the size of the multiplication result is [batch, out_num_caps, in_num_caps,out_dim_caps, 1]
    • torch.squeeze()Remove redundant dimensions, after removing the resulting size [batch,out_num_caps,in_num_caps,out_dim_caps]
  2. x_hat_detached = x_hat.detach()Truncated gradient backpropagation

After this part is over, each input capsule has produced out_num_caps output capsules, so there are currently in_num_caps*out_num_caps capsules. The second part is dynamic routing. The algorithm diagram of dynamic routing is as follows:

dynamic_route.jpg

The following parts implement the process:

b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda()
    for i in range(self.routings):
        c = F.softmax(b, dim=1)
        if i == self.routings-1:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
        else:
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
            b = b + torch.sum(outputs * x_hat_detached, dim=-1)
  1. The first part is the softmax function, which is c = F.softmax(b, dim=1)implemented using this step without changing the size of b
  2. The second part is to calculate the routing result:outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
    • c[:, :, :, None]Extend the dimension of c to broadcast the dimension when multiplied by position
    • torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True)Calculate the product of each capsule and the corresponding weight, that is, the $s_j$ in the algorithm, and at the same time sum on the penultimate dimension, the size of the result output in this step is [batch, out_num_caps, 1,out_dim_caps]
    • Through activation functionsquash()
  3. The third part updates the weight b = b + torch.sum(outputs * x_hat_detached, dim=-1). The two bitwise multiplied variable sizes are [batch, out_num_caps, in_num_caps, out_dim_caps] and [batch, out_num_caps, 1, out_dim_caps]. There is broadcasting behavior on the penultimate dimension, so the final result is [batch, out_num_caps, in_num_caps]

Other components

Network structure

class CapsuleNet(nn.Module):
    """
    A Capsule Network on MNIST.
    :param input_size: data size = [channels, width, height]
    :param classes: number of classes
    :param routings: number of routing iterations
    Shape:
        -Input: (batch, channels, width, height), optional (batch, classes).
        -Output:((batch, classes), (batch, channels, width, height))
    """
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        # Layer 1: Just a conventional Conv2D layer
        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)

        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps]
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)

        # Layer 3: Capsule layer. Routing algorithm works here.
        self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        # Decoder network.
        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)
        if y is None: # during testing, no label given. create one-hot coding using `length`
            index = length.max(dim=1)[1]
            y = Variable(torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.).cuda())
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        return length, reconstruction.view(-1, *self.input_size)

The network component includes two parts: the capsule network and the reconstruction network. The reconstruction network is a multilayer perceptron, and the image is reconstructed according to the result of the capsule. This means that the capsule can also include some spatial information in addition to the result.

Note that the forward propagation part of the capsule network is:

x = self.relu(self.conv1(x))
x = self.primarycaps(x)
x = self.digitcaps(x)
length = x.norm(dim=-1)

The final output is the second norm of each capsule, which is the length of the vector

Cost function

The cost function of the capsule part of the capsule neural network is as follows: L_c = T_c max(0,m^+-||V_c||)^2 +/lambda (1-T_c)max(0,||v_c|| -m^-) ^ 2 $$

The following code implements this part, where L is the cost function calculation of the capsule, where $m+=0.9,m-=0.1$, L_recon is the cost function of reconstruction, and the MSELoss function of the input image and the restored image.

def caps_loss(y_true, y_pred, x, x_recon, lam_recon):
    L = y_true * torch.clamp(0.9-y_pred, min=0.) ** 2 +/
        0.5 * (1-y_true) * torch.clamp(y_pred-0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()
    L_recon = nn.MSELoss()(x_recon, x)
    return L_margin + lam_recon * L_recon

reference

CapsNet paper

CapsNet open source code

Reference: https://cloud.tencent.com/developer/article/1110558 Pytorch-based CapsNet source code detailed explanation of CapsNet basic structure code implementation reference-Cloud + Community-Tencent Cloud