Inference Module#

Module: pyradise.process.inference

General#

The inference module provides a prototype implementation of a DL-framework agnostic and filter-based inference class (i.e. InferenceFilter) which is left for implementation to the user such that the installation of a DL-framework is not required when installing PyRaDiSe.

Class Overview#

The following classes are provided by the inference module:

Class

Description

InferenceFilterParams

Parameterization class for the InferenceFilter.

InferenceFilter

Filter prototype for deep learning model inference.

IndexingStrategy

An abstract base class for all IndexingStrategy.

SliceIndexingStrategy

Indexing strategy for slice-wise looping (for 2D models).

PatchIndexingStrategy

Indexing strategy for patch-wise looping (for 3D models).

Details#

class InferenceFilterParams(model, model_path, modalities, reference_modality, output_organs, output_annotator, organ_indices, batch_size, indexing_strategy)[source]#

Bases: FilterParams

A filter parameter class for the prototype InferenceFilter class.

Parameters:
  • model (Any) – The model to apply.

  • model_path (Optional[str]) – The path to the model parameters.

  • modalities (Tuple[Union[str, Modality], ...]) – The Modality s of the IntensityImage instances to use for inference.

  • reference_modality (Union[str, Modality]) – The Modality that is used as the reference to define the output properties of the created SegmentationImage instances.

  • output_organs (Tuple[Union[str, Organ], ...]) – The organs that get assigned to the created SegmentationImage instances.

  • output_annotator (Union[str, Annotator]) – The annotator that get assigned to the created SegmentationImage instances.

  • organ_indices (Tuple[int, ...]) – The indices of the organs on the output mask of the model (must match output_organs and output_annotators).

  • batch_size (int) – The batch size to use for inference.

  • indexing_strategy (IndexingStrategy) – The IndexingStrategy defining how the data is fed to the model.

class InferenceFilter(warning_on_non_invertible=False)[source]#

Bases: Filter

A prototype filter class for applying a DL-model to a Subject instance.

This class is a prototype for applying a DL-model to a Subject instance. PyRaDiSe provides just a prototype for this filter such that it stays DL framework-agnostic. Therefore, the actual implementation of the DL-framework specific methods must be implemented in a subclass.

For implementing a DL-framework specific InferenceFilter, the following methods must be implemented:

  • _prepare_model(): Prepare the model for inference (e.g. load the parameters from a model file).

  • _infer_on_batch(): Apply the model to a batch of data such that the output shape can be inserted into the new image via the indexing expressions provided by the chosen IndexingStrategy.

Example

Implementation example of a PyTorch-based InferenceFilter subclass:

>>> import torch
>>> import torch.nn as nn
>>>
>>> class ExampleInferenceFilter(InferenceFilter):
>>>
>>>  def __init__(self):
>>>      super().__init__()
>>>
>>>      # Define the device on which the model should be run
>>>      self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>>
>>>      # Define a class attribute for the model
>>>      self.model: Optional[nn.Module] = None
>>>
>>>  def _prepare_model(self, model: nn.Module, model_path: str) -> nn.Module:
>>>
>>>      # Load model parameters
>>>      model.load_state_dict(torch.load(model_path, map_location=self.device))
>>>
>>>      # Assign the model to the class
>>>      self.model = model.to(self.device)
>>>
>>>      # Set model to evaluation mode
>>>      self.model.eval()
>>>
>>>      return model
>>>
>>>  def _infer_on_batch(self,
>>>                      batch: Dict[str, Any],
>>>                      params: InferenceFilterParams
>>>                      ) -> Dict[str, Any]:
>>>      # Stack and adjust the numpy array such that it fits the
>>>      # [batch, channel / images, height, width, (depth)] format
>>>      # Note: The following statement works for slice-wise and patch-wise processing
>>>      if (loop_axis := params.indexing_strategy.loop_axis) is None:
>>>          adjusted_input = np.stack(batch['data'], axis=0)
>>>      else:
>>>          adjusted_input = np.stack(batch['data'], axis=0).squeeze(loop_axis + 2)
>>>
>>>      # Generate a tensor from the numpy array
>>>      input_tensor = torch.from_numpy(adjusted_input)
>>>
>>>      # Move the batch to the same device as the model
>>>      input_tensor = input_tensor.to(self.device, dtype=torch.float32)
>>>
>>>      # Apply the model to the batch
>>>      with torch.no_grad():
>>>          output_tensor = self.model(input_tensor)
>>>
>>>      # Retrieve the predicted classes from the output
>>>      if type(params.indexing_strategy) is SliceIndexingStrategy:
>>>          # Slice-wise processing
>>>
>>>          if len(params.output_organs) > 1:
>>>              # For multi-class segmentation
>>>              final_activation_fn = nn.Softmax2d()
>>>              output_tensor = torch.argmax(final_activation_fn(output_tensor), dim=1)
>>>
>>>          else:
>>>              # For binary segmentation
>>>              final_activation_fn = nn.Sigmoid()
>>>              output_tensor = (final_activation_fn(output_tensor) > 0.5).bool()
>>>
>>>      elif type(params.indexing_strategy) is PatchIndexingStrategy:
>>>
>>>          if len(params.output_organs) > 1:
>>>              # For multi-class segmentation
>>>              final_activation_fn = nn.Softmax(dim=1)
>>>              output_tensor = torch.argmax(final_activation_fn(output_tensor), dim=1)
>>>
>>>          else:
>>>              # For binary segmentation
>>>              final_activation_fn = nn.Sigmoid()
>>>              output_tensor = (final_activation_fn(output_tensor) > 0.5).bool()
>>>
>>>      else:
>>>          raise NotImplementedError(f'Indexing strategy {type(params.indexing_strategy).__name__} not'
>>>                                    'implemented.')
>>>
>>>      # Convert the output to a numpy array
>>>      # Note: The output shape must be [batch, height, width, (depth)]
>>>      output_array = output_tensor.cpu().numpy()
>>>
>>>      # Construct a list of output arrays such that it fits the index expressions
>>>      batch_output_list = []
>>>      for i in range(output_array.shape[0]):
>>>          batch_output_list.append(output_array[i, ...])
>>>
>>>      # Combine the output arrays into a dictionary
>>>      output = {'data': batch_output_list,
>>>                'index_expr': batch['index_expr']}
>>>
>>>      return output
static _get_input_array(subject, params)[source]#

