# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from skimage.filters import threshold_otsu
from InnerEye.ML.config import PhotometricNormalizationMethod, SegmentationModelBase
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.utils.image_util import check_array_range
from InnerEye.ML.utils.transforms import CTRange, LinearTransform, Transform3D
[docs]class WindowNormalizationForScalarItem(Transform3D[ScalarItem]):
"""
Transform3D to apply window normalization to "images" of a ScalarItem.
"""
# noinspection PyMissingConstructor
def __init__(self,
output_range: Tuple[float, float] = (0, 1),
sharpen: float = 1.9,
tail: float = 1.0) -> None:
"""
:param output_range: The desired value range of the result image.
:param sharpen: number of standard deviation either side of mean to include in the window
:param tail: Default 1, allow window range to include more of tail of distribution.
"""
self.output_range = output_range
self.sharpen = sharpen
self.tail = tail
def __call__(self, item: ScalarItem) -> ScalarItem:
return item.clone_with_overrides(
images=torch.tensor(mri_window(image_in=item.images.numpy(),
output_range=self.output_range,
mask=None,
sharpen=self.sharpen,
tail=self.tail)[0],
dtype=item.images.dtype,
device=item.images.device)
)
[docs]class PhotometricNormalization(Transform3D[Sample]):
def __init__(self, config_args: SegmentationModelBase = None, **params: Any):
super().__init__(**params)
if config_args is None:
self.norm_method = PhotometricNormalizationMethod.Unchanged
return
if config_args is not None:
self.norm_method = config_args.norm_method
self.output_range = config_args.output_range
self.level = config_args.level
self.window = config_args.window
self.debug_mode = config_args.debug_mode
self.tail = config_args.tail
self.sharpen = config_args.sharpen
self.trim_percentiles = config_args.trim_percentiles
self.status_of_most_recent_call: Optional[str] = None
def __call__(self, sample: Sample) -> Sample:
return sample.clone_with_overrides(
image=self.transform(
image=sample.image,
mask=sample.mask,
patient_id=sample.patient_id
)
)
[docs]def simple_norm(image_in: np.ndarray, mask: np.ndarray, debug_mode: bool = False) -> np.array:
"""
Normalizes a single image to have mean 0 and standard deviation 1
:param image_in: image to normalize
:param mask: image, has W x H x D
:param debug_mode: whether to log means and SDs
:return: normalized image
"""
if not np.issubdtype(image_in.dtype, np.floating):
raise Exception("normalize::simple_norm: Input image is not a floating type")
image_shape = np.shape(image_in)
nchannel = image_shape[0]
iout = np.zeros(image_shape, dtype=image_in.dtype)
for ichannel in range(nchannel):
i = image_in[ichannel, ...].flatten()
m = mask.flatten()
if debug_mode:
logging.info(" In norm before: Standard Deviation, Mean ,{0: 4.1f}, {1: 4.1f}".format(np.std(i[m == 1]),
np.mean(i[m == 1])))
mean_i = np.mean(i[m == 1])
std_i = np.std(i[m == 1])
i = i - mean_i
i = i / std_i
iout[ichannel, ...] = i.reshape(image_shape[1:])
if debug_mode:
logging.info(" In norm after: Standard Deviation, Mean ,{0: 4.1f}, {1: 4.1f}".format(np.std(i[m == 1]),
np.mean(i[m == 1])))
return iout
[docs]def normalize_trim(image: np.ndarray,
mask: np.ndarray,
output_range: Tuple[float, float] = (-1.0, 1.0),
sharpen: float = 1.9,
trim_percentiles: Tuple[float, float] = (2.0, 98.0),
debug_mode: bool = False) -> np.array:
"""
Normalizes a single image to have mean 0 and standard deviation 1
Normalising occurs after percentile thresholds have been applied to strip out extreme values
:param image: The image to normalize, size Channels x Z x Y x X
:param mask: Consider only pixel values of the input image for which the mask is non-zero. Size Z x Y x X
:param output_range: The desired value range of the result image.
:param sharpen: number of standard deviation either side of mean to include in the window.
:param trim_percentiles: Only consider voxel values between those two percentiles when computing mean and std.
:param debug_mode: If true, create a diagnostic plot (interactive)
:return: trimmed-normalized image
"""
image_shape = image.shape
imout = np.zeros_like(image)
in_mask = mask > 0.5
status = ""
for ichannel in range(image_shape[0]):
if ichannel > 0:
status += "Channel {}: ".format(ichannel)
channel_image = image[ichannel, ...]
pixels_inside_mask = channel_image[in_mask].flatten().astype(float)
# First remove all values that fall outside the trim_percentiles
thresholds = np.percentile(pixels_inside_mask, trim_percentiles, interpolation='midpoint')
lower_threshold = thresholds[0]
upper_threshold = thresholds[1]
above_lower = pixels_inside_mask > lower_threshold
below_upper = pixels_inside_mask < upper_threshold
inside_thresholds = np.logical_and(above_lower, below_upper)
# Compute robust statistics off the pixel values that are inside the trim values
median, estimated_std, min_value, max_value = robust_mean_std(pixels_inside_mask[inside_thresholds])
# Compute an input value range from median and robust std, going as many standard deviations
# as specified by the sharpen parameter
input_range = (max(median - estimated_std * sharpen, min_value),
min(median + estimated_std * sharpen, max_value))
# Use Polynomial transform function to convert data to output range. This also sets values outside
# the input_range to the boundary values.
channel_output = LinearTransform.transform(
data=channel_image,
input_range=input_range,
output_range=output_range
)
channel_output[np.logical_not(in_mask)] = output_range[0]
imout[ichannel, ...] = channel_output
status += "Range ({0:0.0f}, {1:0.0f}) ".format(input_range[0], input_range[1])
logging.info(status)
if debug_mode:
print('median, estimated_std', median, estimated_std)
#
# Normalise values to zero mean and unit variance
#
fig, axs = plt.subplots(2, 2, figsize=(9, 9))
axs[0, 0].set_title("Original Image")
axs[0, 0].imshow(image[0, :, :, 70], cmap='gray')
# axs[1,0].hist(image.flatten(), 100)
axs[1, 0].set_title("Original Image - Histogram with Mask")
axs[1, 0].set_xlim(lower_threshold, upper_threshold)
axs[1, 0].hist(channel_image[in_mask].flatten(), 20)
axs[0, 1].set_title("Normalised Image, Level= {level:4.1f},\n "
"Window range {in1} to {in2}".format(level=median, in1=lower_threshold,
in2=upper_threshold))
axs[0, 1].imshow(imout[0, :, :, 70], cmap='gray')
axs[1, 1].set_title("Normalised Image - Histogram with Mask")
axs[1, 1].hist(channel_image[in_mask], 20)
plt.show()
return imout, status
[docs]def robust_mean_std(data: np.ndarray) -> Tuple[float, float, float, float]:
"""
Computes robust estimates of mean and standard deviation in the given array.
The median is the robust estimate for the mean, the standard deviation is computed from the
inter-quartile ranges.
:param data: The data for which mean and std should be computed.
:return: A 4-tuple with values (median, robust_std, minimum data value, maximum data value)
"""
if data.ndim != 1:
data = data.flatten()
quartiles = np.percentile(data, (0, 25, 50, 75, 100), interpolation='midpoint')
min_value = quartiles[0]
quart25 = quartiles[1]
median = quartiles[2]
quart75 = quartiles[3]
max_value = quartiles[4]
# Estimate standard deviation from inter quartile range:
# Quartile 1 is at -0.67 of the standard normal (Excel NORMSINV(0.25))
# Quartile 3 is at 0.67 of the standard normal (Excel NORMSINV(0.75))
# Inter quartile range hence spans 2 * 0.67 standard deviations
std = (quart75 - quart25) / (2 * 0.67448975)
return median, std, min_value, max_value
[docs]def mri_window(image_in: np.ndarray,
mask: Optional[np.ndarray],
output_range: Tuple[float, float] = (-1.0, 1.0),
sharpen: float = 1.9,
tail: Union[List[float], float] = 1.0,
debug_mode: bool = False) -> Tuple[np.array, str]:
"""
This function takes an MRI Image, removes to first peak of values (air). Then a window range is found centered
around the mean of the remaining values and with a range controlled by the standard deviation and the sharpen
input parameter. The larger sharpen is, the wider the range. The resulting values are the normalised to the given
output_range, with values below and above the range being set the the boundary values.
:param image_in: The image to normalize.
:param mask: Consider only pixel values of the input image for which the mask is non-zero. If None the whole
image is considered.
:param output_range: The desired value range of the result image.
:param sharpen: number of standard deviation either side of mean to include in the window
:param tail: Default 1, allow window range to include more of tail of distribution.
:param debug_mode: If true, create diagnostic plots.
:return: normalized image
"""
nchannel = image_in.shape[0]
imout = np.zeros_like(image_in)
if isinstance(tail, int):
tail = float(tail)
if isinstance(tail, float):
tail = [tail] * nchannel
status = ""
for ichannel in range(nchannel):
if ichannel > 0:
status += "Channel {}: ".format(ichannel)
# Flatten to apply Otsu_thresholding
imflat = image_in[ichannel, ...].flatten()
if mask is None:
maflat = None
in_mask = False
else:
maflat = mask.flatten()
in_mask = mask > 0
# Find Otsu's threshold for the values of the input image
threshold = threshold_otsu(imflat)
# Find window level
level, std_i, _, max_foreground = robust_mean_std(imflat[imflat > threshold])
# If lower value of window is below threshold replace lower value with threshold
input_range = (max(level - std_i * sharpen, threshold),
min(max_foreground, level + tail[ichannel] * std_i * sharpen))
im_thresh = image_in[ichannel, ...]
im_thresh[image_in[ichannel, ...] < threshold] = 0
# Use Polynomial transform function to convert data to output range
imout[ichannel, ...] = LinearTransform.transform(im_thresh, input_range, output_range)
status += f"Otsu {threshold:0.0f}, level {level:0.0f}, range ({input_range[0]:0.0f}, {input_range[1]:0.0f}) "
logging.debug(status)
if debug_mode:
print('Otsu {}, range {}'.format(threshold, input_range))
if mask is None:
no_thresh = np.sum(imflat < threshold)
no_high = np.sum(imout == output_range[1])
pc_thresh = no_thresh / np.numel(imflat) * 100
pc_high = no_high / np.numel(imflat) * 100
else:
no_thresh = np.sum(imflat[maflat == 1] < threshold)
no_high = np.sum(imout == output_range[1])
pc_thresh = no_thresh / np.sum(in_mask) * 100
pc_high = no_high / np.sum(in_mask) * 100
print('Percent of values outside window range: low,high', pc_thresh, pc_high, no_high)
with open("channels_trim.txt", 'a') as fileout:
fileout.write("Thresholded: {ich:d}, {pl:4.2f}, {ph:4.2f} \n".format(ich=ichannel,
pl=pc_thresh,
ph=pc_high))
# Plot input distribution
fig, axs = plt.subplots(2, 2, figsize=(9, 9))
axs[0, 0].set_title("Original Image")
axs[0, 0].imshow(image_in[ichannel, :, :, 70], cmap='gray')
# axs[1,0].hist(image.flatten(), 100)
axs[1, 0].set_title("Original Image - Histogram with Mask")
if mask is None:
axs[1, 0].hist(image_in[ichannel, ...].flatten(), 200)
else:
axs[1, 0].hist(image_in[ichannel, ...][in_mask].flatten(), 200)
axs[0, 1].set_title("Normalised Image, Level= {level:4.1f},\n "
"Window range {in1:4.1f} to {in2:4.1f}, \n"
"{pt:4.1f} % below threshold, {ph:4.1f} % above window \n"
"Threshold= {th:4.1f}"
.format(level=level, in1=input_range[0], in2=input_range[1], pt=pc_thresh,
ph=pc_high, th=threshold))
axs[0, 1].imshow(imout[ichannel, :, :, 70], cmap='gray')
axs[1, 1].set_title("Normalised Image - Histogram with Mask")
if mask is None:
axs[1, 1].hist(imout[ichannel, ...].flatten(), 200)
else:
axs[1, 1].hist(imout[ichannel, ...][in_mask].flatten(), 200)
plt.show()
return imout, status