diff --git a/docs/_api/surface.rst b/docs/_api/surface.rst new file mode 100644 index 00000000..92ba28ab --- /dev/null +++ b/docs/_api/surface.rst @@ -0,0 +1,6 @@ +================== +Surface Transforms +================== + +.. automodule:: nitransforms.surface + :members: diff --git a/docs/api.rst b/docs/api.rst index eb3c566b..a57d6836 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -10,5 +10,6 @@ Information on specific functions, classes, and methods for developers. _api/linear _api/manip _api/nonlinear + _api/surface _api/interp _api/patched diff --git a/nitransforms/__init__.py b/nitransforms/__init__.py index 1f819933..38768ae9 100644 --- a/nitransforms/__init__.py +++ b/nitransforms/__init__.py @@ -16,7 +16,7 @@ transform """ -from . import linear, manip, nonlinear +from . import linear, manip, nonlinear, surface from .linear import Affine, LinearTransformsMapping from .nonlinear import DenseFieldTransform from .manip import TransformChain @@ -37,6 +37,7 @@ __copyright__ = "Copyright (c) 2021 The NiPy developers" __all__ = [ + "surface", "linear", "manip", "nonlinear", diff --git a/nitransforms/base.py b/nitransforms/base.py index ac6e7520..9c8310ab 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -15,6 +15,7 @@ from nibabel import funcs as _nbfuncs from nibabel.nifti1 import intent_codes as INTENT_CODES from nibabel.cifti2 import Cifti2Image +import nibabel as nb EQUALITY_TOL = 1e-5 @@ -88,6 +89,76 @@ def shape(self): return self._shape +class SurfaceMesh(SampledSpatialData): + """Class to represent surface meshes.""" + + __slots__ = ["_triangles"] + + def __init__(self, dataset): + """Create a sampling reference.""" + self._shape = None + + if isinstance(dataset, SurfaceMesh): + self._coords = dataset._coords + self._triangles = dataset._triangles + self._ndim = dataset._ndim + self._npoints = dataset._npoints + self._shape = dataset._shape + return + + if isinstance(dataset, (str, Path)): + dataset = _nbload(str(dataset)) + + if hasattr(dataset, "numDA"): # Looks like a Gifti file + _das = dataset.get_arrays_from_intent(INTENT_CODES["pointset"]) + if not _das: + raise TypeError( + "Input Gifti file does not contain reference coordinates." + ) + self._coords = np.vstack([da.data for da in _das]) + _tris = dataset.get_arrays_from_intent(INTENT_CODES["triangle"]) + self._triangles = np.vstack([da.data for da in _tris]) + self._npoints, self._ndim = self._coords.shape + self._shape = self._coords.shape + return + + if isinstance(dataset, Cifti2Image): + raise NotImplementedError + + raise ValueError("Dataset could not be interpreted as an irregular sample.") + + def check_sphere(self, tolerance=1.001): + """Check sphericity of surface. + Based on https://github.com/Washington-University/workbench/blob/\ +7ba3345d161d567a4b628ceb02ab4471fc96cb20/src/Files/SurfaceResamplingHelper.cxx#L503 + """ + dists = np.linalg.norm(self._coords, axis=1) + return (dists.min() * tolerance) > dists.max() + + def set_radius(self, radius=100): + if not self.check_sphere(): + raise ValueError("You should only set the radius on spherical surfaces.") + dists = np.linalg.norm(self._coords, axis=1) + self._coords = self._coords * (radius / dists).reshape((-1, 1)) + + @classmethod + def from_arrays(cls, coordinates, triangles): + darrays = [ + nb.gifti.GiftiDataArray( + coordinates.astype(np.float32), + intent=nb.nifti1.intent_codes['NIFTI_INTENT_POINTSET'], + datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_FLOAT32'], + ), + nb.gifti.GiftiDataArray( + triangles.astype(np.int32), + intent=nb.nifti1.intent_codes['NIFTI_INTENT_TRIANGLE'], + datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_INT32'], + ), + ] + gii = nb.gifti.GiftiImage(darrays=darrays) + return cls(gii) + + class ImageGrid(SampledSpatialData): """Class to represent spaces of gridded data (images).""" diff --git a/nitransforms/surface.py b/nitransforms/surface.py new file mode 100644 index 00000000..7e1e7116 --- /dev/null +++ b/nitransforms/surface.py @@ -0,0 +1,652 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +"""Surface transforms.""" +import pathlib +import warnings +import h5py +import numpy as np +import nibabel as nb +from scipy import sparse +from scipy.spatial import KDTree +from scipy.spatial.distance import cdist +from nitransforms.base import ( + SurfaceMesh +) + + +class SurfaceTransformBase(): + """Generic surface transformation class""" + + def __init__(self, reference, moving, spherical=False): + """Instantiate a generic surface transform.""" + if spherical: + if not reference.check_sphere(): + raise ValueError("reference was not spherical") + if not moving.check_sphere(): + raise ValueError("moving was not spherical") + reference.set_radius() + moving.set_radius() + self._reference = reference + self._moving = moving + + def __eq__(self, other): + ref_coords_eq = np.all(self.reference._coords == other.reference._coords) + ref_tris_eq = np.all(self.reference._triangles == other.reference._triangles) + mov_coords_eq = np.all(self.moving._coords == other.moving._coords) + mov_tris_eq = np.all(self.moving._triangles == other.moving._triangles) + return ref_coords_eq & ref_tris_eq & mov_coords_eq & mov_tris_eq + + def __invert__(self): + return self.__class__(self._moving, self._reference) + + @property + def reference(self): + return self._reference + + @reference.setter + def reference(self, surface): + self._reference = SurfaceMesh(surface) + + @property + def moving(self): + return self._moving + + @moving.setter + def moving(self, surface): + self._moving = SurfaceMesh(surface) + + @classmethod + def from_filename(cls, reference_path, moving_path): + """Create an Surface Index Transformation from a pair of surfaces with corresponding + vertices.""" + reference = SurfaceMesh(nb.load(reference_path)) + moving = SurfaceMesh(nb.load(moving_path)) + return cls(reference, moving) + + +class SurfaceCoordinateTransform(SurfaceTransformBase): + """Represents surface transformations in which the indices correspond and the coordinates + differ. This could be two surfaces representing difference structures from the same + hemisphere, like white matter and pial, or it could be a sphere and a deformed sphere that + moves those coordinates to a different location.""" + + __slots__ = ("_reference", "_moving") + + def __init__(self, reference, moving): + """Instantiate a transform between two surfaces with corresponding vertices. + Parameters + ---------- + reference: surface + Surface with the starting coordinates for each index. + moving: surface + Surface with the destination coordinates for each index. + """ + + super().__init__(reference=SurfaceMesh(reference), moving=SurfaceMesh(moving)) + if np.all(self._reference._triangles != self._moving._triangles): + raise ValueError("Both surfaces for an index transform must have corresponding" + " vertices.") + + def map(self, x, inverse=False): + if not inverse: + source = self.reference + dest = self.moving + else: + source = self.moving + dest = self.reference + + s_tree = KDTree(source._coords) + dists, matches = s_tree.query(x) + if not np.allclose(dists, 0): + raise NotImplementedError("Mapping on surfaces not implemented for coordinates that" + " aren't vertices") + return dest._coords[matches] + + def __add__(self, other): + if isinstance(other, SurfaceCoordinateTransform): + return self.__class__(self.reference, other.moving) + raise NotImplementedError + + def _to_hdf5(self, x5_root): + """Write transform to HDF5 file.""" + triangles = x5_root.create_group("Triangles") + coords = x5_root.create_group("Coordinates") + coords.create_dataset("0", data=self.reference._coords) + coords.create_dataset("1", data=self.moving._coords) + triangles.create_dataset("0", data=self.reference._triangles) + xform = x5_root.create_group("Transform") + xform.attrs["Type"] = "SurfaceCoordinateTransform" + reference = xform.create_group("Reference") + reference['Coordinates'] = h5py.SoftLink('/0/Coordinates/0') + reference['Triangles'] = h5py.SoftLink('/0/Triangles/0') + moving = xform.create_group("Moving") + moving['Coordinates'] = h5py.SoftLink('/0/Coordinates/1') + moving['Triangles'] = h5py.SoftLink('/0/Triangles/0') + + def to_filename(self, filename, fmt=None): + """Store the transform.""" + if fmt is None: + fmt = "npz" if filename.endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # sparse.save_npz(filename, self.mat) + # return filename + + with h5py.File(filename, "w") as out_file: + out_file.attrs["Format"] = "X5" + out_file.attrs["Version"] = np.uint16(1) + root = out_file.create_group("/0") + self._to_hdf5(root) + + return filename + + @classmethod + def from_filename(cls, filename=None, reference_path=None, moving_path=None, + fmt=None): + """Load transform from file.""" + if filename is None: + if reference_path is None or moving_path is None: + raise ValueError("You must pass either a X5 file or a pair of reference and moving" + " surfaces.") + return cls(SurfaceMesh(nb.load(reference_path)), + SurfaceMesh(nb.load(moving_path))) + + if fmt is None: + try: + fmt = "npz" if filename.endswith(".npz") else "X5" + except AttributeError: + fmt = "npz" if filename.as_posix().endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # return cls(sparse.load_npz(filename)) + + if fmt != "X5": + raise ValueError("Only npz and X5 formats are supported.") + + with h5py.File(filename, "r") as f: + assert f.attrs["Format"] == "X5" + xform = f["/0/Transform"] + reference = SurfaceMesh.from_arrays( + xform['Reference']['Coordinates'], + xform['Reference']['Triangles'] + ) + + moving = SurfaceMesh.from_arrays( + xform['Moving']['Coordinates'], + xform['Moving']['Triangles'] + ) + return cls(reference, moving) + + +class SurfaceResampler(SurfaceTransformBase): + """ + Represents transformations in which the coordinate space remains the same + and the indices change. + To achieve surface project-unproject functionality: + sphere_in as the reference + sphere_project_to as the moving + Then apply the transformation to sphere_unproject_from + """ + + __slots__ = ("_reference", "_moving", "mat", 'interpolation_method') + + def __init__(self, reference, moving, interpolation_method='barycentric', mat=None): + """Initialize the resampling. + + Parameters + ---------- + reference: spherical surface of the reference space. + Output will have number of indices equal to the number of indicies in this surface. + Both reference and moving should be in the same coordinate space. + moving: spherical surface that will be resampled. + Both reference and moving should be in the same coordinate space. + mat : array-like, shape (nv1, nv2) + Sparse matrix representing the transform. + interpolation_method : str + Only barycentric is currently implemented + """ + + super().__init__(SurfaceMesh(reference), SurfaceMesh(moving), spherical=True) + + self.reference.set_radius() + self.moving.set_radius() + if interpolation_method not in ['barycentric']: + raise NotImplementedError(f"{interpolation_method} is not implemented.") + self.interpolation_method = interpolation_method + + # TODO: should we deal with the case where reference and moving are the same? + + # we're calculating the interpolation in the init so that we can ensure + # that it only has to be calculated once and will always be saved with the + # transform + if mat is None: + self.__calculate_mat() + m_tree = KDTree(self.moving._coords) + _, kmr_closest = m_tree.query(self.reference._coords, k=10) + + # invert the triangles to generate a lookup table from vertices to triangle index + tri_lut = {} + for i, idxs in enumerate(self.moving._triangles): + for x in idxs: + if x not in tri_lut: + tri_lut[x] = [i] + else: + tri_lut[x].append(i) + + # calculate the barycentric interpolation weights + bc_weights = [] + enclosing = [] + for point, kmrv in zip(self.reference._coords, kmr_closest): + close_tris = _find_close_tris(kmrv, tri_lut, self.moving) + ww, ee = _find_weights(point, close_tris, m_tree) + bc_weights.append(ww) + enclosing.append(ee) + + # build sparse matrix + # commenting out code for barycentric nearest neighbor + # bary_nearest = [] + mat = sparse.lil_array((self.reference._npoints, self.moving._npoints)) + for s_ix, dd in enumerate(bc_weights): + for k, v in dd.items(): + mat[s_ix, k] = v + # bary_nearest.append(np.array(list(dd.keys()))[np.array(list(dd.values())).argmax()]) + # bary_nearest = np.array(bary_nearest) + # transpose so that number of out vertices is columns + self.mat = sparse.csr_array(mat.T) + else: + if isinstance(mat, sparse.csr_array): + self.mat = mat + else: + self.mat = sparse.csr_array(mat) + # validate shape of the provided matrix + if (mat.shape[0] != moving._npoints) or (mat.shape[1] != reference._npoints): + msg = "Shape of provided mat does not match expectations based on " \ + "dimensions of moving and reference. \n" + if mat.shape[0] != moving._npoints: + msg += f" mat has {mat.shape[0]} rows but moving has {moving._npoints} " \ + f"vertices. \n" + if mat.shape[1] != reference._npoints: + msg += f" mat has {mat.shape[1]} columns but reference has" \ + f" {reference._npoints} vertices." + raise ValueError(msg) + + def __calculate_mat(self): + m_tree = KDTree(self.moving._coords) + _, kmr_closest = m_tree.query(self.reference._coords, k=10) + + # invert the triangles to generate a lookup table from vertices to triangle index + tri_lut = {} + for i, idxs in enumerate(self.moving._triangles): + for x in idxs: + if x not in tri_lut: + tri_lut[x] = [i] + else: + tri_lut[x].append(i) + + # calculate the barycentric interpolation weights + bc_weights = [] + enclosing = [] + for point, kmrv in zip(self.reference._coords, kmr_closest): + close_tris = _find_close_tris(kmrv, tri_lut, self.moving) + ww, ee = _find_weights(point, close_tris, m_tree) + bc_weights.append(ww) + enclosing.append(ee) + + # build sparse matrix + # commenting out code for barycentric nearest neighbor + # bary_nearest = [] + mat = sparse.lil_array((self.reference._npoints, self.moving._npoints)) + for s_ix, dd in enumerate(bc_weights): + for k, v in dd.items(): + mat[s_ix, k] = v + # bary_nearest.append( + # np.array(list(dd.keys()))[np.array(list(dd.values())).argmax()] + # ) + # bary_nearest = np.array(bary_nearest) + # transpose so that number of out vertices is columns + self.mat = sparse.csr_array(mat.T) + + def map(self, x): + return x + + def __add__(self, other): + if (isinstance(other, SurfaceResampler) + and (other.interpolation_method == self.interpolation_method)): + return self.__class__( + self.reference, + other.moving, + interpolation_method=self.interpolation_method + ) + raise NotImplementedError + + def __invert__(self): + return self.__class__( + self.moving, + self.reference, + interpolation_method=self.interpolation_method + ) + + @SurfaceTransformBase.reference.setter + def reference(self, surface): + raise ValueError("Don't modify the reference of an existing resampling." + "Create a new one instead.") + + @SurfaceTransformBase.moving.setter + def moving(self, surface): + raise ValueError("Don't modify the moving of an existing resampling." + "Create a new one instead.") + + def apply(self, x, inverse=False, normalize="element"): + """Apply the transform to surface data. + + Parameters + ---------- + x : array-like, shape (..., nv1), or SurfaceMesh + Data to transform or SurfaceMesh to resample + inverse : bool, default=False + Whether to apply the inverse transform. If True, ``x`` has shape + (..., nv2), and the output will have shape (..., nv1). + normalize : {"element", "sum", "none"}, default="element" + Normalization strategy. If "element", the scale of each value in + the output is comparable to each value of the input. If "sum", the + sum of the output is comparable to the sum of the input. If + "none", no normalization is applied. + + Returns + ------- + y : array-like, shape (..., nv2) + Transformed data. + """ + if normalize not in ("element", "sum", "none"): + raise ValueError("Invalid normalization strategy.") + + mat = self.mat.T if inverse else self.mat + + if normalize == "element": + sum_ = mat.sum(axis=0) + scale = np.zeros_like(sum_) + mask = sum_ != 0 + scale[mask] = 1.0 / sum_[mask] + mat = mat @ sparse.diags(scale) + elif normalize == "sum": + sum_ = mat.sum(axis=1) + scale = np.zeros_like(sum_) + mask = sum_ != 0 + scale[mask] = 1.0 / sum_[mask] + mat = sparse.diags(scale) @ mat + + if isinstance(x, (SurfaceMesh, pathlib.PurePath, str)): + x = SurfaceMesh(x) + if not x.check_sphere(): + raise ValueError("If x is a surface, it should be a sphere.") + x.set_radius() + rs_coords = x._coords.T @ mat + + y = SurfaceMesh.from_arrays(rs_coords.T, self.reference._triangles) + y.set_radius() + else: + y = x @ mat + return y + + def _to_hdf5(self, x5_root): + """Write transform to HDF5 file.""" + triangles = x5_root.create_group("Triangles") + coords = x5_root.create_group("Coordinates") + coords.create_dataset("0", data=self.reference._coords) + coords.create_dataset("1", data=self.moving._coords) + triangles.create_dataset("0", data=self.reference._triangles) + triangles.create_dataset("1", data=self.moving._triangles) + xform = x5_root.create_group("Transform") + xform.attrs["Type"] = "SurfaceResampling" + xform.attrs['InterpolationMethod'] = self.interpolation_method + mat = xform.create_group("IndexWeights") + mat.create_dataset("Data", data=self.mat.data) + mat.create_dataset("Indices", data=self.mat.indices) + mat.create_dataset("Indptr", data=self.mat.indptr) + mat.create_dataset("Shape", data=self.mat.shape) + reference = xform.create_group("Reference") + reference['Coordinates'] = h5py.SoftLink('/0/Coordinates/0') + reference['Triangles'] = h5py.SoftLink('/0/Triangles/0') + moving = xform.create_group("Moving") + moving['Coordinates'] = h5py.SoftLink('/0/Coordinates/1') + moving['Triangles'] = h5py.SoftLink('/0/Triangles/1') + + def to_filename(self, filename, fmt=None): + """Store the transform.""" + if fmt is None: + fmt = "npz" if filename.endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # sparse.save_npz(filename, self.mat) + # return filename + + with h5py.File(filename, "w") as out_file: + out_file.attrs["Format"] = "X5" + out_file.attrs["Version"] = np.uint16(1) + root = out_file.create_group("/0") + self._to_hdf5(root) + + return filename + + @classmethod + def from_filename(cls, filename=None, reference_path=None, moving_path=None, + fmt=None, interpolation_method=None): + """Load transform from file.""" + if filename is None: + if reference_path is None or moving_path is None: + raise ValueError("You must pass either a X5 file or a pair of reference and moving" + " surfaces.") + if interpolation_method is None: + interpolation_method = 'barycentric' + return cls(SurfaceMesh(nb.load(reference_path)), + SurfaceMesh(nb.load(moving_path)), + interpolation_method=interpolation_method) + + if fmt is None: + try: + fmt = "npz" if filename.endswith(".npz") else "X5" + except AttributeError: + fmt = "npz" if filename.as_posix().endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # return cls(sparse.load_npz(filename)) + + if fmt != "X5": + raise ValueError("Only npz and X5 formats are supported.") + + with h5py.File(filename, "r") as f: + assert f.attrs["Format"] == "X5" + xform = f["/0/Transform"] + try: + iws = xform['IndexWeights'] + mat = sparse.csr_matrix( + (iws["Data"][()], iws["Indices"][()], iws["Indptr"][()]), + shape=iws["Shape"][()], + ) + except KeyError: + mat = None + reference = SurfaceMesh.from_arrays( + xform['Reference']['Coordinates'], + xform['Reference']['Triangles'] + ) + + moving = SurfaceMesh.from_arrays( + xform['Moving']['Coordinates'], + xform['Moving']['Triangles'] + ) + interpolation_method = xform.attrs['InterpolationMethod'] + return cls(reference, moving, interpolation_method=interpolation_method, mat=mat) + + +def _points_to_triangles(points, triangles): + + """Implementation that vectorizes project of a point to a set of triangles. + from: https://stackoverflow.com/a/32529589 + """ + with np.errstate(all='ignore'): + # Unpack triangle points + p0, p1, p2 = np.asarray(triangles).swapaxes(0, 1) + + # Calculate triangle edges + e0 = p1 - p0 + e1 = p2 - p0 + a = np.einsum('...i,...i', e0, e0) + b = np.einsum('...i,...i', e0, e1) + c = np.einsum('...i,...i', e1, e1) + + # Calculate determinant and denominator + det = a * c - b * b + inv_det = 1. / det + denom = a - 2 * b + c + + # Project to the edges + p = p0 - points[:, np.newaxis] + d = np.einsum('...i,...i', e0, p) + e = np.einsum('...i,...i', e1, p) + u = b * e - c * d + v = b * d - a * e + + # Calculate numerators + bd = b + d + ce = c + e + numer0 = (ce - bd) / denom + numer1 = (c + e - b - d) / denom + da = -d / a + ec = -e / c + + # Vectorize test conditions + m0 = u + v < det + m1 = u < 0 + m2 = v < 0 + m3 = d < 0 + m4 = a + d > b + e + + m5 = ce > bd + + t0 = m0 & m1 & m2 & m3 + t1 = m0 & m1 & m2 & ~m3 + t2 = m0 & m1 & ~m2 + t3 = m0 & ~m1 & m2 + t4 = m0 & ~m1 & ~m2 + t5 = ~m0 & m1 & m5 + t6 = ~m0 & m1 & ~m5 + t7 = ~m0 & m2 & m4 + t8 = ~m0 & m2 & ~m4 + t9 = ~m0 & ~m1 & ~m2 + + u = np.where(t0, np.clip(da, 0, 1), u) + v = np.where(t0, 0, v) + u = np.where(t1, 0, u) + v = np.where(t1, 0, v) + u = np.where(t2, 0, u) + v = np.where(t2, np.clip(ec, 0, 1), v) + u = np.where(t3, np.clip(da, 0, 1), u) + v = np.where(t3, 0, v) + u *= np.where(t4, inv_det, 1) + v *= np.where(t4, inv_det, 1) + u = np.where(t5, np.clip(numer0, 0, 1), u) + v = np.where(t5, 1 - u, v) + u = np.where(t6, 0, u) + v = np.where(t6, 1, v) + u = np.where(t7, np.clip(numer1, 0, 1), u) + v = np.where(t7, 1 - u, v) + u = np.where(t8, 1, u) + v = np.where(t8, 0, v) + u = np.where(t9, np.clip(numer1, 0, 1), u) + v = np.where(t9, 1 - u, v) + + # Return closest points + return (p0.T + u[:, np.newaxis] * e0.T + v[:, np.newaxis] * e1.T).swapaxes(2, 1) + + +def _barycentric_weights(vecs, coords): + """Compute the weights for barycentric interpolation. + + Parameters + ---------- + vecs : ndarray of shape (6, 3) + The 6 vectors used to compute barycentric weights. + a, e1, e2, + np.cross(e1, e2), + np.cross(e2, a), + np.cross(a, e1) + coords : ndarray of shape (3, ) + + Returns + ------- + (w, u, v, t) : tuple of float + ``w``, ``u``, and ``v`` are the weights of the three vertices of the + triangle, respectively. ``t`` is the scale that needs to be multiplied + to ``coords`` to make it in the same plane as the three vertices. + + From: https://github.com/neuroboros/neuroboros/blob/\ +f2a2efb914e783add2bf06e0f3715236d3d8550e/src/neuroboros/surface/_barycentric.pyx#L9-L47 + """ + det = coords[0] * vecs[3, 0] + coords[1] * vecs[3, 1] + coords[2] * vecs[3, 2] + if det == 0: + if vecs[3, 0] == 0 and vecs[3, 1] == 0 and vecs[3, 2] == 0: + warnings.warn("Zero cross product of two edges: " + "The three vertices are in the same line.") + else: + print(vecs[3]) + y = coords - vecs[0] + u, v = np.linalg.lstsq(vecs[1:3].T, y, rcond=None)[0] + t = 1. + else: + uu = coords[0] * vecs[4, 0] + coords[1] * vecs[4, 1] + coords[2] * vecs[4, 2] + vv = coords[0] * vecs[5, 0] + coords[1] * vecs[5, 1] + coords[2] * vecs[5, 2] + u = uu / det + v = vv / det + tt = vecs[0, 0] * vecs[3, 0] + vecs[0, 1] * vecs[3, 1] + vecs[0, 2] * vecs[3, 2] + t = tt / det + w = 1. - (u + v) + return w, u, v, t + + +def _find_close_tris(kdsv, tri_lut, surface): + tris = [] + for kk in kdsv: + tris.extend(tri_lut[kk]) + close_tri_verts = surface._triangles[np.unique(tris)] + close_tris = surface._coords[close_tri_verts] + return close_tris + + +def _find_weights(point, close_tris, d_tree): + point = point[np.newaxis, :] + tri_dists = cdist(point, _points_to_triangles(point, close_tris).squeeze()) + + closest_tri = close_tris[(tri_dists == tri_dists.min()).squeeze()] + # make sure a single closest triangle was found + if closest_tri.shape[0] != 1: + # in the event of a tie (which can happen) + # just take the first triangle + closest_tri = closest_tri[0] + + closest_tri = closest_tri.squeeze() + # Make sure point is actually inside triangle + enclosing = True + if np.all((point > closest_tri).sum(0) != 3): + + enclosing = False + _, ct_idxs = d_tree.query(closest_tri) + a = closest_tri[0] + e1 = closest_tri[1] - a + e2 = closest_tri[2] - a + vecs = np.vstack([a, e1, e2, np.cross(e1, e2), np.cross(e2, a), np.cross(a, e1)]) + res = {} + res[ct_idxs[0]], res[ct_idxs[1]], res[ct_idxs[2]], _ = _barycentric_weights( + vecs, + point.squeeze() + ) + return res, enclosing diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index d32ce7f9..fb4be8d8 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -4,11 +4,13 @@ import pytest import h5py + from ..base import ( SpatialReference, SampledSpatialData, ImageGrid, TransformBase, + SurfaceMesh, ) from .. import linear as nitl from ..resampling import apply @@ -159,3 +161,31 @@ def test_concatenation(testdata_path): x = [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)] assert np.all((aff + nitl.Affine())(x) == x) assert np.all((aff + nitl.Affine())(x, inverse=True) == x) + + +def test_SurfaceMesh(testdata_path): + surf_path = testdata_path / "sub-200148_hemi-R_pial.surf.gii" + shape_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_thickness.shape.gii" + img_path = testdata_path / "bold.nii.gz" + + mesh = SurfaceMesh(nb.load(surf_path)) + exp_coords_shape = (249277, 3) + exp_tris_shape = (498550, 3) + assert mesh._coords.shape == exp_coords_shape + assert mesh._triangles.shape == exp_tris_shape + assert mesh._npoints == exp_coords_shape[0] + assert mesh._ndim == exp_coords_shape[1] + + mfd = SurfaceMesh(surf_path) + assert (mfd._coords == mesh._coords).all() + assert (mfd._triangles == mesh._triangles).all() + + mfsm = SurfaceMesh(mfd) + assert (mfd._coords == mfsm._coords).all() + assert (mfd._triangles == mfsm._triangles).all() + + with pytest.raises(ValueError): + SurfaceMesh(nb.load(img_path)) + + with pytest.raises(TypeError): + SurfaceMesh(nb.load(shape_path)) diff --git a/nitransforms/tests/test_surface.py b/nitransforms/tests/test_surface.py new file mode 100644 index 00000000..a210583e --- /dev/null +++ b/nitransforms/tests/test_surface.py @@ -0,0 +1,241 @@ +import tempfile + +import numpy as np +import nibabel as nb +import pytest +from scipy import sparse +from nitransforms.base import SurfaceMesh +from nitransforms.surface import ( + SurfaceTransformBase, + SurfaceCoordinateTransform, + SurfaceResampler +) + +# def test_surface_transform_npz(): +# mat = sparse.random(10, 10, density=0.5) +# xfm = SurfaceCoordinateTransform(mat) +# fn = tempfile.mktemp(suffix=".npz") +# print(fn) +# xfm.to_filename(fn) +# +# xfm2 = SurfaceCoordinateTransform.from_filename(fn) +# try: +# assert xfm.mat.shape == xfm2.mat.shape +# np.testing.assert_array_equal(xfm.mat.data, xfm2.mat.data) +# np.testing.assert_array_equal(xfm.mat.indices, xfm2.mat.indices) +# np.testing.assert_array_equal(xfm.mat.indptr, xfm2.mat.indptr) +# except Exception: +# os.remove(fn) +# raise +# os.remove(fn) + + +# def test_surface_transform_normalization(): +# mat = np.random.uniform(size=(20, 10)) +# xfm = SurfaceCoordinateTransform(mat) +# x = np.random.uniform(size=(5, 20)) +# y_element = xfm.apply(x, normalize="element") +# np.testing.assert_array_less(y_element.sum(axis=1), x.sum(axis=1)) +# y_sum = xfm.apply(x, normalize="sum") +# np.testing.assert_allclose(y_sum.sum(axis=1), x.sum(axis=1)) +# y_none = xfm.apply(x, normalize="none") +# assert y_none.sum() != y_element.sum() +# assert y_none.sum() != y_sum.sum() + +def test_SurfaceTransformBase(testdata_path): + # note these transformations are a bit of a weird use of surface transformation, but I'm + # just testing the base class and the io + sphere_reg_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + ) + pial_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + + sphere_reg = SurfaceMesh(nb.load(sphere_reg_path)) + pial = SurfaceMesh(nb.load(pial_path)) + stfb = SurfaceTransformBase(sphere_reg, pial) + + # test loading from filenames + stfb_ff = SurfaceTransformBase.from_filename(sphere_reg_path, pial_path) + assert stfb_ff == stfb + + # test inversion and setting + stfb_i = ~stfb + stfb.reference = pial + stfb.moving = sphere_reg + assert np.all(stfb_i._reference._coords == stfb._reference._coords) + assert np.all(stfb_i._reference._triangles == stfb._reference._triangles) + assert np.all(stfb_i._moving._coords == stfb._moving._coords) + assert np.all(stfb_i._moving._triangles == stfb._moving._triangles) + # test equality + assert stfb_i == stfb + + +def test_SurfaceCoordinateTransform(testdata_path): + # note these transformations are a bit of a weird use of surface transformation, but I'm + # just testing the class and the io + sphere_reg_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + ) + pial_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + fslr_sphere_path = testdata_path / "tpl-fsLR_hemi-R_den-32k_sphere.surf.gii" + + sphere_reg = SurfaceMesh(nb.load(sphere_reg_path)) + pial = SurfaceMesh(nb.load(pial_path)) + fslr_sphere = SurfaceMesh(nb.load(fslr_sphere_path)) + + # test mesh correspondence test + with pytest.raises(ValueError): + sct = SurfaceCoordinateTransform(fslr_sphere, pial) + + # test loading from filenames + sct = SurfaceCoordinateTransform(pial, sphere_reg) + sctf = SurfaceCoordinateTransform.from_filename(reference_path=pial_path, + moving_path=sphere_reg_path) + assert sct == sctf + + # test mapping + assert np.all(sct.map(sct.moving._coords[:100], inverse=True) == sct.reference._coords[:100]) + assert np.all(sct.map(sct.reference._coords[:100]) == sct.moving._coords[:100]) + with pytest.raises(NotImplementedError): + sct.map(sct.moving._coords[0]) + + # test inversion and addition + scti = ~sct + + assert sct + scti == SurfaceCoordinateTransform(pial, pial) + assert scti + sct == SurfaceCoordinateTransform(sphere_reg, sphere_reg) + + sct.reference = sphere_reg + sct.moving = pial + assert np.all(scti.reference._coords == sct.reference._coords) + assert np.all(scti.reference._triangles == sct.reference._triangles) + assert scti == sct + + +def test_SurfaceCoordinateTransformIO(testdata_path, tmpdir): + sphere_reg_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + ) + pial_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + + sct = SurfaceCoordinateTransform(pial_path, sphere_reg_path) + fn = tempfile.mktemp(suffix=".h5") + sct.to_filename(fn) + sct2 = SurfaceCoordinateTransform.from_filename(fn) + assert sct == sct2 + + +def test_ProjectUnproject(testdata_path): + + sphere_reg_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + ) + fslr_sphere_path = testdata_path / "tpl-fsLR_hemi-R_den-32k_sphere.surf.gii" + subj_fsaverage_sphere_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsaverage_desc-reg_sphere.surf.gii" + ) + fslr_fsaverage_sphere_path = ( + testdata_path + / "tpl-fsLR_space-fsaverage_hemi-R_den-32k_sphere.surf.gii" + ) + pial_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + + # test project-unproject funcitonality + projunproj = SurfaceResampler(sphere_reg_path, fslr_sphere_path) + with pytest.raises(ValueError): + projunproj.apply(pial_path) + transformed = projunproj.apply(fslr_fsaverage_sphere_path) + projunproj_ref = nb.load(subj_fsaverage_sphere_path) + assert (projunproj_ref.agg_data()[0] - transformed._coords).max() < 0.0005 + assert np.all(transformed._triangles == projunproj_ref.agg_data()[1]) + + +def test_SurfaceResampler(testdata_path, tmpdir): + dif_tol = 0.001 + fslr_sphere_path = ( + testdata_path + / "tpl-fsLR_hemi-R_den-32k_sphere.surf.gii" + ) + shape_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_thickness.shape.gii" + ) + ref_resampled_thickness_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_thickness.shape.gii" + ) + pial_path = ( + testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + ) + sphere_reg_path = ( + testdata_path + / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + ) + + fslr_sphere = SurfaceMesh(nb.load(fslr_sphere_path)) + sphere_reg = SurfaceMesh(nb.load(sphere_reg_path)) + subj_thickness = nb.load(shape_path) + + with pytest.raises(ValueError): + SurfaceResampler(sphere_reg_path, pial_path) + with pytest.raises(ValueError): + SurfaceResampler(pial_path, sphere_reg_path) + + reference = fslr_sphere + moving = sphere_reg + # compare results to what connectome workbench produces + resampling = SurfaceResampler(reference, moving) + resampled_thickness = resampling.apply(subj_thickness.agg_data(), normalize='element') + ref_resampled = nb.load(ref_resampled_thickness_path).agg_data() + + max_dif = np.abs(resampled_thickness.astype(np.float32) - ref_resampled).max() + assert max_dif < dif_tol + + with pytest.raises(ValueError): + SurfaceResampler(reference, moving, mat=resampling.mat[:, :10000]) + with pytest.raises(ValueError): + SurfaceResampler(reference, moving, mat=resampling.mat[:10000, :]) + with pytest.raises(ValueError): + resampling.reference = reference + with pytest.raises(ValueError): + resampling.moving = moving + with pytest.raises(NotImplementedError): + _ = SurfaceResampler(reference, moving, "foo") + + # test file io + fn = tempfile.mktemp(suffix=".h5") + resampling.to_filename(fn) + resampling2 = SurfaceResampler.from_filename(fn) + + # assert resampling2 == resampling + assert np.allclose(resampling2.reference._coords, resampling.reference._coords) + assert np.all(resampling2.reference._triangles == resampling.reference._triangles) + assert np.allclose(resampling2.reference._coords, resampling.reference._coords) + assert np.all(resampling2.moving._triangles == resampling.moving._triangles) + + resampled_thickness2 = resampling2.apply(subj_thickness.agg_data(), normalize='element') + assert np.all(resampled_thickness2 == resampled_thickness) + + # test loading with a csr + assert isinstance(resampling.mat, sparse.csr_array) + resampling2a = SurfaceResampler(reference, moving, mat=resampling.mat) + resampled_thickness2a = resampling2a.apply(subj_thickness.agg_data(), normalize='element') + assert np.all(resampled_thickness2a == resampled_thickness) + + with pytest.raises(ValueError): + _ = SurfaceResampler(moving, reference, mat=resampling.mat) + + # test map + assert np.all(resampling.map(np.array([[0, 0, 0]])) == np.array([[0, 0, 0]])) + + # test loading from surfaces + resampling3 = SurfaceResampler.from_filename(reference_path=fslr_sphere_path, + moving_path=sphere_reg_path) + assert resampling3 == resampling + resampled_thickness3 = resampling3.apply(subj_thickness.agg_data(), normalize='element') + assert np.all(resampled_thickness3 == resampled_thickness) diff --git a/setup.cfg b/setup.cfg index 20fe531e..f0074ae1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = scipy >= 1.6.0 nibabel >= 3.0 h5py + pathlib test_requires = pytest pytest-cov