from enum import Enum
from typing import Optional, Tuple, Union
import numpy as np
import SimpleITK as sitk
from pyradise.data import (Image, IntensityImage, Modality, SegmentationImage,
Subject, TransformInfo, str_to_modality)
from .base import Filter, FilterParams
__all__ = [
"IntraSubjectRegistrationFilterParams",
"IntraSubjectRegistrationFilter",
"InterSubjectRegistrationFilterParams",
"InterSubjectRegistrationFilter",
"RegistrationType",
]
[docs]class RegistrationType(Enum):
"""An enum class representing the different registration transform types."""
AFFINE = 1
"""Affine registration."""
SIMILARITY = 2
"""Similarity registration."""
RIGID = 3
"""Rigid registration."""
BSPLINE = 4
"""BSpline registration."""
def get_interpolator(image: Image) -> Optional[int]:
"""Get the appropriate interpolator for the given image depending on the image type.
Args:
image (Image): The image.
Returns:
Optional[int]: The interpolator.
"""
if isinstance(image, IntensityImage):
return sitk.sitkBSpline
elif isinstance(image, SegmentationImage):
return sitk.sitkNearestNeighbor
else:
return None
def get_registration_method(
registration_type: RegistrationType = RegistrationType.RIGID,
number_of_histogram_bins: int = 200,
learning_rate: float = 1.0,
step_size: float = 0.001,
number_of_iterations: int = 1500,
relaxation_factor: float = 0.5,
shrink_factors: Tuple[int, ...] = (2, 2, 1),
smoothing_sigmas: Tuple[float, ...] = (2, 1, 0),
sampling_percentage: float = 0.2,
deterministic: bool = True,
) -> sitk.ImageRegistrationMethod:
"""Get the registration method based on the provided parameters.
Args:
registration_type (RegistrationType): The type of registration (default: RegistrationType.RIGID).
number_of_histogram_bins (int): The number of histogram bins for registration (default: 200).
learning_rate (float): The learning rate of the optimizer (default: 1.0).
step_size (float): The step size of the optimizer (default: 0.001).
number_of_iterations (int): The maximal number of optimization iterations (default: 1500).
relaxation_factor (float): The relaxation factor (default: 0.5).
shrink_factors (Tuple[int, ...): The shrink factors for the image pyramid (default: (2, 2, 1))).
smoothing_sigmas (Tuple[float, ...]): The smoothing sigmas (default: (2, 1, 0))).
sampling_percentage (float): The sampling percentage of the voxels to incorporate into the optimization
(default: 0.2).
deterministic (bool): Deterministic processing with a fixed seed and a single thread (default: True).
Returns:
sitk.ImageRegistrationMethod: The registration method.
"""
registration = sitk.ImageRegistrationMethod()
registration.SetMetricAsMattesMutualInformation(number_of_histogram_bins)
if deterministic:
# https://simpleitk.readthedocs.io/en/master/registrationOverview.html
registration.SetGlobalDefaultNumberOfThreads(0)
sampling_seed = 42
registration.SetMetricSamplingPercentage(sampling_percentage, sampling_seed)
else:
registration.SetMetricSamplingStrategy(registration.RANDOM)
registration.SetMetricSamplingPercentage(sampling_percentage, sitk.sitkWallClock)
registration.SetMetricUseFixedImageGradientFilter(False)
registration.SetMetricUseMovingImageGradientFilter(False)
registration.SetInterpolator(sitk.sitkLinear)
if registration_type == RegistrationType.BSPLINE:
registration.SetOptimizerAsLBFGSB()
else:
registration.SetOptimizerAsRegularStepGradientDescent(
learningRate=learning_rate,
minStep=step_size,
numberOfIterations=number_of_iterations,
relaxationFactor=relaxation_factor,
gradientMagnitudeTolerance=1e-4,
estimateLearningRate=registration.EachIteration,
maximumStepSizeInPhysicalUnits=0.0,
)
registration.SetOptimizerScalesFromPhysicalShift()
# Setup for the multi-resolution framework
registration.SetShrinkFactorsPerLevel(shrink_factors)
registration.SetSmoothingSigmasPerLevel(smoothing_sigmas)
registration.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
return registration
def register_images(
moving_image: sitk.Image,
fixed_image: sitk.Image,
registration_type: RegistrationType,
registration_method: sitk.ImageRegistrationMethod,
) -> sitk.Transform:
"""Register the moving image to the fixed image and return the transformation.
Args:
moving_image (sitk.Image): The moving image.
fixed_image (sitk.Image): The fixed image.
registration_type (RegistrationType): The registration type.
registration_method (sitk.ImageRegistrationMethod): The registration method.
Returns:
sitk.Transform: The registration transformation.
"""
if moving_image.GetDimension() != fixed_image.GetDimension():
raise ValueError("The floating and fixed image dimensions do not match!")
dims = moving_image.GetDimension()
if dims not in (2, 3):
raise ValueError("The image must have 2 or 3 dimensions. Different number of dimensions are not supported!")
moving_image_f32 = sitk.Cast(moving_image, sitk.sitkFloat32)
fixed_image_f32 = sitk.Cast(fixed_image, sitk.sitkFloat32)
if registration_type == RegistrationType.BSPLINE:
transform_domain_mesh_size = [10] * dims
initial_transform = sitk.BSplineTransformInitializer(fixed_image, transform_domain_mesh_size)
else:
if registration_type == RegistrationType.RIGID:
transform_type = sitk.VersorRigid3DTransform() if dims == 3 else sitk.Euler2DTransform()
elif registration_type == RegistrationType.AFFINE:
transform_type = sitk.AffineTransform(dims)
elif registration_type == RegistrationType.SIMILARITY:
transform_type = sitk.Similarity3DTransform() if dims == 3 else sitk.Similarity2DTransform()
else:
raise ValueError(f"The registration type ({registration_type.name}) is not supported!")
initial_transform = sitk.CenteredTransformInitializer(
fixed_image_f32, moving_image_f32, transform_type, sitk.CenteredTransformInitializerFilter.GEOMETRY
)
registration_method.SetInitialTransform(initial_transform, inPlace=True)
transform = registration_method.Execute(fixed_image_f32, moving_image_f32)
return transform
# pylint: disable = too-few-public-methods
[docs]class IntraSubjectRegistrationFilterParams(FilterParams):
"""A filter parameter class for the :class:`~pyradise.process.registration.IntraSubjectRegistrationFilter` class.
Args:
reference_modality (Union[Modality, str]): The reference modality.
registration_type (RegistrationType): The type of registration (default: RegistrationType.RIGID).
number_of_histogram_bins (int): The number of histogram bins for registration (default: 200).
learning_rate (float): The learning rate of the optimizer (default: 1.0).
step_size (float): The step size of the optimizer (default: 0.001).
number_of_iterations (int): The maximal number of optimization iterations (default: 1500).
relaxation_factor (float): The relaxation factor (default: 0.5).
shrink_factors (Tuple[int, ...): The shrink factors for the image pyramid (default: (2, 2, 1))).
smoothing_sigmas (Tuple[float, ...]): The smoothing sigmas (default: (2, 1, 0))).
sampling_percentage (float): The sampling percentage of the voxels to incorporate into the optimization
(default: 0.2).
resampling_interpolator (int): The resampling interpolator (default: sitk.sitkBSpline).
deterministic (bool): Deterministic processing with a fixed seed and a single thread (default: True).
"""
def __init__(
self,
reference_modality: Union[Modality, str],
registration_type: RegistrationType = RegistrationType.RIGID,
number_of_histogram_bins: int = 200,
learning_rate: float = 1.0,
step_size: float = 0.001,
number_of_iterations: int = 1500,
relaxation_factor: float = 0.5,
shrink_factors: Tuple[int, ...] = (2, 2, 1),
smoothing_sigmas: Tuple[float, ...] = (2, 1, 0),
sampling_percentage: float = 0.2,
resampling_interpolator: int = sitk.sitkBSpline,
deterministic: bool = True,
) -> None:
super().__init__()
if len(shrink_factors) != len(smoothing_sigmas):
raise ValueError("The shrink_factors and smoothing_sigmas need to have the same length!")
self.reference_modality: Modality = str_to_modality(reference_modality)
self.registration_type = registration_type
self.number_of_histogram_bins = number_of_histogram_bins
self.learning_rate = learning_rate
self.step_size = step_size
self.number_of_iterations = number_of_iterations
self.relaxation_factor = relaxation_factor
self.shrink_factors: Tuple[int, ...] = shrink_factors
self.smoothing_sigmas: Tuple[float, ...] = smoothing_sigmas
self.sampling_percentage = sampling_percentage
self.resampling_interpolator = resampling_interpolator
self.deterministic = deterministic
[docs]class IntraSubjectRegistrationFilter(Filter):
"""An invertible intra-subject registration filter class which registers all
:class:`~pyradise.data.image.IntensityImage` instances to a reference :class:`~pyradise.data.image.IntensityImage`
instance.
Important:
This filter assumes that the :class:`~pyradise.data.image.SegmentationImage` instances are already registered
to the reference :class:`~pyradise.data.image.IntensityImage` instance. No transformation will be applied to
the :class:`~pyradise.data.image.SegmentationImage` instances.
Warning:
The inverse registration procedure may not yield the expected results if successive
:class:`~pyradise.process.base.Filter` s are applied to the same :class:`~pyradise.data.image.Image` instances.
Thus, it's recommended to use the invertibility feature with appropriate caution.
"""
[docs] @staticmethod
def is_invertible() -> bool:
"""Return whether the filter is invertible or not.
Returns:
bool: True because the registration filter is invertible.
"""
return True
# noinspection DuplicatedCode
def _process_image(
self,
moving_image: Image,
fixed_image: sitk.Image,
params: IntraSubjectRegistrationFilterParams,
transform: Optional[sitk.Transform] = None,
track_infos: bool = True,
) -> Image:
"""Apply the transformation or register the image to the reference image.
Args:
moving_image (Image): The moving image.
fixed_image (sitk.Image): The fixed image.
params (IntraSubjectRegistrationFilterParams): The filter parameters.
transform (Optional[sitk.Transform]): The transformation to apply to the image (default: None).
track_infos (bool): Whether to track the processing information or not (default: True).
Returns:
Image: The registered image.
"""
# get the moving image as SimpleITK image
moving_image_sitk = moving_image.get_image_data()
# cast the image if its pixels are not of type float32
if isinstance(moving_image, IntensityImage):
moving_image_sitk = sitk.Cast(moving_image_sitk, sitk.sitkFloat32)
fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)
# register the moving image to the fixed image if no transform is given
if transform is None:
# get the registration method
registration_method = get_registration_method(
params.registration_type,
params.number_of_histogram_bins,
params.learning_rate,
params.step_size,
params.number_of_iterations,
params.relaxation_factor,
params.shrink_factors,
params.smoothing_sigmas,
params.sampling_percentage,
params.deterministic,
)
transform = register_images(moving_image_sitk, fixed_image, params.registration_type, registration_method)
# get the interpolator according to the image type
interpolator = get_interpolator(moving_image)
if interpolator is None:
return moving_image
# resample the moving image
min_intensity = float(np.min(sitk.GetArrayFromImage(moving_image_sitk)))
new_image_sitk = sitk.Resample(
moving_image_sitk, fixed_image, transform, interpolator, min_intensity, moving_image_sitk.GetPixelIDValue()
)
# set the new image data to the image
moving_image.set_image_data(new_image_sitk)
# track the necessary information
if track_infos:
self.tracking_data.update(
{
"original_origin": moving_image_sitk.GetOrigin(),
"original_spacing": moving_image_sitk.GetSpacing(),
"original_direction": moving_image_sitk.GetDirection(),
"original_size": moving_image_sitk.GetSize(),
}
)
self._register_tracked_data(moving_image, moving_image_sitk, new_image_sitk, params, transform)
return moving_image
[docs] def execute(self, subject: Subject, params: IntraSubjectRegistrationFilterParams) -> Subject:
"""Execute the intra-subject registration procedure.
Args:
subject (Subject): The :class:`~pyradise.data.subject.Subject` instance to be processed.
params (IntraSubjectRegistrationFilterParams): The filter parameters.
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with registered
:class:`~pyradise.data.image.IntensityImage` instances.
"""
# get the reference image
reference_image = subject.get_image_by_modality(params.reference_modality)
reference_image_sitk = reference_image.get_image_data()
# perform the registration
for image in subject.get_images():
if isinstance(image, IntensityImage):
if image.get_modality() == params.reference_modality:
continue
self._process_image(image, reference_image_sitk, params, track_infos=True)
return subject
# noinspection DuplicatedCode
[docs] def execute_inverse(
self,
subject: Subject,
transform_info: TransformInfo,
target_image: Optional[Union[SegmentationImage, IntensityImage]] = None,
) -> Subject:
"""Execute the inverse of the intra-subject registration procedure.
Args:
subject: The :class:`~pyradise.data.subject.Subject` instance to be processed.
transform_info: The transform information.
target_image (Optional[Union[SegmentationImage, IntensityImage]]): The target image to which the inverse
transformation should be applied. If None, the inverse transformation is applied to all images (default:
None).
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with unregistered
:class:`~pyradise.data.image.IntensityImage` instances.
"""
# construct the original image as a reference
original_image_props = transform_info.get_image_properties(pre_transform=True)
reference_image_np = np.zeros(original_image_props.size[::-1], dtype=float)
reference_image_sitk = sitk.GetImageFromArray(reference_image_np)
reference_image_sitk.SetOrigin(original_image_props.origin)
reference_image_sitk.SetSpacing(original_image_props.spacing)
reference_image_sitk.SetDirection(original_image_props.direction)
# get the inverse transform
transform = transform_info.get_transform(True)
# perform the inverse registration
for image in subject.get_images():
if target_image is not None and image != target_image:
continue
if isinstance(image, IntensityImage):
if image.get_modality() == transform_info.params.reference_modality:
continue
self._process_image(
image, reference_image_sitk, transform_info.get_params(), transform, track_infos=False
)
return subject
# pylint: disable = too-few-public-methods
[docs]class InterSubjectRegistrationFilterParams(FilterParams):
"""A filter parameter class for the :class:`~pyradise.process.registration.InterSubjectRegistrationFilter` class.
Args:
reference_subject (Subject): The reference subject to which the subject will be registered.
reference_modality (Union[Modality, str]): The modality of the reference image (fixed image) to be used for
registration.
subject_modality (Optional[Union[Modality, str]]): The modality of the subject image (moving image) to be used
for registration. If ``None``, the same modality as the reference image will be used (default: None).
registration_type (RegistrationType): The type of registration (default: RegistrationType.RIGID).
number_of_histogram_bins (int): The number of histogram bins for registration (default: 200).
learning_rate (float): The learning rate of the optimizer (default: 1.0).
step_size (float): The step size of the optimizer (default: 0.001).
number_of_iterations (int): The maximal number of optimization iterations (default: 1500).
relaxation_factor (float): The relaxation factor (default: 0.5).
shrink_factors (Tuple[int, ...): The shrink factors for the image pyramid (default: (2, 2, 1))).
smoothing_sigmas (Tuple[float, ...]): The smoothing sigmas (default: (2, 1, 0))).
sampling_percentage (float): The sampling percentage of the voxels to incorporate into the optimization
(default: 0.2).
resampling_interpolator (int): The interpolator to use for resampling the image.
deterministic (bool): Deterministic processing with a fixed seed and a single thread (default: True).
"""
# pylint: disable=too-many-instance-attributes, too-many-arguments
def __init__(
self,
reference_subject: Subject,
reference_modality: Union[Modality, str],
subject_modality: Optional[Union[Modality, str]] = None,
registration_type: RegistrationType = RegistrationType.RIGID,
number_of_histogram_bins: int = 200,
learning_rate: float = 1.0,
step_size: float = 0.001,
number_of_iterations: int = 1500,
relaxation_factor: float = 0.5,
shrink_factors: Tuple[int, ...] = (2, 2, 1),
smoothing_sigmas: Tuple[float, ...] = (2, 1, 0),
sampling_percentage: float = 0.2,
resampling_interpolator: int = sitk.sitkBSpline,
deterministic: bool = True,
) -> None:
super().__init__()
if len(shrink_factors) != len(smoothing_sigmas):
raise ValueError("The shrink_factors and smoothing_sigmas need to have the same length!")
self.reference_subject = reference_subject
self.reference_modality: Modality = str_to_modality(reference_modality)
self.subject_modality: Modality = (
str_to_modality(subject_modality) if subject_modality is not None else reference_modality
)
self.registration_type = registration_type
self.number_of_histogram_bins = number_of_histogram_bins
self.learning_rate = learning_rate
self.step_size = step_size
self.number_of_iterations = number_of_iterations
self.relaxation_factor = relaxation_factor
self.shrink_factors: Tuple[int, ...] = shrink_factors
self.smoothing_sigmas: Tuple[float, ...] = smoothing_sigmas
self.sampling_percentage = sampling_percentage
self.resampling_interpolator = resampling_interpolator
self.deterministic = deterministic
[docs]class InterSubjectRegistrationFilter(Filter):
"""An invertible inter-subject registration filter class which registers all
:class:`~pyradise.data.image.IntensityImage` instances of the provided :class:`~pyradise.data.subject.Subject` to a
reference :class:`~pyradise.data.image.IntensityImage` instance of another :class:`~pyradise.data.subject.Subject`.
Important:
This filter assumes that all :class:`~pyradise.data.image.Image` instances of the provided
:class:`~pyradise.data.subject.Subject` are co-registered such that the
:class:`~pyradise.data.image.SegmentationImage` instances do not require special treatment.
Warning:
The inverse registration procedure may not yield the expected results if successive
:class:`~pyradise.process.base.Filter` s are applied to the same :class:`~pyradise.data.image.Image` instances.
Thus, it's recommended to use the invertibility feature with appropriate caution.
"""
[docs] @staticmethod
def is_invertible() -> bool:
"""Return whether the filter is invertible.
Returns:
bool: True because the inter-subject registration is invertible.
"""
return True
# noinspection DuplicatedCode
def _apply_transform(
self,
subject: Subject,
transform: sitk.Transform,
reference_image: sitk.Image,
params: InterSubjectRegistrationFilterParams,
) -> Subject:
"""Apply the provided transformation to the subject.
Args:
subject (Subject): The subject.
transform (sitk.Transform): The transformation to apply to the subject.
reference_image (sitk.Image): The reference image.
params (InterSubjectRegistrationFilterParams): The filters parameters.
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with transformed
:class:`~pyradise.data.image.Image` instances.
"""
# transform and resample the images
for image in subject.get_images():
interpolator = get_interpolator(image)
if interpolator is None:
continue
# get the image data and cast if necessary
image_sitk = image.get_image_data()
if isinstance(image, IntensityImage):
image_sitk = sitk.Cast(image_sitk, sitk.sitkFloat32)
# resample the image
min_intensity = float(np.min(sitk.GetArrayFromImage(image_sitk)))
new_image_sitk = sitk.Resample(
image_sitk, reference_image, transform, interpolator, min_intensity, image_sitk.GetPixelIDValue()
)
# set the new image data to the image
image.set_image_data(new_image_sitk)
# track the necessary data
self._register_tracked_data(image, image_sitk, new_image_sitk, params, transform)
return subject
# noinspection DuplicatedCode
@staticmethod
def _apply_inverse_transform(
subject: Subject,
transform_info: TransformInfo,
target_image: Optional[Union[SegmentationImage, IntensityImage]] = None,
) -> Subject:
"""Apply the inverse transformation to the subject.
Args:
subject (Subject): The subject.
transform_info (TransformInfo): The transformation information.
target_image (Optional[Union[SegmentationImage, IntensityImage]]): The target image to which the inverse
transformation should be applied. If None, the inverse transformation is applied to all images (default:
None).
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with back-transformed
:class:`~pyradise.data.image.Image` instances.
"""
# construct the original image as a reference
original_image_props = transform_info.get_image_properties(pre_transform=True)
reference_image_np = np.zeros(original_image_props.size[::-1], dtype=float)
reference_image_sitk = sitk.GetImageFromArray(reference_image_np)
reference_image_sitk.SetOrigin(original_image_props.origin)
reference_image_sitk.SetSpacing(original_image_props.spacing)
reference_image_sitk.SetDirection(original_image_props.direction)
# get the inverse transform
transform = transform_info.get_transform(True)
# transform and resample the images
for image in subject.get_images():
if target_image is not None and image != target_image:
continue
# get the image data and cast if necessary
image_sitk = image.get_image_data()
if isinstance(image, IntensityImage):
image_sitk = sitk.Cast(image_sitk, sitk.sitkFloat32)
# the interpolator
interpolator = get_interpolator(image)
if interpolator is None:
continue
# resample the image
min_intensity = float(np.min(sitk.GetArrayFromImage(image_sitk)))
new_image_sitk = sitk.Resample(
image_sitk, reference_image_sitk, transform, interpolator, min_intensity, image_sitk.GetPixelIDValue()
)
# set the new image data to the image
image.set_image_data(new_image_sitk)
return subject
@staticmethod
def _register_image(
subject: Subject, reference_image: sitk.Image, params: InterSubjectRegistrationFilterParams
) -> sitk.Transform:
"""Register the subject image to the specific modality of the reference subject.
Args:
subject (Subject): The subject to register.
reference_image (sitk.Image): The reference image.
params (InterSubjectRegistrationFilterParams): The filters parameters.
Returns:
sitk.Transform: The registration transformation.
"""
moving_image = subject.get_image_by_modality(params.subject_modality)
moving_image_sitk = moving_image.get_image_data()
# get the registration method
registration_method = get_registration_method(
params.registration_type,
params.number_of_histogram_bins,
params.learning_rate,
params.step_size,
params.number_of_iterations,
params.relaxation_factor,
params.shrink_factors,
params.smoothing_sigmas,
params.sampling_percentage,
params.deterministic,
)
return register_images(moving_image_sitk, reference_image, params.registration_type, registration_method)
[docs] def execute(self, subject: Subject, params: InterSubjectRegistrationFilterParams) -> Subject:
"""Executes the inter-subject registration procedure.
Args:
subject (Subject): The :class:`~pyradise.data.subject.Subject` instance to be processed.
params (InterSubjectRegistrationFilterParams): The filter parameters.
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with all
:class:`~pyradise.data.image.IntensityImage` instances registered to the reference subject
:class:`~pyradise.data.image.IntensityImage` instance.
"""
# get the reference image
reference_image = params.reference_subject.get_image_by_modality(params.reference_modality)
reference_image_sitk = reference_image.get_image_data()
# register the subject to the reference image
transform = self._register_image(subject, reference_image_sitk, params)
# apply the transform to the other images of the subject
subject = self._apply_transform(subject, transform, reference_image_sitk, params)
return subject
[docs] def execute_inverse(
self,
subject: Subject,
transform_info: TransformInfo,
target_image: Optional[Union[SegmentationImage, IntensityImage]] = None,
) -> Subject:
"""Execute the inverse of the inter-subject registration procedure.
Args:
subject (Subject): The :class:`~pyradise.data.subject.Subject` instance to be processed.
transform_info (TransformInfo): The transform information.
target_image (Optional[Union[SegmentationImage, IntensityImage]]): The target image to which the inverse
transformation should be applied. If None, the inverse transformation is applied to all images (default:
None).
Returns:
Subject: The :class:`~pyradise.data.subject.Subject` instance with unregistered
:class:`~pyradise.data.image.IntensityImage` instances.
"""
subject = self._apply_inverse_transform(subject, transform_info, target_image)
return subject