Source code for pyradise.fileio.loading

from abc import ABC, abstractmethod
from typing import List, Sequence, Tuple

import SimpleITK as sitk

from pyradise.data import IntensityImage, SegmentationImage, Subject

from .dicom_conversion import (DicomImageSeriesConverter,
                               DicomRTSSSeriesConverter)
from .series_info import (DicomSeriesImageInfo, DicomSeriesRegistrationInfo,
                          DicomSeriesRTSSInfo, IntensityFileSeriesInfo,
                          SegmentationFileSeriesInfo, SeriesInfo)

__all__ = ["Loader", "ExplicitLoader", "SubjectLoader", "IterableSubjectLoader"]


[docs]class Loader(ABC): """An abstract base class for all :class:`Loader` classes. A :class:`Loader` class typically takes a sequence of :class:`~pyradise.fileio.series_info.SeriesInfo` entries and loads the data based on the information provided by the :class:`~pyradise.fileio.series_info.SeriesInfo` entries. The data is then returned as a :class:`~pyradise.data.subject.Subject` such that it can be used directly for further processing with for example the :mod:`~pyradise.process` package or the :mod:`~pyradise.fileio.writing` module. """ @staticmethod def _extract_info_by_type(info: Sequence[SeriesInfo], type_: type) -> Tuple: """Extract all :class:`~pyradise.fileio.series_info.SeriesInfo` entries of the specified type from the provided sequence of :class:`~pyradise.fileio.series_info.SeriesInfo`. Args: info (Sequence[SeriesInfo]): The sequence of :class:`~pyradise.fileio.series_info.SeriesInfo` entries. type_ (type): The specific sub-type of the :class:`~pyradise.fileio.series_info.SeriesInfo` class to be extracted. Returns: Tuple[SeriesInfo]: The extracted :class:`~pyradise.fileio.series_info.SeriesInfo` entries. """ return tuple(filter(lambda x: isinstance(x, type_), info))
[docs]class ExplicitLoader(Loader, ABC): """An abstract :class:`Loader` class that implements a :meth:`load` method such that multiple sets of :class:`~pyradise.fileio.series_info.SeriesInfo` entries can be loaded with the same :class:`Loader` instance."""
[docs] @abstractmethod def load(self, info: Tuple[SeriesInfo, ...]) -> Subject: """Load the :class:`~pyradise.data.subject.Subject`. Args: info (Tuple[SeriesInfo, ...]): The :class:`~pyradise.fileio.series_info.SeriesInfo` entries to be loaded. Returns: Subject: The loaded :class:`~pyradise.data.subject.Subject`. """ raise NotImplementedError()
[docs]class SubjectLoader(ExplicitLoader): """An :class:`ExplicitLoader` for loading a :class:`~pyradise.data.subject.Subject` based on its :class:`~pyradise.fileio.series_info.SeriesInfo` entries. This loader can load both DICOM data (i.e. :class:`~pyradise.fileio.series_info.DicomSeriesInfo`) and discrete image data (i.e. :class:`~pyradise.fileio.series_info.FileSeriesInfo`). The loader validates the provided :class:`~pyradise.fileio.series_info.SeriesInfo` entries before loading and raises appropriate errors if the information is not valid. Examples: Load and normalize NIFTI files and save the subject as NRRD files: >>> from argparse import ArgumentParser >>> from pyradise.fileio import (SubjectFileCrawler, SubjectLoader, >>> SubjectWriter, ImageFileFormat) >>> from pyradise.process import (ZScoreNormFilter, >>> ZScoreNormFilterParams) >>> >>> >>> def main(input_path: str, output_path: str, subject_name: str) -> None: >>> # Crawl the input directory for compressed NIFTI files >>> info = SubjectFileCrawler(input_path, subject_name, 'nii.gz').execute() >>> >>> # Load the subject >>> subject = SubjectLoader().load(info) >>> >>> # Perform the normalization >>> normalization_params = ZScoreNormFilterParams(loop_axis=1) >>> normalization_filter = ZScoreNormFilter(normalization_params) >>> subject = normalization_filter.execute(subject) >>> >>> # Write the subject to the output directory >>> writer = SubjectWriter(ImageFileFormat.NRRD) >>> writer.write(output_path, subject, write_transforms=False) >>> >>> >>> if __name__ == '__main__': >>> parser = ArgumentParser() >>> parser.add_argument('input_path', type=str, help='The input directory.') >>> parser.add_argument('output_path', type=str, help='The output directory.') >>> parser.add_argument('subject_name', type=str, help='The name of the subject.') >>> args = parser.parse_args() >>> >>> main(args.input_path, args.output_path, args.subject_name) Load DICOM data and save the converted data as NIFTI files: >>> from argparse import ArgumentParser >>> from pyradise.fileio import SubjectDicomCrawler, SubjectLoader, SubjectWriter >>> >>> >>> def main(input_path: str, output_path: str) -> None: >>> # Crawl the input directory for DICOM data >>> # Note: We assume that the modality configuration file (modality_config.json) >>> # is existing. >>> info = SubjectDicomCrawler(input_path).execute() >>> >>> # Load the subject >>> subject = SubjectLoader().load(info) >>> >>> # Write the subject to the output directory >>> writer = SubjectWriter() >>> writer.write(output_path, subject, write_transforms=False) >>> >>> >>> if __name__ == '__main__': >>> parser = ArgumentParser() >>> parser.add_argument('input_path', type=str, help='The input directory.') >>> parser.add_argument('output_path', type=str, help='The output directory.') >>> args = parser.parse_args() >>> >>> main(args.input_path, args.output_path) Args: intensity_pixel_value_type (int): The pixel value type of the intensity imagesm when loading discrete image files (default: sitk.sitkFloat32). segmentation_pixel_value_type (int): The pixel value type of the segmentation images when loading discrete files (default: sitk.sitkUInt8). fill_hole_search_distance (int): The search distance for the hole filling algorithm. If the search distance is set to zero the hole filling algorithm is omitted. The search distance must be an odd number larger than 1 (default: 0). """ def __init__( self, intensity_pixel_value_type: int = sitk.sitkFloat32, segmentation_pixel_value_type: int = sitk.sitkUInt8, fill_hole_search_distance: int = 0, ) -> None: super().__init__() self.intensity_pixel_type = intensity_pixel_value_type self.segmentation_pixel_type = segmentation_pixel_value_type # store the fill hole search distance if fill_hole_search_distance == 0: self.fill_hole_distance = 0 elif fill_hole_search_distance % 2 == 0: raise ValueError("The fill hole search distance must be an odd number.") elif fill_hole_search_distance == 1: raise ValueError("The fill hole search distance must be larger than 1.") else: self.fill_hole_distance = fill_hole_search_distance @staticmethod def _load_intensity_images( info: Tuple[IntensityFileSeriesInfo], pixel_value_type: sitk.sitkFloat32 ) -> Tuple[IntensityImage]: """Load the :class:`~pyradise.data.image.IntensityImage` s. Args: info (Tuple[IntensityFileSeriesInfo]): The :class:`~pyradise.data.image.IntensityFileSeriesInfo` entries containing the file paths to the images. pixel_value_type (int): The pixel value type for the intensity images. Returns: Tuple[IntensityImage]: The loaded intensity :class:`~pyradise.data.image.IntensityImage` instances. """ images = [] for info_entry in info: image = sitk.ReadImage(info_entry.get_path()[0], pixel_value_type) images.append(IntensityImage(image, info_entry.get_modality())) return tuple(images) @staticmethod def _load_segmentation_images( info: Tuple[SegmentationFileSeriesInfo], pixel_value_type: sitk.sitkUInt8 ) -> Tuple[SegmentationImage]: """Load the :class:`~pyradise.data.image.SegmentationImage` s. Args: info (Tuple[SegmentationFileSeriesInfo]): The :class:`~pyradise.fileio.series_info.SegmentationFileSeriesInfo` entries containing the file paths to the images. pixel_value_type (int): The pixel value type for the segmentation images. Returns: Tuple[SegmentationImage]: The loaded :class:`~pyradise.data.image.SegmentationImage` instances. """ images = [] for info_entry in info: image = sitk.ReadImage(info_entry.get_path()[0], pixel_value_type) images.append(SegmentationImage(image, info_entry.get_organ(), info_entry.get_annotator())) return tuple(images) @staticmethod def _validate_patient_identification(info: Tuple[SeriesInfo]) -> bool: """Validate the patient identification of the provided :class:`~pyradise.fileio.series_info.SeriesInfo` entries. Args: info (Tuple[SeriesInfo]): The :class:`~pyradise.fileio.series_info.SeriesInfo` entries to check. Returns: bool: True if the patient identification is valid for all info entries, otherwise False. """ if not info: return False names = [entry.get_patient_name() for entry in info] ids = [entry.get_patient_id() for entry in info] return all(name == names[0] for name in names) and all(id_ == ids[0] for id_ in ids) @staticmethod def _validate_registration( reg_info: Tuple[DicomSeriesRegistrationInfo], image_info: Tuple[DicomSeriesImageInfo] ) -> bool: """Validate the ReferencedSeriesInstanceUIDs of the provided :class:`~pyradise.fileio.series_info.DicomSeriesRegistrationInfo` entries by checking if the referenced DICOM image data is provided. Args: reg_info (Tuple[DicomSeriesRegistrationInfo]): The :class:`~pyradise.fileio.series_info.DicomSeriesRegistrationInfo` entries to check. image_info (Tuple[DicomSeriesImageInfo]): The :class:`~pyradise.fileio.series_info.DicomSeriesImageInfo` entries containing the referenced SeriesInstanceUIDs. Returns: bool: True if the image infos for all registration infos is available, otherwise False. """ def is_image_info_available(instance_uids: List[str], image_info_: Tuple[DicomSeriesImageInfo]) -> bool: comparison = [[info.series_instance_uid == uid for info in image_info_] for uid in instance_uids] return all(any(comparison_) for comparison_ in comparison) if not reg_info: return True if not image_info: return False identity_uids = [] transform_uids = [] for reg_info_entry in reg_info: reg_info_entry.update() if not reg_info_entry.is_updated() else None identity = reg_info_entry.referenced_series_instance_uid_identity if identity != "": identity_uids.append(identity) transform = reg_info_entry.referenced_series_instance_uid_transform if transform != "": transform_uids.append(transform) if is_image_info_available(identity_uids, image_info) and is_image_info_available(transform_uids, image_info): return True return False @staticmethod def _validate_rtss_info(rtss_info: Tuple[DicomSeriesRTSSInfo], image_info: Tuple[DicomSeriesImageInfo]) -> bool: """Validate if all SeriesInstanceUIDs referenced in the DICOM-RTSSs are provided. Args: rtss_info (Tuple[DicomSeriesRTSSInfo]): The :class:`~pyradise.fileio.series_info.DicomSeriesRTSSInfo` entries to check. image_info (Tuple[DicomSeriesImageInfo]): The :class:`~pyradise.fileio.series_info.DicomSeriesImageInfo` entries containing the SeriesInstanceUIDs. Returns: bool: True if the referenced image infos for all RTSS infos are available, otherwise False. """ if not rtss_info: return True if not image_info: return False comparison = [ any(info.series_instance_uid == rtss_info_entry.referenced_instance_uid for info in image_info) for rtss_info_entry in rtss_info ] return all(comparison)
[docs] def load(self, info: Tuple[SeriesInfo, ...]) -> Subject: """Load a :class:`~pyradise.data.subject.Subject` from the provided :class:`~pyradise.fileio.series_info.SeriesInfo` entries. Args: info (Tuple[SeriesInfo, ...]): The :class:`~pyradise.fileio.series_info.SeriesInfo` entries containing the necessary information for loading the subject. Raises: ValueError: If ``info`` is an empty tuple. ValueError: If ``info`` is not a tuple of :class:`~pyradise.fileio.series_info.SeriesInfo` entries. ValueError: If the patient name and patient id of the provided :class:`~pyradise.fileio.series_info.SeriesInfo` entries are not equal. ValueError: If not all referenced :class:`~pyradise.fileio.series_info.DicomSeriesImageInfo` entries are provided for registration. ValueError: If not all referenced :class:`~pyradise.fileio.series_info.DicomSeriesImageInfo` entries are provided for RTSS loading. Returns: Subject: The loaded subject. """ # check if the info entries have the correct structure if not info: raise ValueError("The provided info entries are empty.") if not all(isinstance(entry, SeriesInfo) for entry in info): raise ValueError( "The provided info entries are not of type SeriesInfo. " "Make sure to provide a tuple of SeriesInfo entries." ) # separate the info entries dicom_image_info = self._extract_info_by_type(info, DicomSeriesImageInfo) dicom_reg_info = self._extract_info_by_type(info, DicomSeriesRegistrationInfo) dicom_rtss_info = self._extract_info_by_type(info, DicomSeriesRTSSInfo) intensity_image_info = self._extract_info_by_type(info, IntensityFileSeriesInfo) segmentation_image_info = self._extract_info_by_type(info, SegmentationFileSeriesInfo) # validate the info entries if not self._validate_patient_identification(info): raise ValueError("The patient identification (patient_name and patient_id) is not unique!") if not self._validate_registration(dicom_reg_info, dicom_image_info): raise ValueError("At least one referenced image in the registration is missing!") if not self._validate_rtss_info(dicom_rtss_info, dicom_image_info): raise ValueError("The referenced image in the RTSS is not available!") # create the subject if dicom_image_info: subject = Subject(dicom_image_info[0].get_patient_name()) elif intensity_image_info: subject = Subject(intensity_image_info[0].get_patient_name()) elif segmentation_image_info: subject = Subject(segmentation_image_info[0].get_patient_name()) else: raise ValueError("Subject can not be constructed because a subject name is missing!") # load the images and add them to the subject if dicom_image_info: dicom_images = DicomImageSeriesConverter(dicom_image_info, dicom_reg_info).convert() subject.add_images(dicom_images) if dicom_rtss_info: dicom_segmentations = DicomRTSSSeriesConverter( dicom_rtss_info, dicom_image_info, dicom_reg_info, self.fill_hole_distance ).convert() subject.add_images(dicom_segmentations, force=True) intensity_images = self._load_intensity_images(intensity_image_info, self.intensity_pixel_type) segmentation_images = self._load_segmentation_images(segmentation_image_info, self.segmentation_pixel_type) subject.add_images(intensity_images + segmentation_images, force=True) return subject
[docs]class IterableSubjectLoader(Loader): """An :class:`Loader` for loading a sequence of :class:`~pyradise.data.subject.Subject` s based on their :class:`~pyradise.fileio.series_info.SeriesInfo` entries. This loader can load both DICOM data (i.e. :class:`~pyradise.fileio.series_info.DicomSeriesInfo`) and discrete image data (i.e. :class:`~pyradise.fileio.series_info.FileSeriesInfo`). The loader validates the provided :class:`~pyradise.fileio.series_info.SeriesInfo` entries before loading and raises appropriate errors if the information is not valid. Notes: For loading large DICOM dataset we recommend to use the :class:`SubjectLoader` instead because the antecedent crawling process can require a lot of computation time and memory. Raises: ValueError: If ``info`` is an empty tuple. ValueError: If ``info`` is not a tuple of tuples of :class:`~pyradise.fileio.series_info.SeriesInfo` entries. Args: info (Tuple[Tuple[SeriesInfo, ...], ...]): The :class:`~pyradise.fileio.series_info.SeriesInfo` entries for all subjects to load. intensity_pixel_value_type (int): The pixel value type of the intensity imagesm when loading discrete image files (default: sitk.sitkFloat32). segmentation_pixel_value_type (int): The pixel value type of the segmentation images when loading discrete files (default: sitk.sitkUInt8). fill_hole_search_distance (int): The search distance for the hole filling algorithm. If the search distance is set to zero the hole filling algorithm is omitted. The search distance must be an odd number larger than 1 (default: 0). Examples: Load, normalize and save a NIFTI dataset with multiple subjects: >>> from argparse import ArgumentParser >>> from pyradise.fileio import (DatasetFileCrawler, IterableSubjectLoader, >>> SubjectWriter) >>> from pyradise.process import (ZScoreNormFilter, >>> ZScoreNormFilterParams) >>> >>> >>> def main(input_path: str, output_path: str) -> None: >>> # Crawl the dataset info >>> info = DatasetFileCrawler(input_path, '.nii.gz').execute() >>> >>> # Construct the loader >>> loader = IterableSubjectLoader(info) >>> >>> # Construct the normalization filter >>> normalization_params = ZScoreNormFilterParams(loop_axis=1) >>> normalization_filter = ZScoreNormFilter(normalization_params) >>> >>> # Construct the writer >>> writer = SubjectWriter() >>> >>> # Iteratively load the subjects >>> for subject in loader: >>> # Normalize the images >>> subject = normalization_filter.execute(subject) >>> >>> # Save the subject >>> writer.write_to_subject_folder(output_path, subject, write_transforms=False) >>> >>> >>> if __name__ == '__main__': >>> parser = ArgumentParser() >>> parser.add_argument('--input_path', type=str, >>> help='The dataset input directory.') >>> parser.add_argument('--output_path', type=str, >>> help='The dataset output directory.') >>> args = parser.parse_args() >>> >>> main(args.input_path, args.output_path) """ def __init__( self, info: Tuple[Tuple[SeriesInfo, ...], ...], intensity_pixel_value_type: int = sitk.sitkFloat32, segmentation_pixel_value_type: int = sitk.sitkUInt8, fill_hole_search_distance: int = 0, ): super().__init__() if not info: raise ValueError("The provided infos are empty.") if not all(isinstance(entry, tuple) for entry in info): raise ValueError( "The provided first level info entries are not of type tuple. " "Make sure that the info is a tuple of tuples." ) self.info = info self.intensity_pixel_type = intensity_pixel_value_type self.segmentation_pixel_type = segmentation_pixel_value_type # store the fill hole search distance if fill_hole_search_distance == 0: self.fill_hole_distance = 0 elif fill_hole_search_distance % 2 == 0: raise ValueError("The fill hole search distance must be an odd number.") elif fill_hole_search_distance == 1: raise ValueError("The fill hole search distance must be larger than 1.") else: self.fill_hole_distance = fill_hole_search_distance self.current_idx = 0 self.num_subjects = len(self.info) def __iter__(self): self.current_idx = 0 return self def __next__(self) -> Subject: if self.current_idx < self.num_subjects: loader = SubjectLoader(self.intensity_pixel_type, self.segmentation_pixel_type, self.fill_hole_distance) subject = loader.load(self.info[self.current_idx]) self.current_idx += 1 return subject raise StopIteration() def __len__(self): return self.num_subjects