Mask R-CNN for segmentation using PyTorch

Bjørn Hansen
4 min readDec 30, 2020

In this article the masked R-CNN will be implemented for the task of segmentation. First an introduction of the R-CNN framework will be presented followed by an example implementation using PyTorch and lastly a presentation of the results.

Mask R-CNN

Frameworks such as the mask R-CNN have been developed for multi use object instance segmentation and detection tasks. The mask R-CNN was originally introduced in 2017 and is an extension of the Faster R-CNN deep learning framework. The mask R-CNN has two fundamental stages; the first stage generates proposals about the regions where there might be an object based on the input image and, the second stage predicts the class of the defined objects, defines bounding box coordinates, and generates a pixel mask for the object. A general overview of the framework can be seen in the image below.

This first stage consists of a region proposal network which through scanning the feature map proposes regions which may contain an object. Next, a set of predefined so called anchors bind features to the original image location. The anchors area a set of boxes with predefined locations and scales relative to the given input image. Binary classes (object and background) and bounding boxes are assigned to individual anchors according to a intersection over union (IoU) threshold value. The concept of the first stage in the mask RCNN is relatively similar to the concept of extracting information through convolving, down-sampling, and up-sampling.

The second stage of the mask RCNN is a neural network which is similar to the first stage, however, here the it takes the proposed regions of interest (ROI) and with a method called ROIAlign, the areas of interest are located to the the relevant areas of the feature map. From here there are various neural network branches which perform assignments independently for each object on a pixel level, and therefore have their own loss functions. There is a classifier, bounding box, and mask generating network.

General overview of the mask R-CNN.

Implementing the mask R-CNN in PyTorch

The implementation of the mask R-CNN will follow the same procedure which was used in the TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL. The mask R-CNN was originally trained using the coco dataset for the task of detection and classification of everyday objects but in this article it will be transfer learned on colon histology images, which were originally presented in the GlaS contest. This was a research contest where teams of researchers from around the world competed in the task of segmentation; see the link to download the data and read more about the contest, Link. To start with a custom data loader is defined as seen below.

After loading the image data, we can view a couple pictures. Below are a couple images and with their bounding boxes around the glands of interest.

With the data all set, we can move onto loading the model. The Torchvision library makes transfer learning very convenient as it allows you to simply load a pre-trained mask R-CNN model with a resnet 50 backbone (or other pre-trained backbones if one chooses). The code for loading the model is shown below.

The data and model are ready, so the training loop is next. There are many ways to define a training loop using PyTorch, seen below is an example of how it can be done for the model above. The original torchvision tutorial also has a nice training loop implementation, called train_one_epoch, Link.

Results

Since the mask R-CNN is relatively complex it takes a while to train, however once finished the results seen below were attained. Using the test set the model achieved an IoU score of 0.419 and an F1 score of 0.592.

Training loss for Mask R-CNN.
The original image mask (top) and the predicted mask using the mask R-CNN model (bottom).

Conclusion

The full notebook used in this article can be found via the link below. The mask R-CNN is a cool framework which can be used for a range of computer vision tasks. If you are interested in seeing a full PyTorch implementation of mask R-CNN from scratch, there is a Github repo here, Link. For further reading on the use of the mask R-CNN for medical images I recommend the following research paper, Link. I hope this article was helpful, if you have any thoughts, feel free to share them below.

Full notebook: https://colab.research.google.com/drive/11FN5yQh1X7x-0olAOx7EJbwEey5jamKl?usp=sharing

--

--