Return the input array for the DL-model.

Note

This function returns the concatenated data in C (channels) x H (height) x W (width) x D (depth) order.

Parameters:
Returns:

The input array for the DL-model.

Return type:

np.ndarray

abstract _prepare_model(model, model_path)[source]#

Prepare the model for inference (e.g. loading the model parameters). The loaded model must be added to a class attribute such that it can be accessed by all methods.

This method must be implemented for the specific DL-framework.

Parameters:
  • model (Any) – The model instance.

  • model_path (str) – The path to the model parameters.

Returns:

The model prepared for inference.

Return type:

Any

abstract _infer_on_batch(batch, params)[source]#

Apply the model to a batch of data.

This method must be implemented for the specific DL-framework and is called with a batch of data. The batch is a dictionary with the following keys:

  • data: A list of numpy arrays with the input data.

  • index_expr: A list of index expressions for the input data.

Note

The output data in the dictionary must be a list of numpy arrays with the same length as the input data. Each data entry must be a numpy array with the shape [C (channels) x H (height) x W (width) x (D (depth))].

Parameters:
  • batch (Dict[str, Any]) – The batch of data.

  • params (InferenceFilterParams) – The filter parameters.

Returns:

The output of the model.

Return type:

Dict[str, Any]

_apply_model(input_array, params)[source]#

Apply the model to the input array to predict the segmentation.

Parameters:
  • input_array (np.ndarray) – The input array for the DL-model.

  • params (InferenceFilterParams) – The filter parameters.

Returns:

The output array of the DL-model.

Return type:

np.ndarray

static _array_to_subject(output_array, subject, params)[source]#

Convert the output array of the DL-model to one or multiple SegmentationImage instances and add them to the provided Subject instance.

Parameters:
Returns:

The Subject instance with the new SegmentationImage instances added.

Return type:

Subject

static is_invertible()[source]#

Returns whether the filter is invertible or not.

Note

If your DL model is invertible, you should override this method and return True.

Returns:

False because the inference filter is typically not invertible.

Return type:

bool

execute(subject, params)[source]#

Execute the filter on the provided Subject instance.

Parameters:
Returns:

The Subject instance with the newly added SegmentationImage instances.

Return type:

Subject

execute_inverse(subject, transform_info, target_image=None)[source]#

Return the provided Subject instance without any processing because the inference of a DL-model is typically not invertible.

Parameters:
  • subject (Subject) – The Subject instance to be returned.

  • transform_info (TransformInfo) – The transform information.

  • target_image (Optional[Union[SegmentationImage, IntensityImage]]) – The target image to which the inverse transformation should be applied. If None, the inverse transformation is applied to all images (default: None).

Returns:

The provided Subject instance.

Return type:

Subject

class IndexingStrategy(loop_axis)[source]#

Bases: ABC

An abstract class that defines the strategy for indexing / looping over the image data content during model inference with an InferenceFilter. The IndexingStrategy is typically assigned to the InferenceFilterParams for getting used by the InferenceFilter.

Parameters:

loop_axis (Optional[int]) – The axis along which the image data should be processed. If None, the image data will be processed as a whole.

abstract __call__(shape)[source]#

Compute the indexing expressions based on the given shape of the image data and the loop_axis attribute.

Parameters:

shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.

Returns:

The indexing expressions.

Return type:

Tuple[Tuple[slice, …], …]

class SliceIndexingStrategy(slice_axis)[source]#

Bases: IndexingStrategy

An indexing strategy class that computes the indexing expressions for slice-wise looping over the image data content.

Parameters:

slice_axis (int) – The axis along which the image data should be sliced.

__call__(shape)[source]#

Compute the indexing expressions for each slice based on the given shape of the image data and the loop_axis attribute.

Parameters:

shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.

Returns:

The indexing expressions.

Return type:

Tuple[Tuple[slice, …], …]

class PatchIndexingStrategy(patch_shape, stride=None)[source]#

Bases: IndexingStrategy

An indexing strategy class that computes the indexing expressions for patch-wise looping over the image data content.

Parameters:
  • patch_shape (Tuple[int, ...]) – The shape of the patches.

  • stride (Optional[Tuple[int, ...]]) – The stride of the patches. If None, the patches will be extracted with the same stride as the patch shape.

__call__(shape)[source]#

Compute the indexing expressions for each patch based on the given shape of the image data and the patch_shape and stride attributes.

Parameters:

shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.

Returns:

The indexing expressions.

Return type:

Tuple[Tuple[slice, …], …]