Source code for medipt.transforms.spatial.composite_transform
import SimpleITK as sitk
import numpy as np
from typing import Union, Tuple, List
from .spatial_transform import SpatialTransform
from .elastic_deformation_transform import ElasticDeformation
[docs]def composite_transform(transforms: List[Union[sitk.AffineTransform, sitk.BSplineTransform, sitk.DisplacementFieldTransform]],
dim: int = 3):
if sitk.Version_MajorVersion() == 1:
compos = sitk.Transform(dim, sitk.sitkIdentity)
for transformation in transforms:
compos.AddTransform(transformation)
else:
compos = sitk.CompositeTransform(transforms)
return compos
[docs]class CompositeTransform:
def __init__(self,
*args, **kwargs):
self.transforms = []
[docs] def add_transforms(self,
transform: Union[List[Union[SpatialTransform, sitk.Transform]],
Tuple[Union[SpatialTransform, sitk.Transform], ...],
Union[SpatialTransform, sitk.Transform]]):
if isinstance(transform, (tuple, list)):
self.transforms.extend(transform)
else:
self.transforms.append(transform)
[docs] def create_composite(self, dim: int = 3) -> sitk.Transform:
compos = sitk.CompositeTransform(dim)
for transform in self.transforms:
if isinstance(transform, SpatialTransform):
sitk_transform = transform.transform
compos.AddTransform(sitk_transform)
else:
compos.AddTransform(transform.transform)
return compos
[docs] def create_inverse_composite(self, dim: int = 3,
use_displacement_field: bool = False) -> sitk.CompositeTransform:
compos = sitk.CompositeTransform(dim)
for transform in self.transforms[::-1]:
if isinstance(transform, SpatialTransform):
if isinstance(transform, ElasticDeformation):
if use_displacement_field:
transform.get_inverted_transform_from_displacement()
sitk_inverse_transform = transform.inverted_transform_from_displacement
else:
transform.get_inverse_transform()
sitk_inverse_transform = transform.inverse_transform
else:
transform.get_inverse_transform()
sitk_inverse_transform = transform.inverse_transform
# if use_displacement_field:
# transform.get_inverted_transform_from_displacement()
# sitk_inverse_transform = transform.inverted_transform_from_displacement
#
# else:
# transform.get_inverse_transform()
# sitk_inverse_transform = transform.inverse_transform
compos.AddTransform(sitk_inverse_transform)
else:
if use_displacement_field:
raise Warning("The use_displacement_field option is not implemented for SimpleITK transforms.")
compos.AddTransform(transform.GetInverse())
return compos