Source code for medipt.transforms.intensity.shift_scale_intensity

from typing import Union, Tuple, List
import numpy as np
from .intensity_utils_sitk import shift_scale_intensity_sitk
from .intensity_utils_np import shift_scale_intensity_np
from ...utils import random_uniform_float
from ...utils.random_float import initialize_rand_state
import SimpleITK as sitk
import warnings


# class ShiftScaleIntensity:
#     def __init__(self,
#                  shift: Union[int, float] = None,
#                  scale: Union[int, float] = None,
#                  *args, **kwargs):
#
#         self.shift = shift
#         self.scale = scale
#
#
#     def __call__(self,
#                  image: Union[np.ndarray, sitk.Image],
#                  *args, **kwargs) -> Union[np.ndarray, sitk.Image]:
#
#         if isinstance(image, np.ndarray):
#             output_image = shift_scale_intensity_np(image,
#                                                     shift=self.shift,
#                                                     scale=self.scale,
#                                                     *args, **kwargs)
#         elif isinstance(image, sitk.Image):
#             output_image = shift_scale_intensity_sitk(image,
#                                                       shift=self.shift,
#                                                       scale=self.scale,
#                                                       *args, **kwargs)
#
#         else:
#             raise ImportError("image must be either a numpy array or a SimpleITK image.")
#
#         return output_image



[docs]class ShiftScaleIntensity: def __init__(self, shift: Union[int, float] = None, scale: Union[int, float] = None, input_min: Union[int, float] = None, input_max: Union[int, float] = None, output_min: Union[int, float] = None, output_max: Union[int, float] = None, random_shift: Union[int, float] = None, random_scale: Union[int, float] = None, seed: Union[np.random.RandomState, np.random.Generator, np.random.BitGenerator, int, None] = None, legacy_random_state: bool = True, *args, ** kwargs): self.shift = shift self.scale = scale self.input_min = input_min self.input_max = input_max self.output_min = output_min self.output_max = output_max self.random_shift = random_shift self.random_scale = random_scale self.seed = seed self.legacy_random_state = legacy_random_state self.rand_init = initialize_rand_state(self.seed, self.legacy_random_state) def __call__(self, image: Union[np.ndarray, sitk.Image], *args, **kwargs) -> Union[np.ndarray, sitk.Image]: if (self.shift is None) and (self.scale is None) and (self.input_min is None) and (self.input_max is None) and (self.output_min is None) and (self.output_max is None): raise ValueError("At least one of shift, scale, input_min, input_max, output_min, output_max must be provided.") if ((self.shift is None) and (self.scale is None)) and ((self.input_min is None) and (self.input_max is None) and (self.output_min is None) and (self.output_max is None)): raise ValueError("Either shift and scale or input_min, input_max, output_min, output_max must be provided.") if (self.shift is not None) and (self.scale is not None) and (self.input_min is not None) and (self.input_max is not None) and (self.output_min is not None) and (self.output_max is not None): warnings.warn("Mixing both shift and scale with input_min, input_max, output_min, output_max. Input and output values will be set to 'None'.") self.input_min = None self.input_max = None self.output_min = None self.output_max = None if (self.shift is None) and (self.scale is None) and (self.input_min is not None) and (self.input_max is not None) and (self.output_min is not None) and (self.output_max is not None): self.shift = -self.input_min + self.output_min * (self.input_max - self.input_min) / (self.output_max - self.output_min) self.scale = (self.output_max - self.output_min) / (self.input_max - self.input_min) if isinstance(image, np.ndarray): output_image = shift_scale_intensity_np(image, shift=self.shift, scale=self.scale, *args, **kwargs) elif isinstance(image, sitk.Image): output_image = shift_scale_intensity_sitk(image, shift=self.shift, scale=self.scale, *args, **kwargs) else: raise ImportError("image must be either a numpy array or a SimpleITK image.") if (self.random_shift is not None) or (self.random_scale is not None): if self.random_shift: random_shift = random_uniform_float(low_value=-self.random_shift, high_value=self.random_shift, seed=self.seed, legacy_random_state=self.legacy_random_state, rand_init=self.rand_init) else: random_shift = 0 if self.random_scale: random_scale = random_uniform_float(low_value=-self.random_scale, high_value=self.random_scale, seed=self.seed, legacy_random_state=self.legacy_random_state, rand_init=self.rand_init) else: random_scale = 1 random_scale = 1 + random_scale if isinstance(image, np.ndarray): output_image = shift_scale_intensity_np(output_image, shift=random_shift, scale=random_scale, *args, **kwargs) elif isinstance(image, sitk.Image): output_image = shift_scale_intensity_sitk(output_image, shift=random_shift, scale=random_scale, *args, **kwargs) return output_image else: return output_image