Source code for medipt.resample_image.resample_image

import numpy as np
import SimpleITK as sitk
from typing import Union, Tuple, List, Callable, Iterable


[docs]def get_sitk_interpolator(interpolator): """ Return an sitk interpolator object for the given string. :param interpolator: Interpolator type as string. 'nearest': sitk.sitkNearestNeighbor 'linear': sitk.sitkLinear 'cubic': sitk.sitkBSpline 'label_gaussian': sitk.sitkLabelGaussian 'gaussian': sitk.sitkGaussian 'lanczos': sitk.sitkLanczosWindowedSinc :return: The sitk interpolator object. """ if isinstance(interpolator, str): if interpolator == 'nearest': return sitk.sitkNearestNeighbor elif interpolator == 'linear': return sitk.sitkLinear elif interpolator == 'cubic': return sitk.sitkBSpline elif interpolator == 'label_gaussian': return sitk.sitkLabelGaussian elif interpolator == 'gaussian': return sitk.sitkGaussian elif interpolator == 'lanczos': return sitk.sitkLanczosWindowedSinc elif isinstance(interpolator, int): return interpolator else: raise Exception('invalid interpolator type')
[docs]def get_simpleitk_pix_type(pixel_type: Union[str, int]) -> int: """ Returns the SimpleITK pixel type for the given pixel type string. :param pixel_type: The pixel type as string. 'float32': sitk.sitkFloat32 'float64': sitk.sitkFloat64 'int8': sitk.sitkInt8 'int16': sitk.sitkInt16 'int32': sitk.sitkInt32 'int64': sitk.sitkInt64 'uint8': sitk.sitkUInt8 'uint16': sitk.sitkUInt16 'uint32': sitk.sitkUInt32 'uint64': sitk.sitkUInt64 'complex64': sitk.sitkComplexFloat32 'complex128': sitk.sitkComplexFloat64 :return: The SimpleITK pixel type. """ if isinstance(pixel_type, str): if pixel_type == 'float32': return sitk.sitkFloat32 elif pixel_type == 'float64': return sitk.sitkFloat64 elif pixel_type == 'int8': return sitk.sitkInt8 elif pixel_type == 'int16': return sitk.sitkInt16 elif pixel_type == 'int32': return sitk.sitkInt32 elif pixel_type == 'int64': return sitk.sitkInt64 elif pixel_type == 'uint8': return sitk.sitkUInt8 elif pixel_type == 'uint16': return sitk.sitkUInt16 elif pixel_type == 'uint32': return sitk.sitkUInt32 elif pixel_type == 'uint64': return sitk.sitkUInt64 elif pixel_type == 'complex64': return sitk.sitkComplexFloat32 elif pixel_type == 'complex128': return sitk.sitkComplexFloat64 elif isinstance(pixel_type, int): return pixel_type else: raise Exception('invalid pixel type')
[docs]class ResampleImage: def __init__(self, interpolator: Union[str, int] = 'linear', post_processing_sitk: Union[Callable, Union[List[Callable]], Tuple[Callable, ...]] = None, pixel_type: [str, int] = 'float32', default_pixel_value: Union[float, int, None] = None, dim: int = 3, used_dimensions: bool = None, seed: Union[np.random.RandomState, np.random.Generator, np.random.BitGenerator, int, None] = None, legacy_random_state: bool = True, *args, **kwargs): self.used_dimensions = used_dimensions or [True] * dim self.dim = dim self.seed = seed self.legacy_random_state = legacy_random_state self.interpolator = interpolator self.post_processing_sitk = post_processing_sitk self.pixel_type = pixel_type self.default_pixel_value = default_pixel_value
[docs] def resample_image(self, image: sitk.Image, input_output_space_dict: dict, reference_image: sitk.Image = None, transform: sitk.Transform = None, *args, **kwargs) -> sitk.Image: # if image is None: # image = input_output_space_dict.get('input_image', None) # if image is None: # image = input_output_space_dict.get('image', None) # if image is None: # raise Exception('No image was provided.') output_size = input_output_space_dict['output_size'] output_spacing = input_output_space_dict['output_spacing'] output_origin = input_output_space_dict['output_origin'] output_direction = input_output_space_dict['output_direction'] sitk_interpolator = get_sitk_interpolator(self.interpolator) if self.default_pixel_value is None: default_value_filter = sitk.MinimumMaximumImageFilter() default_value_filter.Execute(image) default_pixel_value = default_value_filter.GetMinimum() else: default_pixel_value = self.default_pixel_value resample_filter = sitk.ResampleImageFilter() if reference_image is not None: resample_filter.SetReferenceImage(reference_image) else: resample_filter.SetOutputSpacing(output_spacing) resample_filter.SetOutputOrigin(output_origin) resample_filter.SetOutputDirection(output_direction) resample_filter.SetSize(output_size) resample_filter.SetDefaultPixelValue(default_pixel_value) if transform is None: resample_filter.SetTransform(sitk.AffineTransform(self.dim)) else: resample_filter.SetTransform(transform) resample_filter.SetInterpolator(sitk_interpolator) if self.pixel_type is not None: sitk_pix_type = get_simpleitk_pix_type(self.pixel_type) resample_filter.SetOutputPixelType(sitk_pix_type) resampled_image = resample_filter.Execute(image) return resampled_image
[docs] def get_images(self, images: Union[List[sitk.Image], Tuple[sitk.Image, ...], sitk.Image], input_output_space_dict: dict, reference_image: sitk.Image = None, transform: sitk.Transform = None, *args, **kwargs ): if isinstance(images, (list, tuple)): return [self.resample_image(image, input_output_space_dict, reference_image, transform, *args, **kwargs) for image in images] else: return self.resample_image(images, input_output_space_dict, reference_image, transform, *args, **kwargs)
[docs] def get_resampled_images(self, image: Union[List[sitk.Image], Tuple[sitk.Image, ...], sitk.Image], input_output_space_dict: dict, reference_image: sitk.Image = None, transform: sitk.Transform = None, *args, **kwargs ) -> Union[List[sitk.Image], Tuple[sitk.Image, ...], sitk.Image]: output_image_sitk = self.get_images(image, input_output_space_dict, reference_image, transform, *args, **kwargs) if self.post_processing_sitk is not None: if isinstance(self.post_processing_sitk, (list, tuple)): for post_processing_sitk in self.post_processing_sitk: output_image_sitk = post_processing_sitk(output_image_sitk) else: output_image_sitk = self.post_processing_sitk(output_image_sitk) return output_image_sitk