Skip to content

Commit 6377203

Browse files
committed
fix: generalize targets, test all branches
1 parent 72cd04f commit 6377203

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

nitransforms/resampling.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import asyncio
1212
from os import cpu_count
13+
from contextlib import suppress
1314
from functools import partial
1415
from pathlib import Path
1516
from typing import Callable, TypeVar, Union
@@ -108,14 +109,16 @@ async def _apply_serial(
108109
semaphore = asyncio.Semaphore(max_concurrent)
109110

110111
for t in range(n_resamplings):
111-
xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
112+
xfm_t = (
113+
transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
114+
)
112115

113116
targets_t = (
114117
ImageGrid(spatialimage).index(
115118
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim)
116119
)
117120
if targets is None
118-
else targets
121+
else targets[t, ...]
119122
)
120123

121124
data_t = (
@@ -258,11 +261,23 @@ def apply(
258261
dim=_ref.ndim,
259262
)
260263
)
261-
elif xfm_nvols == 1:
262-
targets = ImageGrid(spatialimage).index( # data should be an image
263-
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
264+
else:
265+
# Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+
targets = (
267+
ImageGrid(spatialimage).index(
268+
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
269+
)
270+
if targets is None
271+
else targets
264272
)
265273

274+
if targets.ndim == 3:
275+
targets = np.rollaxis(targets, targets.ndim - 1, 0)
276+
elif targets.ndim == 2:
277+
targets = targets[np.newaxis, ...]
278+
else: # pragma: no cover
279+
raise RuntimeError(f"Can't generate targets with {targets.ndim} dimensions.")
280+
266281
if serialize_4d:
267282
data = (
268283
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
@@ -297,17 +312,24 @@ def apply(
297312
else:
298313
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
299314

300-
if targets is None:
301-
targets = ImageGrid(spatialimage).index( # data should be an image
302-
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
303-
)
304-
315+
if data_nvols == 1 and xfm_nvols == 1:
316+
targets = np.squeeze(targets)
317+
assert targets.ndim == 2
305318
# Cast 3D data into 4D if 4D nonsequential transform
306-
if data_nvols == 1 and xfm_nvols > 1:
319+
elif data_nvols == 1 and xfm_nvols > 1:
307320
data = data[..., np.newaxis]
308321

309-
if transform.ndim == 4:
310-
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
322+
if xfm_nvols > 1:
323+
assert targets.ndim == 3
324+
n_time, n_dim, n_vox = targets.shape
325+
# Reshape to (3, n_time x n_vox)
326+
ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1))
327+
time_row = np.repeat(np.arange(n_time), n_vox)[None, :]
328+
329+
# Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
330+
# t is the slowest-changing axis, so we put it first
331+
targets = np.vstack((time_row, ijk_targets))
332+
data = np.rollaxis(data, data.ndim - 1, 0)
311333

312334
resampled = ndi.map_coordinates(
313335
data,
@@ -326,11 +348,19 @@ def apply(
326348
)
327349
hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype())
328350

329-
moved = spatialimage.__class__(
330-
resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)),
331-
_ref.affine,
332-
hdr,
333-
)
351+
if serialize_4d:
352+
resampled = resampled.reshape(
353+
_ref.shape
354+
if n_resamplings == 1
355+
else _ref.shape + (resampled.shape[-1],)
356+
)
357+
else:
358+
resampled = resampled.reshape((-1, *_ref.shape))
359+
resampled = np.rollaxis(resampled, 0, resampled.ndim)
360+
with suppress(ValueError):
361+
resampled = np.squeeze(resampled, axis=3)
362+
363+
moved = spatialimage.__class__(resampled, _ref.affine, hdr)
334364
return moved
335365

336366
output_dtype = output_dtype or input_dtype

nitransforms/tests/test_resampling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def test_LinearTransformsMapping_apply(
365365
)
366366

367367

368-
def test_apply_serialized_4d_multiple_targets():
368+
@pytest.mark.parametrize("serialize_4d", [True, False])
369+
def test_apply_4d(serialize_4d):
369370
"""Regression test for per-volume transforms with serialized resampling."""
370371
nvols = 9
371372
shape = (10, 5, 5)
@@ -379,9 +380,11 @@ def test_apply_serialized_4d_multiple_targets():
379380
mat[0, 3] = i
380381
transforms.append(nitl.Affine(mat))
381382

383+
extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1}
384+
382385
xfm = nitl.LinearTransformsMapping(transforms, reference=img)
383-
moved = apply(xfm, img, order=0)
384386

387+
moved = apply(xfm, img, order=0, **extraparams)
385388
data = np.asanyarray(moved.dataobj)
386389
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
387390
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]

0 commit comments

Comments
 (0)