Inference Module#
Module: pyradise.process.inference
General#
The inference module provides a prototype implementation of a DL-framework agnostic and filter-based inference class
(i.e. InferenceFilter) which is left for implementation to the user such that
the installation of a DL-framework is not required when installing PyRaDiSe.
Class Overview#
The following classes are provided by the inference module:
Class |
Description |
|---|---|
Parameterization class for the |
|
Filter prototype for deep learning model inference. |
|
An abstract base class for all |
|
Indexing strategy for slice-wise looping (for 2D models). |
|
Indexing strategy for patch-wise looping (for 3D models). |
Details#
- class InferenceFilterParams(model, model_path, modalities, reference_modality, output_organs, output_annotator, organ_indices, batch_size, indexing_strategy)[source]#
Bases:
FilterParamsA filter parameter class for the prototype
InferenceFilterclass.- Parameters:
model (Any) – The model to apply.
model_path (Optional[str]) – The path to the model parameters.
modalities (Tuple[Union[str, Modality], ...]) – The
Modalitys of theIntensityImageinstances to use for inference.reference_modality (Union[str, Modality]) – The
Modalitythat is used as the reference to define the output properties of the createdSegmentationImageinstances.output_organs (Tuple[Union[str, Organ], ...]) – The organs that get assigned to the created
SegmentationImageinstances.output_annotator (Union[str, Annotator]) – The annotator that get assigned to the created
SegmentationImageinstances.organ_indices (Tuple[int, ...]) – The indices of the organs on the output mask of the model (must match output_organs and output_annotators).
batch_size (int) – The batch size to use for inference.
indexing_strategy (IndexingStrategy) – The
IndexingStrategydefining how the data is fed to the model.
- class InferenceFilter(warning_on_non_invertible=False)[source]#
Bases:
FilterA prototype filter class for applying a DL-model to a
Subjectinstance.This class is a prototype for applying a DL-model to a
Subjectinstance. PyRaDiSe provides just a prototype for this filter such that it stays DL framework-agnostic. Therefore, the actual implementation of the DL-framework specific methods must be implemented in a subclass.For implementing a DL-framework specific
InferenceFilter, the following methods must be implemented:_prepare_model(): Prepare the model for inference (e.g. load the parameters from a model file)._infer_on_batch(): Apply the model to a batch of data such that the output shape can be inserted into the new image via the indexing expressions provided by the chosenIndexingStrategy.
Example
Implementation example of a PyTorch-based
InferenceFiltersubclass:>>> import torch >>> import torch.nn as nn >>> >>> class ExampleInferenceFilter(InferenceFilter): >>> >>> def __init__(self): >>> super().__init__() >>> >>> # Define the device on which the model should be run >>> self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> >>> # Define a class attribute for the model >>> self.model: Optional[nn.Module] = None >>> >>> def _prepare_model(self, model: nn.Module, model_path: str) -> nn.Module: >>> >>> # Load model parameters >>> model.load_state_dict(torch.load(model_path, map_location=self.device)) >>> >>> # Assign the model to the class >>> self.model = model.to(self.device) >>> >>> # Set model to evaluation mode >>> self.model.eval() >>> >>> return model >>> >>> def _infer_on_batch(self, >>> batch: Dict[str, Any], >>> params: InferenceFilterParams >>> ) -> Dict[str, Any]: >>> # Stack and adjust the numpy array such that it fits the >>> # [batch, channel / images, height, width, (depth)] format >>> # Note: The following statement works for slice-wise and patch-wise processing >>> if (loop_axis := params.indexing_strategy.loop_axis) is None: >>> adjusted_input = np.stack(batch['data'], axis=0) >>> else: >>> adjusted_input = np.stack(batch['data'], axis=0).squeeze(loop_axis + 2) >>> >>> # Generate a tensor from the numpy array >>> input_tensor = torch.from_numpy(adjusted_input) >>> >>> # Move the batch to the same device as the model >>> input_tensor = input_tensor.to(self.device, dtype=torch.float32) >>> >>> # Apply the model to the batch >>> with torch.no_grad(): >>> output_tensor = self.model(input_tensor) >>> >>> # Retrieve the predicted classes from the output >>> if type(params.indexing_strategy) is SliceIndexingStrategy: >>> # Slice-wise processing >>> >>> if len(params.output_organs) > 1: >>> # For multi-class segmentation >>> final_activation_fn = nn.Softmax2d() >>> output_tensor = torch.argmax(final_activation_fn(output_tensor), dim=1) >>> >>> else: >>> # For binary segmentation >>> final_activation_fn = nn.Sigmoid() >>> output_tensor = (final_activation_fn(output_tensor) > 0.5).bool() >>> >>> elif type(params.indexing_strategy) is PatchIndexingStrategy: >>> >>> if len(params.output_organs) > 1: >>> # For multi-class segmentation >>> final_activation_fn = nn.Softmax(dim=1) >>> output_tensor = torch.argmax(final_activation_fn(output_tensor), dim=1) >>> >>> else: >>> # For binary segmentation >>> final_activation_fn = nn.Sigmoid() >>> output_tensor = (final_activation_fn(output_tensor) > 0.5).bool() >>> >>> else: >>> raise NotImplementedError(f'Indexing strategy {type(params.indexing_strategy).__name__} not' >>> 'implemented.') >>> >>> # Convert the output to a numpy array >>> # Note: The output shape must be [batch, height, width, (depth)] >>> output_array = output_tensor.cpu().numpy() >>> >>> # Construct a list of output arrays such that it fits the index expressions >>> batch_output_list = [] >>> for i in range(output_array.shape[0]): >>> batch_output_list.append(output_array[i, ...]) >>> >>> # Combine the output arrays into a dictionary >>> output = {'data': batch_output_list, >>> 'index_expr': batch['index_expr']} >>> >>> return output
- static _get_input_array(subject, params)[source]#
Return the input array for the DL-model.
Note
This function returns the concatenated data in C (channels) x H (height) x W (width) x D (depth) order.
- Parameters:
params (InferenceFilterParams) – The filter parameters.
- Returns:
The input array for the DL-model.
- Return type:
np.ndarray
- abstract _prepare_model(model, model_path)[source]#
Prepare the model for inference (e.g. loading the model parameters). The loaded model must be added to a class attribute such that it can be accessed by all methods.
This method must be implemented for the specific DL-framework.
- Parameters:
model (Any) – The model instance.
model_path (str) – The path to the model parameters.
- Returns:
The model prepared for inference.
- Return type:
Any
- abstract _infer_on_batch(batch, params)[source]#
Apply the model to a batch of data.
This method must be implemented for the specific DL-framework and is called with a batch of data. The batch is a dictionary with the following keys:
data: A list of numpy arrays with the input data.index_expr: A list of index expressions for the input data.
Note
The output data in the dictionary must be a list of numpy arrays with the same length as the input data. Each
dataentry must be a numpy array with the shape [C (channels) x H (height) x W (width) x (D (depth))].- Parameters:
batch (Dict[str, Any]) – The batch of data.
params (InferenceFilterParams) – The filter parameters.
- Returns:
The output of the model.
- Return type:
Dict[str, Any]
- _apply_model(input_array, params)[source]#
Apply the model to the input array to predict the segmentation.
- Parameters:
input_array (np.ndarray) – The input array for the DL-model.
params (InferenceFilterParams) – The filter parameters.
- Returns:
The output array of the DL-model.
- Return type:
np.ndarray
- static _array_to_subject(output_array, subject, params)[source]#
Convert the output array of the DL-model to one or multiple
SegmentationImageinstances and add them to the providedSubjectinstance.- Parameters:
output_array (np.ndarray) – The output array of the DL-model.
subject (Subject) – The
Subjectinstance to which the newSegmentationImageinstances will be added.params (InferenceFilterParams) – The filter parameters.
- Returns:
The
Subjectinstance with the newSegmentationImageinstances added.- Return type:
- static is_invertible()[source]#
Returns whether the filter is invertible or not.
Note
If your DL model is invertible, you should override this method and return
True.- Returns:
False because the inference filter is typically not invertible.
- Return type:
bool
- execute(subject, params)[source]#
Execute the filter on the provided
Subjectinstance.- Parameters:
params (InferenceFilterParams) – The filter parameters.
- Returns:
The
Subjectinstance with the newly addedSegmentationImageinstances.- Return type:
- execute_inverse(subject, transform_info, target_image=None)[source]#
Return the provided
Subjectinstance without any processing because the inference of a DL-model is typically not invertible.- Parameters:
transform_info (TransformInfo) – The transform information.
target_image (Optional[Union[SegmentationImage, IntensityImage]]) – The target image to which the inverse transformation should be applied. If None, the inverse transformation is applied to all images (default: None).
- Returns:
The provided
Subjectinstance.- Return type:
- class IndexingStrategy(loop_axis)[source]#
Bases:
ABCAn abstract class that defines the strategy for indexing / looping over the image data content during model inference with an
InferenceFilter. TheIndexingStrategyis typically assigned to theInferenceFilterParamsfor getting used by theInferenceFilter.- Parameters:
loop_axis (Optional[int]) – The axis along which the image data should be processed. If None, the image data will be processed as a whole.
- abstract __call__(shape)[source]#
Compute the indexing expressions based on the given shape of the image data and the
loop_axisattribute.- Parameters:
shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.
- Returns:
The indexing expressions.
- Return type:
Tuple[Tuple[slice, …], …]
- class SliceIndexingStrategy(slice_axis)[source]#
Bases:
IndexingStrategyAn indexing strategy class that computes the indexing expressions for slice-wise looping over the image data content.
- Parameters:
slice_axis (int) – The axis along which the image data should be sliced.
- __call__(shape)[source]#
Compute the indexing expressions for each slice based on the given shape of the image data and the
loop_axisattribute.- Parameters:
shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.
- Returns:
The indexing expressions.
- Return type:
Tuple[Tuple[slice, …], …]
- class PatchIndexingStrategy(patch_shape, stride=None)[source]#
Bases:
IndexingStrategyAn indexing strategy class that computes the indexing expressions for patch-wise looping over the image data content.
- Parameters:
patch_shape (Tuple[int, ...]) – The shape of the patches.
stride (Optional[Tuple[int, ...]]) – The stride of the patches. If None, the patches will be extracted with the same stride as the patch shape.
- __call__(shape)[source]#
Compute the indexing expressions for each patch based on the given shape of the image data and the
patch_shapeandstrideattributes.- Parameters:
shape (Tuple[int, ...]) – The shape of the image data for which the indexing expressions should be computed.
- Returns:
The indexing expressions.
- Return type:
Tuple[Tuple[slice, …], …]