Source code for medipt.transforms.spatial.spatial_transform
from abc import ABC, abstractmethod
import numpy as np
import SimpleITK as sitk
from typing import Union, Tuple, List
from types import ModuleType
from ...utils import image_index_to_phys
from ...utils.random_float import initialize_rand_state
from ...utils.image_input_output_space import resolve_input_output_space
[docs]class SpatialTransform(ABC):
"""
A generic spatial transform that can be applied to 2D and 3D images.
"""
def __init__(self,
dim: int = 3,
used_dimensions: bool = None,
seed: Union[np.random.RandomState, np.random.Generator, np.random.BitGenerator, int, None] = None,
legacy_random_state: bool = True,
rand_init: Union[ModuleType, np.random.Generator, np.random.BitGenerator] = None,
*args, **kwargs):
"""
Initializer.
:param dim: The dimension of the transformation.
:param args: Arguments passed to super init.
:param kwargs: Keyword arguments passed to super init.
"""
self.used_dimensions = used_dimensions or [True] * dim
self.dim = dim
self.seed = seed
self.legacy_random_state = legacy_random_state
self.transform = None
self.inverse_transform = None
self.displacement_field = None
self.inverted_displacement_field = None
self.inverted_transform_from_displacement = None
if rand_init is None:
self.rand_init = initialize_rand_state(self.seed, self.legacy_random_state)
else:
self.rand_init = rand_init
assert len(self.used_dimensions) == dim, 'Length of used_dimensions must be equal to dim.'
self.input_size = None
self.input_space = None
self.input_origin = None
self.input_direction = None
self.image_params_gathered = False
self.output_size = None
self.output_spacing = None
self.output_origin = None
self.output_direction = None
@abstractmethod
def _get_transform(self,
transform_params: Union[List[float], Tuple[float, ...], float],
*args, **kwargs):
raise NotImplementedError('This method must be implemented in a child class.')
[docs] def get_inverse_transform(self, *args, **kwargs):
if self.transform is not None:
self.inverse_transform = self.transform.GetInverse()
else:
raise ValueError('No transform found. Call get_transform first.')
# return self.inverse_transform
[docs] def get_displacement_field(self, *args, **kwargs):
raise NotImplementedError('This method must be implemented in a child class.')
[docs] def get_inverted_displacement_field(self, *args, **kwargs):
raise NotImplementedError('This method must be implemented in a child class.')
[docs] def get_inverted_transform_from_displacement(self, *args, **kwargs):
self.get_inverse_transform(*args, **kwargs)
[docs] def get_image_params(
self,
image: sitk.Image = None,
image_size: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
image_spacing: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
image_origin: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
image_direction: Union[List[Union[int, float]], Tuple[Union[int, float]]] = None,
*args, **kwargs) -> Tuple[
Union[List[Union[int, float]], Tuple[Union[int, float]], np.ndarray],
Union[List[Union[int, float]], Tuple[Union[int, float]], np.ndarray],
Union[List[Union[int, float]], Tuple[Union[int, float]], np.ndarray],
Union[List[Union[int, float]], Tuple[Union[int, float]], np.ndarray]]:
if image is not None:
image_size = image.GetSize()
image_spacing = image.GetSpacing()
image_origin = image.GetOrigin()
image_direction = image.GetDirection()
else:
image_size = kwargs.get('image_size', image_size)
image_spacing = kwargs.get('image_spacing', image_spacing)
image_origin = kwargs.get('image_origin', image_origin)
image_direction = kwargs.get('image_direction', image_direction)
if image_spacing is None:
image_spacing = [1.0] * self.dim
if image_origin is None:
image_origin = np.zeros(self.dim)
if image_direction is None:
image_direction = np.eye(self.dim).flatten(order='F').tolist()
return image_size, image_spacing, image_origin, image_direction
[docs] def get_input_output_space(self,
# input_image: sitk.Image = None,
# input_size: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# output_size: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# input_spacing: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
#
#
# output_spacing: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# input_direction: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# output_direction: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# input_origin: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
# output_origin: Union[List[Union[int, float]], Tuple[Union[int, float], ...]] = None,
dict: dict = None,
*args, **kwargs):
if dict is None:
dict = resolve_input_output_space(*args, **kwargs)
elif len(dict) == 0:
dict = resolve_input_output_space(*args, **kwargs)
input_size = dict['input_size']
input_spacing = dict['input_spacing']
input_origin = dict['input_origin']
input_direction = dict['input_direction']
output_size = dict['output_size']
output_spacing = dict['output_spacing']
output_origin = dict['output_origin']
output_direction = dict['output_direction']
self.input_size = input_size
self.input_spacing = input_spacing
self.input_origin = input_origin
self.input_direction = input_direction
self.output_size = output_size
self.output_spacing = output_spacing
self.output_origin = output_origin
self.output_direction = output_direction
self.image_params_gathered = True
[docs] def get_input_origin(self,
*args, **kwargs) -> List[float]:
if self.image_params_gathered is False:
self.get_input_output_space(*args, **kwargs)
input_origin_index = (0.0,) * self.dim
input_origin_phys = image_index_to_phys(dim=self.dim,
index_coords=input_origin_index,
spacing=self.input_spacing,
direction=self.input_direction,
origin=self.input_origin)
return input_origin_phys
[docs] def get_output_origin(self, *args, **kwargs) -> List[float]:
if self.image_params_gathered is False:
self.get_input_output_space(*args, **kwargs)
output_origin_index = (0.0,) * self.dim
output_origin_phys = image_index_to_phys(dim=self.dim,
index_coords=output_origin_index,
spacing=self.output_spacing,
direction=self.output_direction,
origin=self.output_origin)
return output_origin_phys
[docs] def get_input_center(self, *args, **kwargs) -> List[float]:
if self.image_params_gathered is False:
self.get_input_output_space(*args, **kwargs)
input_center_index = tuple([(i - 1) / 2 for i in self.input_size])
input_center_phys = image_index_to_phys(dim=self.dim,
index_coords=input_center_index,
spacing=self.input_spacing,
direction=self.input_direction,
origin=self.input_origin)
return input_center_phys
[docs] def get_random_point(self,
*args, **kwargs) -> List[float]:
if self.image_params_gathered is False:
self.get_input_output_space(*args, **kwargs)
input_center_index = tuple([(i - 1) / 2 for i in self.input_size])
input_center_index_min = np.array(input_center_index) - (np.array(input_center_index) // 4)
input_center_index_max = np.array(input_center_index) + (np.array(input_center_index) // 4)
input_index = self.rand_init.randint(low=input_center_index_min, high=input_center_index_max)
input_index = (int(input_index[0]), int(input_index[1]), int(input_index[2]))
input_center_phys = image_index_to_phys(dim=self.dim,
index_coords=input_index,
spacing=self.input_spacing,
direction=self.input_direction,
origin=self.input_origin)
return input_center_phys
[docs] def get_output_center(self, *args, **kwargs) -> List[float]:
if self.image_params_gathered is False:
self.get_input_output_space(*args, **kwargs)
output_center_index = tuple([(i - 1) / 2 for i in self.output_size])
output_center_phys = image_index_to_phys(dim=self.dim,
index_coords=output_center_index,
spacing=self.output_spacing,
direction=self.output_direction,
origin=self.output_origin)
return output_center_phys