diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 3fd33b76da..b2dcb965e3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -376,9 +376,9 @@ SignalRandAddSquarePulsePartial, SignalRandDrop, SignalRandScale, - SignalRandShift, - SignalRemoveFrequency, + SignalRemoveFrequency ) +from .signal import RadialFourier3D, RadialFourierFeatures3D from .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict from .smooth_field.array import ( RandSmoothDeform, diff --git a/monai/transforms/signal/__init__.py b/monai/transforms/signal/__init__.py index 1e97f89407..5ed71ccb0e 100644 --- a/monai/transforms/signal/__init__.py +++ b/monai/transforms/signal/__init__.py @@ -8,3 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Signal processing transforms for medical imaging. +""" + +from .radial_fourier import RadialFourier3D, RadialFourierFeatures3D + +__all__ = ["RadialFourier3D", "RadialFourierFeatures3D"] diff --git a/monai/transforms/signal/radial_fourier.py b/monai/transforms/signal/radial_fourier.py new file mode 100644 index 0000000000..e58aefe7e5 --- /dev/null +++ b/monai/transforms/signal/radial_fourier.py @@ -0,0 +1,350 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +3D Radial Fourier Transform for medical imaging data. +""" + +from __future__ import annotations + +import math +from typing import Optional, Union + +from collections.abc import Sequence + +import numpy as np +import torch +from torch.fft import fftn, fftshift, ifftn, ifftshift + +from monai.config import NdarrayOrTensor +from monai.transforms.transform import Transform +from monai.utils import convert_data_type, optional_import + +# Optional imports for type checking +spatial, _ = optional_import("monai.utils", name="spatial") + + +class RadialFourier3D(Transform): + """ + Computes the 3D Radial Fourier Transform of medical imaging data. + + This transform converts 3D medical images into radial frequency domain representations, + which is particularly useful for handling anisotropic resolution common in medical scans + (e.g., different resolution in axial vs coronal planes). + + The radial transform provides rotation-invariant frequency analysis and can help + normalize frequency representations across datasets with different acquisition parameters. + + Args: + normalize: if True, normalize the output by the number of voxels. + return_magnitude: if True, return magnitude of the complex result. + return_phase: if True, return phase of the complex result. + radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum. + max_frequency: maximum normalized frequency to include (0.0 to 1.0). + spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions. + + Returns: + Radial Fourier transform of input data. Shape depends on parameters: + - If radial_bins is None: complex tensor of same spatial shape as input + - If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase + + Example: + >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True) + >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth + >>> result = transform(image) # Shape: (1, 64) + """ + + def __init__( + self, + normalize: bool = True, + return_magnitude: bool = True, + return_phase: bool = False, + radial_bins: Optional[int] = None, + max_frequency: float = 1.0, + spatial_dims: Union[int, Sequence[int]] = (-3, -2, -1), + ) -> None: + super().__init__() + self.normalize = normalize + self.return_magnitude = return_magnitude + self.return_phase = return_phase + self.radial_bins = radial_bins + self.max_frequency = max_frequency + + if isinstance(spatial_dims, int): + spatial_dims = (spatial_dims,) + self.spatial_dims = tuple(spatial_dims) + + # Validate parameters + if not 0.0 < max_frequency <= 1.0: + raise ValueError(f"max_frequency must be in (0.0, 1.0], got {max_frequency}") + if radial_bins is not None and radial_bins < 1: + raise ValueError(f"radial_bins must be >= 1, got {radial_bins}") + if not return_magnitude and not return_phase: + raise ValueError("At least one of return_magnitude or return_phase must be True") + + def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor: + """ + Compute radial distance from frequency domain center. + + Args: + shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order. + + Returns: + Tensor of same spatial shape with radial distances. + """ + # Create frequency coordinates for each dimension + coords = [] + for dim_size in shape: + # Create frequency range from -0.5 to 0.5 + freq = torch.fft.fftfreq(dim_size) + coords.append(freq) + + # Create meshgrid and compute radial distance + mesh = torch.meshgrid(coords, indexing="ij") + radial = torch.sqrt(sum(c**2 for c in mesh)) + + return radial + + def _compute_radial_spectrum(self, spectrum: torch.Tensor, radial_coords: torch.Tensor) -> torch.Tensor: + """ + Compute radial average of frequency spectrum. + + Args: + spectrum: complex frequency spectrum (flattened 1D array). + radial_coords: radial distance for each frequency coordinate (flattened 1D array). + + Returns: + Radial average of spectrum (1D array of length radial_bins). + """ + if self.radial_bins is None: + return spectrum + + # Bin radial coordinates + max_r = self.max_frequency * 0.5 # Maximum normalized frequency + bin_edges = torch.linspace(0, max_r, self.radial_bins + 1, device=spectrum.device) + + # Initialize output + result_real = torch.zeros(self.radial_bins, dtype=spectrum.real.dtype, device=spectrum.device) + result_imag = torch.zeros(self.radial_bins, dtype=spectrum.imag.dtype, device=spectrum.device) + + # Bin the frequencies - spectrum and radial_coords are both 1D + for i in range(self.radial_bins): + mask = (radial_coords >= bin_edges[i]) & (radial_coords < bin_edges[i + 1]) + if mask.any(): + # spectrum is 1D, so we can index it directly + result_real[i] = spectrum.real[mask].mean() + result_imag[i] = spectrum.imag[mask].mean() + + # Combine real and imaginary parts + result = torch.complex(result_real, result_imag) + + return result + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply 3D Radial Fourier Transform to input data. + + Args: + img: input medical image data. Expected shape: (..., D, H, W) + where D, H, W are spatial dimensions. + + Returns: + Transformed data in radial frequency domain. + """ + # Convert to tensor if needed + img_tensor, *_ = convert_data_type(img, torch.Tensor) + # Get spatial dimensions + spatial_shape = tuple(img_tensor.shape[d] for d in self.spatial_dims) + if len(spatial_shape) != 3: + raise ValueError(f"Expected 3 spatial dimensions, got {len(spatial_shape)}") + + # Compute 3D FFT + # Shift zero frequency to center and compute FFT + spectrum = fftn(ifftshift(img_tensor, dim=self.spatial_dims), dim=self.spatial_dims) + spectrum = fftshift(spectrum, dim=self.spatial_dims) + + # Normalize if requested + if self.normalize: + norm_factor = math.prod(spatial_shape) + spectrum = spectrum / norm_factor + + # Compute radial coordinates + radial_coords = self._compute_radial_coordinates(spatial_shape) + + # Apply radial binning if requested + if self.radial_bins is not None: + # Reshape for radial processing + orig_shape = spectrum.shape + # Move spatial dimensions to end for processing + spatial_indices = [d % len(orig_shape) for d in self.spatial_dims] + non_spatial_indices = [i for i in range(len(orig_shape)) if i not in spatial_indices] + + # Reshape to (non_spatial..., spatial_prod) + flat_shape = (*[orig_shape[i] for i in non_spatial_indices], -1) + spectrum_flat = spectrum.moveaxis(spatial_indices, [-3, -2, -1]).reshape(flat_shape) + radial_flat = radial_coords.flatten() + + # Get non-spatial dimensions (batch, channel, etc.) + non_spatial_dims = spectrum_flat.shape[:-1] + spatial_size = spectrum_flat.shape[-1] + + # Reshape to 2D: (non_spatial_product, spatial_size) + non_spatial_product = 1 + for dim in non_spatial_dims: + non_spatial_product *= dim + + spectrum_2d = spectrum_flat.reshape(non_spatial_product, spatial_size) + + # Process each non-spatial element (batch/channel combination) + results = [] + for i in range(non_spatial_product): + elem_spectrum = spectrum_2d[i] # Get spatial frequencies for this batch/channel + radial_result = self._compute_radial_spectrum(elem_spectrum, radial_flat) + results.append(radial_result) + + # Combine results and reshape back + spectrum = torch.stack(results, dim=0) + spectrum = spectrum.reshape(*non_spatial_dims, self.radial_bins) + else: + # Apply frequency mask if max_frequency < 1.0 + if self.max_frequency < 1.0: + freq_mask = radial_coords <= (self.max_frequency * 0.5) + # Expand mask to match spectrum dimensions + for _ in range(len(self.spatial_dims)): + freq_mask = freq_mask.unsqueeze(0) + spectrum = spectrum * freq_mask + + # Extract magnitude and/or phase as requested + output = None + if self.return_magnitude: + magnitude = torch.abs(spectrum) + output = magnitude if output is None else torch.cat([output, magnitude], dim=-1) + + if self.return_phase: + phase = torch.angle(spectrum) + output = phase if output is None else torch.cat([output, phase], dim=-1) + + # Convert back to original data type + output, *_ = convert_data_type(output, type(img)) + + return output + + def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...]) -> NdarrayOrTensor: + """ + Inverse transform from radial frequency domain to spatial domain. + + Args: + radial_data: data in radial frequency domain. + original_shape: original spatial shape (D, H, W). + + Returns: + Reconstructed spatial data. + + Note: + This is an approximate inverse when radial_bins is used. + """ + if self.radial_bins is None: + # Direct inverse FFT + radial_tensor, *_ = convert_data_type(radial_data, torch.Tensor) + + # Separate magnitude and phase if needed + if self.return_magnitude and self.return_phase: + # Assuming they were concatenated along last dimension + split_idx = radial_tensor.shape[-1] // 2 + magnitude = radial_tensor[..., :split_idx] + phase = radial_tensor[..., split_idx:] + radial_tensor = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase)) + + # Apply inverse FFT + result = ifftn(ifftshift(radial_tensor, dim=self.spatial_dims), dim=self.spatial_dims) + result = fftshift(result, dim=self.spatial_dims) + + if self.normalize: + result = result * math.prod(original_shape) + + result, *_ = convert_data_type(result.real, type(radial_data)) + return result + + else: + raise NotImplementedError( + "Exact inverse transform not available for radially binned data. " + "Consider using radial_bins=None for applications requiring inversion." + ) + + +class RadialFourierFeatures3D(Transform): + """ + Extract radial Fourier features for medical image analysis. + + Computes multiple radial Fourier transforms with different parameters + to create a comprehensive frequency feature representation. + + Args: + n_bins_list: list of radial bin counts to compute. + return_types: list of return types: 'magnitude', 'phase', or 'complex'. + normalize: if True, normalize the output. + + Returns: + Concatenated radial Fourier features. + + Example: + >>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128]) + >>> image = torch.randn(1, 128, 128, 96) + >>> features = transform(image) # Shape: (1, 32+64+128=224) + """ + + def __init__( + self, + n_bins_list: Sequence[int] = (32, 64, 128), + return_types: Sequence[str] = ("magnitude",), + normalize: bool = True, + ) -> None: + super().__init__() + self.n_bins_list = n_bins_list + self.return_types = return_types + self.normalize = normalize + + # Create individual transforms + self.transforms = [] + for n_bins in n_bins_list: + for return_type in return_types: + transform = RadialFourier3D( + normalize=normalize, + return_magnitude=(return_type in ["magnitude", "complex"]), + return_phase=(return_type in ["phase", "complex"]), + radial_bins=n_bins, + ) + self.transforms.append(transform) + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """Extract radial Fourier features.""" + features = [] + for transform in self.transforms: + feat = transform(img) + features.append(feat) + + # Concatenate along last dimension + if features: + # Convert all features to tensors if any are numpy arrays + features_tensors = [] + for feat in features: + if isinstance(feat, np.ndarray): + features_tensors.append(torch.from_numpy(feat)) + else: + features_tensors.append(feat) + output = torch.cat(features_tensors, dim=-1) + else: + output = img + + # Convert to original type if needed + if isinstance(img, np.ndarray): + output = output.cpu().numpy() + + return output diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py new file mode 100644 index 0000000000..6b2caa0810 --- /dev/null +++ b/tests/test_radial_fourier.py @@ -0,0 +1,196 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the 3D Radial Fourier Transform. +""" + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RadialFourier3D, RadialFourierFeatures3D +from monai.utils import set_determinism + + +class TestRadialFourier3D(unittest.TestCase): + """Test cases for RadialFourier3D transform.""" + + def setUp(self): + """Set up test fixtures.""" + set_determinism(seed=42) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Create test data + self.test_image_3d = torch.randn(1, 32, 64, 64, device=self.device) # Batch, D, H, W + self.test_image_4d = torch.randn(2, 1, 48, 64, 64, device=self.device) # Batch, Channel, D, H, W + + def tearDown(self): + """Clean up after tests.""" + set_determinism(seed=None) + + @parameterized.expand( + [ + [{"radial_bins": 32, "return_magnitude": True}, (1, 32)], + [{"radial_bins": 64, "return_magnitude": True, "return_phase": True}, (1, 128)], + [{"radial_bins": None, "return_magnitude": True}, (1, 32, 64, 64)], + [{"radial_bins": 16, "return_magnitude": True, "max_frequency": 0.5}, (1, 16)], + ] + ) + def test_output_shape(self, params, expected_shape): + """Test that output shape matches expectations.""" + transform = RadialFourier3D(**params) + result = transform(self.test_image_3d) + self.assertEqual(result.shape, expected_shape) + + def test_complex_input(self): + """Test with complex-valued input.""" + complex_image = torch.complex( + torch.randn(1, 32, 64, 64, device=self.device), + torch.randn(1, 32, 64, 64, device=self.device), + ) + transform = RadialFourier3D(radial_bins=32, return_magnitude=True) + result = transform(complex_image) + self.assertEqual(result.shape, (1, 32)) + + def test_normalization(self): + """Test normalization affects output scale.""" + transform1 = RadialFourier3D(radial_bins=32, normalize=True) + transform2 = RadialFourier3D(radial_bins=32, normalize=False) + + result1 = transform1(self.test_image_3d) + result2 = transform2(self.test_image_3d) + + # Normalized result should be smaller + self.assertLess(torch.abs(result1).mean().item(), torch.abs(result2).mean().item()) + + def test_inverse_transform(self): + """Test approximate inverse transform.""" + # Use full spectrum for invertibility + transform = RadialFourier3D(radial_bins=None, normalize=True) + + # Forward transform + spectrum = transform(self.test_image_3d) + + # Inverse transform + reconstructed = transform.inverse(spectrum, self.test_image_3d.shape[-3:]) + + # Should have same shape + self.assertEqual(reconstructed.shape, self.test_image_3d.shape) + + def test_deterministic(self): + """Test that transform is deterministic.""" + transform = RadialFourier3D(radial_bins=32) + + result1 = transform(self.test_image_3d) + result2 = transform(self.test_image_3d) + + self.assertTrue(torch.allclose(result1, result2, rtol=1e-5)) + + def test_numpy_input(self): + """Test that numpy arrays are accepted.""" + np_image = self.test_image_3d.cpu().numpy() + transform = RadialFourier3D(radial_bins=32) + + result = transform(np_image) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (1, 32)) + + @parameterized.expand( + [ + [{"max_frequency": -0.1}], # Invalid negative + [{"max_frequency": 1.5}], # Invalid > 1.0 + [{"radial_bins": 0}], # Invalid zero bins + [{"return_magnitude": False, "return_phase": False}], # No output requested + ] + ) + def test_invalid_parameters(self, params): + """Test that invalid parameters raise errors.""" + with self.assertRaises(ValueError): + RadialFourier3D(**params) + + def test_spatial_dims_parameter(self): + """Test custom spatial dimensions.""" + # Test with 4D input but spatial dims in middle + image = torch.randn(2, 32, 64, 64, 3, device=self.device) # Batch, D, H, W, Channels + transform = RadialFourier3D(radial_bins=16, spatial_dims=(1, 2, 3)) + result = transform(image) + self.assertEqual(result.shape, (2, 3, 16)) + + def test_batch_processing(self): + """Test processing batch of images.""" + batch_size = 4 + batch_image = torch.randn(batch_size, 32, 64, 64, device=self.device) + transform = RadialFourier3D(radial_bins=32) + result = transform(batch_image) + self.assertEqual(result.shape, (batch_size, 32)) + + +class TestRadialFourierFeatures3D(unittest.TestCase): + """Test cases for RadialFourierFeatures3D transform.""" + + def setUp(self): + """Set up test fixtures.""" + set_determinism(seed=42) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.test_image = torch.randn(2, 32, 64, 64, device=self.device) + + def tearDown(self): + """Clean up after tests.""" + set_determinism(seed=None) + + def test_feature_extraction(self): + """Test multi-scale feature extraction.""" + transform = RadialFourierFeatures3D(n_bins_list=[16, 32, 64], return_types=["magnitude"]) + + features = transform(self.test_image) + expected_features = 16 + 32 + 64 # Sum of all bins + + self.assertEqual(features.shape, (2, expected_features)) + + def test_multiple_return_types(self): + """Test with multiple return types.""" + transform = RadialFourierFeatures3D(n_bins_list=[16, 32], return_types=["magnitude", "phase"]) + + features = transform(self.test_image) + # Each bin count appears twice (magnitude and phase) + expected_features = (16 + 32) * 2 + + self.assertEqual(features.shape, (2, expected_features)) + + def test_complex_output(self): + """Test complex output type.""" + transform = RadialFourierFeatures3D(n_bins_list=[16], return_types=["complex"]) + + features = transform(self.test_image) + # Complex returns both magnitude and phase concatenated + self.assertEqual(features.shape, (2, 16 * 2)) + + def test_empty_bins_list(self): + """Test with empty bins list.""" + transform = RadialFourierFeatures3D(n_bins_list=[], return_types=["magnitude"]) + features = transform(self.test_image) + # Should return original image when no transforms + self.assertEqual(features.shape, self.test_image.shape) + + def test_numpy_compatibility(self): + """Test with numpy input.""" + np_image = self.test_image.cpu().numpy() + transform = RadialFourierFeatures3D(n_bins_list=[16, 32]) + + features = transform(np_image) + self.assertIsInstance(features, np.ndarray) + self.assertEqual(features.shape, (2, 16 + 32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/signal/__init__.py b/tests/transforms/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2