Source code for medipt.transforms.spatial.flipping_transform

from typing import Union, Tuple, List
import SimpleITK as sitk
import numpy as np

from .spatial_transform import SpatialTransform
from ...utils import random_binomial


[docs]class FlippingTransform(SpatialTransform): def _get_transform(self, flip_axes: Union[List[float], Tuple[float, ...], float], *args, **kwargs): if isinstance(flip_axes, (tuple, list, np.ndarray)): current_flip_axes = [] if len(flip_axes) > 1: if len(flip_axes) != self.dim: raise ValueError(f'flip axes must be a tuple or list of length {self.dim}.') for flip_axis in flip_axes: current_flip_axes.append(bool(flip_axis)) elif isinstance(flip_axes, (int, float, np.integer, np.floating, np.bool_, bool)): current_flip_axes = [bool(flip_axes)] * self.dim else: raise ValueError('flip axes must be tuples, lists, or numbers.') self.transform = sitk.AffineTransform(self.dim) scale_factors = [-1.0 if f else 1.0 for f in current_flip_axes] self.transform.Scale(scale_factors)
[docs] def get_transform(self, flip_axes: Union[List[float], Tuple[float, ...], float], *args, **kwargs ): self._get_transform(flip_axes, *args, **kwargs)
[docs]class RandomFlipping(FlippingTransform):
[docs] def get_random_transform(self, flip_axes: Union[ List[Union[int, float, bool]], Tuple[Union[int, float, bool], ...], Union[int, float, bool, np.integer, np.floating], np.ndarray], *args, **kwargs, ): if isinstance(flip_axes, (tuple, list, np.ndarray)): if isinstance(flip_axes, np.ndarray): probability = flip_axes * 0.5 elif isinstance(flip_axes, (list, tuple)): probability = 0.5 * np.array(flip_axes) else: raise ValueError('flip axes must be tuples, lists, or numbers.') elif isinstance(flip_axes, (int, float, np.integer, np.floating, np.bool_, bool)): probability = 0.5 * flip_axes else: raise ValueError('flip axes must be tuples, lists, or numbers.') current_flip_axis = random_binomial(n=1, p=probability, seed=self.seed, legacy_random_state=self.legacy_random_state, rand_init=self.rand_init) self._get_transform(current_flip_axis, *args, **kwargs)
# # if isinstance(flip_axes, (tuple, list, np.ndarray)): # current_flip_axes = [] # if len(flip_axes) > 1: # if len(flip_axes) != self.dim: # raise ValueError(f'flip axes must be a tuple or list of length {self.dim}.') # # for flip_axis in flip_axes: # # if flip_axis == 1: # current_flip_axis = random_binomial(n=1, p=0.5, # seed=self.seed, # legacy_random_state=self.legacy_random_state, # rand_init=self.rand_init) # else: # current_flip_axis = 0 # # current_flip_axes.append(bool(current_flip_axis)) # # elif isinstance(flip_axes, (int, float, np.integer, np.floating, np.bool_, bool)): # if flip_axes == 1: # current_flip_axis = random_binomial(n=1, p=0.5, # seed=self.seed, # legacy_random_state=self.legacy_random_state, # rand_init=self.rand_init) # # else: # current_flip_axis = 0 # current_flip_axes = [bool(current_flip_axis)] * self.dim # # else: # raise ValueError('flip axes must be tuples, lists, or numbers.') # # self._get_transform(current_flip_axes, *args, **kwargs)