Cardiac Regression with DiGITS
This year NVIDIA is co-sponsoring the National Data Science Bowl competition on cardiac volume prediction. Regression, where a model generates a real number (as opposed to the discrete case of classification) is the relevant challenge here. In this presentation, DiGITS sitting on top of Caffe is used for cardiac MR CINE (volumetric video acquisition) end diastolic and systolic LV volume prediction. The model currently achieves 0.049108 (lower is better) on the metric (Continuous Ranked Probability Score (CRPS) ) with just a basic Lenet.
DiGITS is a graphical interface with handy real-time visualization of model training. It is accessible to the starting DL researcher and useful for the more advanced user. For the competition here, the model finetuning and accessible hyper-parameter management will prove useful in phase II, where the validation labels are released. The software requirements are DiGITs 3 which now has deb packages installation for Ubuntu 14.04. Other required software includes Python 2.X and the associated packages used, eg. pydicom, numpy, scipy. The dataset is obtained through the Kaggle website.
We run a number of preprocessing scripts to extract the data from DICOM format using pydicom. Before running the scripts, copy a skeleton directory structure to destination folder via a terminal, eg.
sudo find -type d -links 2 -exec mkdir -p “/raid/leo/cardiac/trainDigits/{}” \;
Run this in the folder that you’ve put all your training images, eg. /raid/leo/cardiac/train. Now run the preprocessing script which will extract frames of 27 cardiac MR images, nearly representing a cardiac cycle, and reshape them into a single image composed of 3x3 tiles x3 channels (false RGB). Make sure to change data paths on lines 165-172 to match your file system.
27 sequential cardiac images are tiled and stacked to create a false color image. What do the colored regions represent and what in what domains does a convolution act on this image?
Interestingly, convolutions now act spatially and temporally since the image is composed of a time series of images. When the preprocessing is done, fire up DIGITs 3.0.
Non-classification datasets may be created in DiGITS through the “other” type of datasets. For these datasets, DiGITS expects the user to provide a set of LMDB databases. Note that since labels may be vectors (or matrices), it is not possible to use a single LMDB database to hold the image and its label. Therefore, DiGITS expects one LMDB database for the images and a separate LMDB database for the labels.
Writing the LMDB in Python:
In the generic dataset creation form, you need to provide the paths to:
-
the train image database (eg. /raid/leo/cardiac/trainDigits/train_images)
-
the train label database (eg. /raid/leo/cardiac/trainDigits/train_labels)
-
the train mean image train_mean.binaryproto file (eg. /raid/leo/cardiac/trainDigits/train_mean.binaryproto)
Once the dataset has been created, create an ‘other’ model to generate a regression model. Paste in the prototxt file for a basic Lenet. For solver options, set the base learning rate to 1e-8, click advanced learning rate options and select 65% for the step size %.
Enter in the model hyperparameters via DiGITS as depicted
Once your model is trained, push validation images file list through the model using the ‘Test Many’ to create predictions for end diastolic and systolic LV volume prediction. Here is a sample list. Several of these scripts were based off of Bing Xu’s Mxnet tutorial and other code from the NVIDIA DiGITS repository. If submitting for evaluation on the leaderboard, a Python script formats the predictions in terms of a cumulative probability distribution function.
There are many improvements that can be made to this initial approach, aside from the possibilites of a new direction that DiGITS may bring to your workflow. Improvements include extending the regression to predict 600 values of the probability distribution, improved networks such as AlexNet or GoogLeNet (example protobuf files in DiGITS), data augmentation of the network (maybe rearranging time sequence), different loss functions, hyperparameter adjustments, etc.
Happy competing!