1010
1111import asyncio
1212from os import cpu_count
13+ from contextlib import suppress
1314from functools import partial
1415from pathlib import Path
1516from typing import Callable , TypeVar , Union
@@ -108,14 +109,16 @@ async def _apply_serial(
108109 semaphore = asyncio .Semaphore (max_concurrent )
109110
110111 for t in range (n_resamplings ):
111- xfm_t = transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
112+ xfm_t = (
113+ transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
114+ )
112115
113116 targets_t = (
114117 ImageGrid (spatialimage ).index (
115118 _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116119 )
117120 if targets is None
118- else targets
121+ else targets [ t , ...]
119122 )
120123
121124 data_t = (
@@ -258,11 +261,23 @@ def apply(
258261 dim = _ref .ndim ,
259262 )
260263 )
261- elif xfm_nvols == 1 :
262- targets = ImageGrid (spatialimage ).index ( # data should be an image
263- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
264+ else :
265+ # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+ targets = (
267+ ImageGrid (spatialimage ).index (
268+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
269+ )
270+ if targets is None
271+ else targets
264272 )
265273
274+ if targets .ndim == 3 :
275+ targets = np .rollaxis (targets , targets .ndim - 1 , 0 )
276+ elif targets .ndim == 2 :
277+ targets = targets [np .newaxis , ...]
278+ else : # pragma: no cover
279+ raise RuntimeError (f"Can't generate targets with { targets .ndim } dimensions." )
280+
266281 if serialize_4d :
267282 data = (
268283 np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
@@ -297,17 +312,24 @@ def apply(
297312 else :
298313 data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
299314
300- if targets is None :
301- targets = ImageGrid (spatialimage ).index ( # data should be an image
302- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
303- )
304-
315+ if data_nvols == 1 and xfm_nvols == 1 :
316+ targets = np .squeeze (targets )
317+ assert targets .ndim == 2
305318 # Cast 3D data into 4D if 4D nonsequential transform
306- if data_nvols == 1 and xfm_nvols > 1 :
319+ elif data_nvols == 1 and xfm_nvols > 1 :
307320 data = data [..., np .newaxis ]
308321
309- if transform .ndim == 4 :
310- targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
322+ if xfm_nvols > 1 :
323+ assert targets .ndim == 3
324+ n_time , n_dim , n_vox = targets .shape
325+ # Reshape to (3, n_time x n_vox)
326+ ijk_targets = np .rollaxis (targets , 0 , 2 ).reshape ((n_dim , - 1 ))
327+ time_row = np .repeat (np .arange (n_time ), n_vox )[None , :]
328+
329+ # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
330+ # t is the slowest-changing axis, so we put it first
331+ targets = np .vstack ((time_row , ijk_targets ))
332+ data = np .rollaxis (data , data .ndim - 1 , 0 )
311333
312334 resampled = ndi .map_coordinates (
313335 data ,
@@ -326,11 +348,19 @@ def apply(
326348 )
327349 hdr .set_data_dtype (output_dtype or spatialimage .header .get_data_dtype ())
328350
329- moved = spatialimage .__class__ (
330- resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
331- _ref .affine ,
332- hdr ,
333- )
351+ if serialize_4d :
352+ resampled = resampled .reshape (
353+ _ref .shape
354+ if n_resamplings == 1
355+ else _ref .shape + (resampled .shape [- 1 ],)
356+ )
357+ else :
358+ resampled = resampled .reshape ((- 1 , * _ref .shape ))
359+ resampled = np .rollaxis (resampled , 0 , resampled .ndim )
360+ with suppress (ValueError ):
361+ resampled = np .squeeze (resampled , axis = 3 )
362+
363+ moved = spatialimage .__class__ (resampled , _ref .affine , hdr )
334364 return moved
335365
336366 output_dtype = output_dtype or input_dtype
0 commit comments