Source code for pyradise.data.taping

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, TypeVar

import numpy as np
import SimpleITK as sitk

__all__ = ["Tape", "TransformTape", "TransformInfo"]

# pylint: disable=no-member

# Forward declaration of image types
Image = TypeVar("Image")
IntensityImage = TypeVar("IntensityImage")
SegmentationImage = TypeVar("SegmentationImage")
Filter = TypeVar("Filter")
FilterParameters = TypeVar("FilterParameters")
ImageProperties = TypeVar("ImageProperties")
Subject = TypeVar("Subject")


[docs]class Tape(ABC): """An abstract class for a tape which records defined elements and can replay them upon request.""" def __init__(self) -> None: super().__init__() self.recordings = []
[docs] @abstractmethod def record(self, value: Any) -> None: """Record a value on the :class:`Tape`. Args: value (Any): The value to be recorded. Returns: None """ raise NotImplementedError()
[docs] @staticmethod @abstractmethod def playback(data: Any, **kwargs) -> Any: """Playback the recorded elements of the :class:`Tape` on the data object. Args: data (Any): The data on which the playback should take place. This object need to contain also the tape. **kwargs: Additional keyword arguments. Returns: Any: The back played data. """ raise NotImplementedError()
[docs] def get_recorded_elements(self, reverse: bool = False) -> Tuple[Any, ...]: """Get the recorded elements on the :class:`Tape`. Args: reverse (bool): Indicates if the recordings should be returned in reverse order. Returns: Tuple[Any, ...]: The recorded elements of the :class:`Tape`. """ if reverse: return tuple(reversed(self.recordings)) return tuple(self.recordings)
[docs] def reset(self) -> None: """Reset the :class:`Tape`. Returns: None """ self.recordings = []
[docs]class TransformInfo: """A class to store information about a data transformation performed via a :class:`~pyradise.process.base.Filter`. This class is used in combination with a :class:`~pyradise.data.taping.TransformTape` instance to keep track of data transformations and to render invertibility feasible for invertible filters operations. Args: name (str): The name of the filter which performed the data transformation. params (Optional[FilterParameters]): The filter parameters which parameterize the data transformation. pre_transform_image_properties (ImageProperties): The image properties before the data transformation. post_transform_image_properties (ImageProperties): The image properties after the data transformation. filter_args (Optional[Dict[str, Any]]): The filter arguments passed via the constructor of the filter (default: None). additional_data (Optional[Dict[str, Any]]): Additional data which is required the data transformation or to inverse it (default: None). transform (Optional[sitk.Transform]): A SimpleITK transform which may be used for the data transformation (default: None). """ def __init__( self, name: str, params: Optional[FilterParameters], pre_transform_image_properties: ImageProperties, post_transform_image_properties: ImageProperties, filter_args: Optional[Dict[str, Any]] = None, additional_data: Optional[Dict[str, Any]] = None, transform: Optional[sitk.Transform] = None, ) -> None: super().__init__() self.name = name self.params = params self.pre_transform_image_properties: ImageProperties = pre_transform_image_properties self.post_transform_image_properties: ImageProperties = post_transform_image_properties self.filter_args: Dict[str, Any] = filter_args if filter_args is not None else dict() self.additional_data: Dict[str, Any] = additional_data if additional_data is not None else dict() self.transform: Optional[sitk.Transform] = transform def _get_subclasses(self, cls: type) -> Dict[str, type]: """Get all subclasses of the provided class. Args: cls (type): The class to get the subclasses of. Returns: Dict[str, type]: A dictionary containing the subclasses of the provided class. """ subclasses = {} for subclass in cls.__subclasses__(): subclasses.update({subclass.__name__: subclass}) if subclass.__subclasses__(): subclasses.update(self._get_subclasses(subclass)) return subclasses
[docs] def get_filter(self) -> Filter: """Get the :class:`~pyradise.process.base.Filter` instance which performed the data transformation. Returns: Filter: The filter used for the data transformation. """ from pyradise.process import Filter subclasses = self._get_subclasses(Filter) return subclasses.get(self.name)(**self.filter_args)
[docs] def get_params(self) -> FilterParameters: """Get the :class:`~pyradise.process.base.FilterParams` instance which was used to parameterize the data transformation. Returns: FilterParameters: The filter parameters used for the data transformation. """ return self.params
[docs] def get_image_properties(self, pre_transform: bool) -> ImageProperties: """Get the pre-transform or post-transform :class:`~pyradise.data.image.ImageProperties` instance. Args: pre_transform (bool): If True returns the pre-transform image properties, otherwise the post-transform image properties. Returns: ImageProperties: The pre-transform or post-transform image properties. """ if pre_transform: return self.pre_transform_image_properties return self.post_transform_image_properties
[docs] def add_data(self, key: str, value: Any) -> None: """Add additional data to the :class:`TransformInfo` instance. Note: If the provided key already exists, the value will be overwritten. Args: key (str): The key of the additional data. value (Any): The value of the additional data. Returns: None """ self.additional_data[key] = value
[docs] def get_data(self, key: str) -> Any: """Get additional data from the :class:`TransformInfo` instance by key. Args: key (str): The key of the additional data entry to get. Returns: Any: The value of the additional data entry. If the key is not existing :data:`None` is returned. """ return self.additional_data.get(key, None)
[docs] def get_transform(self, inverse: bool = False) -> sitk.Transform: """Get the :class:`SimpleITK.Transform` instance which was used to perform the data transformation. Args: inverse (bool): Indicates if the inverse transform should be returned (default: False). Returns: sitk.Transform: The transform used for the data transformation or the identity transform if origin and direction did not change during data transformation. """ if self.transform is not None: if inverse: return self.transform.GetInverse() return self.transform # check if the image origin and direction have changed num_dims = len(self.pre_transform_image_properties.size) if self.pre_transform_image_properties.has_equal_origin_direction(self.post_transform_image_properties): transform = sitk.AffineTransform(num_dims) transform.SetIdentity() return transform else: transform = sitk.AffineTransform(num_dims) transform.SetIdentity() # compute the translation post_origin = self.post_transform_image_properties.origin pre_origin = self.pre_transform_image_properties.origin translation = list(np.array(post_origin) - np.array(pre_origin)) # compute the rotation post_direction = np.array(self.post_transform_image_properties.direction).reshape(num_dims, num_dims) pre_direction = np.array(self.pre_transform_image_properties.direction).reshape(num_dims, num_dims) rotation = np.matmul(np.linalg.inv(pre_direction), post_direction) rotation = list(rotation.reshape(-1)) # set the transform parameters transform.SetParameters(rotation + translation) # return the inverted or the original transform if inverse: transform = transform.GetInverse() return transform
[docs]class TransformTape(Tape): """A class to keep track of the :class:`~pyradise.data.taping.TransformInfo` instances such that they can be played back on appropriate data. This class provides the basic functionality to render invertibility and reproducibility feasible. """ def __init__(self): super().__init__()
[docs] def record(self, value: TransformInfo) -> None: """Record a :class:`~pyradise.data.taping.TransformInfo` instance on the tape. Args: value (TransformInfo): The :class:`~pyradise.data.taping.TransformInfo` instance to record. Returns: None """ self.recordings.append(value)
[docs] def get_recorded_elements(self, reverse: bool = False) -> Tuple[TransformInfo, ...]: """Get the recorded :class:`~pyradise.data.taping.TransformInfo` instances. Args: reverse (bool): Indicates if the recorded elements should be returned in reverse order (default: False). Returns: Tuple[TransformInfo, ...]: The recorded :class:`~pyradise.data.taping.TransformInfo` instances. """ return super().get_recorded_elements(reverse)
[docs] @staticmethod def playback(data: Image, **kwargs) -> Image: """Play back the recorded :class:`~pyradise.data.taping.TransformInfo` instances on the provided data. Args: data (Image): The data to play back the recorded :class:`~pyradise.data.taping.TransformInfo` instances on. **kwargs: Additional keyword arguments. Returns: Image: The :class:`~pyradise.data.image.Image` instance after the playback of the recorded :class:`~pyradise.data.taping.TransformInfo` instances. """ from pyradise.data import Subject # create a temporary subject to store the image subject = Subject("temporary_playback_subject", data) # playback the transformations for transform_info in data.get_transform_tape().get_recorded_elements(reverse=True): filter_ = transform_info.get_filter() if not filter_.is_invertible(): continue subject = filter_.execute_inverse(subject, transform_info) # set the new image data on the original image image = subject.get_images_by_type(type(data))[0] data.set_image_data(image.get_image_data()) # clear the recordings after playback data.get_transform_tape().recordings.clear() return data