From 7ed97e910fed0c19f6cc1f8687b1522c0f43ba78 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 10:25:41 +0200 Subject: [PATCH 01/14] Remove net forces in ASE calculator --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8c26477a..3aab1f17 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -465,6 +465,10 @@ def calculate( forces_values = ( outputs["non_conservative_forces"].block().values.detach() ) + # remove any spurious net force + forces_values = forces_values - forces_values.mean( + dim=0, keepdim=True + ) else: forces_values = -system.positions.grad forces_values = forces_values.reshape(-1, 3) @@ -587,6 +591,12 @@ def compute_energy( results_as_numpy_arrays["forces"], split_indices, axis=0 ) + # remove net forces + results_as_numpy_arrays["forces"] = [ + f - f.mean(axis=0, keepdims=True) + for f in results_as_numpy_arrays["forces"] + ] + if all(atoms.pbc.all() for atoms in atoms_list): results_as_numpy_arrays["stress"] = [ s From 159b174aed6b22c910baf617deed96671449b610 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:28:04 +0200 Subject: [PATCH 02/14] Add SO3 and O3 averaging calculators --- .../metatomic/torch/ase_calculator.py | 247 +++++++++++++++++- 1 file changed, 244 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3aab1f17..5fe45bff 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,13 +2,15 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import metatensor.torch import numpy as np import torch import vesin from metatensor.torch import Labels, TensorBlock, TensorMap +from scipy.integrate import lebedev_rule +from scipy.spatial.transform import Rotation from torch.profiler import record_function from . import ( @@ -31,7 +33,6 @@ all_properties as ALL_ASE_PROPERTIES, ) - FilePath = Union[str, bytes, pathlib.PurePath] LOGGER = logging.getLogger(__name__) @@ -593,7 +594,7 @@ def compute_energy( # remove net forces results_as_numpy_arrays["forces"] = [ - f - f.mean(axis=0, keepdims=True) + f - f.mean(axis=0, keepdims=True) for f in results_as_numpy_arrays["forces"] ] @@ -824,3 +825,243 @@ def _full_3x3_to_voigt_6_stress(stress): (stress[0, 1] + stress[1, 0]) / 2.0, ] ) + + +class SO3AveragedCalculator(ase.calculators.calculator.Calculator): + """ + Take a MetatomicCalculator and average its predictions over a + Lebedev (S^2) x Uniform (S^1) grid of rotations in SO(3). + """ + + implemented_properties = ["energy", "forces", "stress"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + lebedev_order: int = 3, + n_inplane_rotations: int = 4, + batch_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.base_calculator = base_calculator + self.lebedev_order = lebedev_order + self.n_inplane_rotations = n_inplane_rotations + + self.so3_quadrature_rotations = _get_so3_quadrature( + lebedev_order, n_inplane_rotations + ) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.so3_quadrature_rotations) + ) + + def calculate(self, atoms, properties, system_changes): + super().calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + + if len(self.so3_quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.so3_quadrature_rotations) + batch_size = ( + self.batch_size + if self.batch_size is not None + else len(rotated_atoms_list) + ) + batches = [ + rotated_atoms_list[i : i + batch_size] + for i in range(0, len(rotated_atoms_list), batch_size) + ] + results: Dict[str, Any] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, compute_forces_and_stresses + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " + "parameters while initializing the calculator." + f"Full error message: {e}" + ) + + results = _compute_rotational_average( + results, self.so3_quadrature_rotations + ) + self.results.update(results) + + +class O3AveragedCalculator(ase.calculators.calculator.Calculator): + """ + Take a MetatomicCalculator and average its predictions over a + Lebedev (S^2) x Uniform (S^1) grid of rotations in O(3). + """ + + implemented_properties = ["energy", "forces", "stress"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + lebedev_order: int = 3, + n_inplane_rotations: int = 4, + batch_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.base_calculator = base_calculator + self.lebedev_order = lebedev_order + self.n_inplane_rotations = n_inplane_rotations + + self.o3_quadrature_rotations = _get_o3_quadrature( + lebedev_order, n_inplane_rotations + ) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.o3_quadrature_rotations) + ) + + def calculate(self, atoms, properties, system_changes): + super().calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + + if len(self.o3_quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.o3_quadrature_rotations) + batches = [ + rotated_atoms_list[i : i + self.batch_size] + for i in range(0, len(rotated_atoms_list), self.batch_size) + ] + results: Dict[str, Any] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, compute_forces_and_stresses + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " + "parameters while initializing the calculator." + f"Full error message: {e}" + ) + + results = _compute_rotational_average(results, self.o3_quadrature_rotations) + self.results.update(results) + + +def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + rotated_atoms_list = [] + has_cell = atoms.cell is not None and atoms.cell.rank > 0 + for rot in rotations: + new_atoms = atoms.copy() + new_atoms.positions = new_atoms.positions @ rot.T + if has_cell: + new_atoms.cell = new_atoms.cell @ rot.T + rotated_atoms_list.append(new_atoms) + return rotated_atoms_list + + +def _get_so3_quadrature(lebedev_order: int, n_rotations: int): + """ + Lebedev(S^2) x uniform angle quadrature on SO(3). + """ + + # Lebedev nodes (X: (3, M)) + X, _ = lebedev_rule(lebedev_order) + + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, K) # (N,) + B = np.repeat(beta, K) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + Rmats = Rot.as_matrix() # (N, 3, 3) + + return Rmats + + +def _get_o3_quadrature(lebedev_order: int, n_rotations: int): + """ + Lebedev(S^2) x uniform angle quadrature on O(3). + Returns an array of shape (2N, 3, 3) with orthogonal matrices, + the first N in SO(3), the next N in its coset with inversion. + """ + # Lebedev nodes (X: (3, M)) + X, _ = lebedev_rule(lebedev_order) + + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, K) # (N,) + B = np.repeat(beta, K) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + + return R_o3 + + +def _compute_rotational_average(results, rotations): + R = np.asarray(rotations) # (B,3,3) + out = {} + if "energy" in results: + arr = np.asarray(results["energy"]) + out["energy"] = arr.mean() + out["energy_rot_std"] = arr.std() + if "forces" in results: + F = np.stack(results["forces"], axis=0) # (B,N,3) + F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) + out["forces"] = F_back.mean(axis=0) + out["forces_rot_std"] = F_back.std(axis=0) + if "stress" in results: + S = np.stack(results["stress"], axis=0) # (B,3,3) + RT = np.swapaxes(R, 1, 2) + tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) + S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) + out["stress"] = S_back.mean(axis=0) + out["stress_rot_std"] = S_back.std(axis=0) + return out From 484a4a2a934f3c8f4cde60aa06e77ea7ee354c9c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:38:01 +0200 Subject: [PATCH 03/14] Update --- .../metatomic/torch/ase_calculator.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 1e57dddc..08eabf9e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -188,9 +188,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -863,7 +863,15 @@ def _full_3x3_to_voigt_6_stress(stress): ) -<<<<<<< HEAD +def _get_energy_uncertainty_output(): + return ModelOutput( + quantity="energy", + unit="eV", + per_atom=True, + explicit_gradients=[], + ) + + class SO3AveragedCalculator(ase.calculators.calculator.Calculator): """ Take a MetatomicCalculator and average its predictions over a @@ -1102,12 +1110,3 @@ def _compute_rotational_average(results, rotations): out["stress"] = S_back.mean(axis=0) out["stress_rot_std"] = S_back.std(axis=0) return out -======= -def _get_energy_uncertainty_output(): - return ModelOutput( - quantity="energy", - unit="eV", - per_atom=True, - explicit_gradients=[], - ) ->>>>>>> main From 84706dc29e53820b5043a909f08b4aaf4c23be99 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:47:02 +0200 Subject: [PATCH 04/14] Fix typing --- .../metatomic/torch/ase_calculator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 08eabf9e..8d50aa8e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import metatensor.torch import numpy as np @@ -188,9 +188,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs @@ -918,7 +918,7 @@ def calculate(self, atoms, properties, system_changes): rotated_atoms_list[i : i + batch_size] for i in range(0, len(rotated_atoms_list), batch_size) ] - results: Dict[str, Any] = {} + results: Dict[str, np.ndarray] = {} for batch in batches: try: batch_results = self.base_calculator.compute_energy( @@ -986,7 +986,7 @@ def calculate(self, atoms, properties, system_changes): rotated_atoms_list[i : i + self.batch_size] for i in range(0, len(rotated_atoms_list), self.batch_size) ] - results: Dict[str, Any] = {} + results: Dict[str, np.ndarray] = {} for batch in batches: try: batch_results = self.base_calculator.compute_energy( From ea98c0d84d087bb0f871a3eec64c9a50e2cdca81 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 17:19:12 +0200 Subject: [PATCH 05/14] Make scipy an optional dependency --- .../metatomic/torch/ase_calculator.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8d50aa8e..3cbd8107 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -9,8 +9,6 @@ import torch import vesin from metatensor.torch import Labels, TensorBlock, TensorMap -from scipy.integrate import lebedev_rule -from scipy.spatial.transform import Rotation from torch.profiler import record_function from . import ( @@ -888,6 +886,14 @@ def __init__( batch_size: Optional[int] = None, **kwargs, ): + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the SO3AveragedCalculator, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + super().__init__(**kwargs) self.base_calculator = base_calculator @@ -939,6 +945,12 @@ def calculate(self, atoms, properties, system_changes): f"Full error message: {e}" ) + # Clean up + try: + torch.cuda.empty_cache() + except Exception: + pass + results = _compute_rotational_average( results, self.so3_quadrature_rotations ) @@ -961,6 +973,14 @@ def __init__( batch_size: Optional[int] = None, **kwargs, ): + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the SO3AveragedCalculator, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + super().__init__(**kwargs) self.base_calculator = base_calculator @@ -1007,6 +1027,12 @@ def calculate(self, atoms, properties, system_changes): f"Full error message: {e}" ) + # Clean up + try: + torch.cuda.empty_cache() + except Exception: + pass + results = _compute_rotational_average(results, self.o3_quadrature_rotations) self.results.update(results) @@ -1027,6 +1053,8 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): """ Lebedev(S^2) x uniform angle quadrature on SO(3). """ + from scipy.integrate import lebedev_rule + from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) @@ -1060,6 +1088,9 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): Returns an array of shape (2N, 3, 3) with orthogonal matrices, the first N in SO(3), the next N in its coset with inversion. """ + from scipy.integrate import lebedev_rule + from scipy.spatial.transform import Rotation + # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) From d375984f870660567e5f937e120fb302b306ed2c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Tue, 16 Sep 2025 14:31:43 +0200 Subject: [PATCH 06/14] Update rotation routines --- .../metatomic/torch/ase_calculator.py | 124 ++++++++++++------ 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3cbd8107..4a4113e6 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -971,6 +971,7 @@ def __init__( lebedev_order: int = 3, n_inplane_rotations: int = 4, batch_size: Optional[int] = None, + return_o3_samples=False, **kwargs, ): try: @@ -987,7 +988,7 @@ def __init__( self.lebedev_order = lebedev_order self.n_inplane_rotations = n_inplane_rotations - self.o3_quadrature_rotations = _get_o3_quadrature( + self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( lebedev_order, n_inplane_rotations ) @@ -995,6 +996,8 @@ def __init__( batch_size if batch_size is not None else len(self.o3_quadrature_rotations) ) + self.return_o3_samples = return_o3_samples + def calculate(self, atoms, properties, system_changes): super().calculate(atoms, properties, system_changes) @@ -1033,8 +1036,13 @@ def calculate(self, atoms, properties, system_changes): except Exception: pass - results = _compute_rotational_average(results, self.o3_quadrature_rotations) - self.results.update(results) + self.results.update( + _compute_rotational_average( + results, self.o3_quadrature_rotations, self.o3_quadrature_weights + ) + ) + if self.return_o3_samples: + self.results["o3_samples"] = results def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: @@ -1044,7 +1052,10 @@ def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Ato new_atoms = atoms.copy() new_atoms.positions = new_atoms.positions @ rot.T if has_cell: - new_atoms.cell = new_atoms.cell @ rot.T + new_atoms.set_cell( + new_atoms.cell.array @ rot.T, scale_atoms=False, apply_constraint=False + ) + new_atoms.wrap() rotated_atoms_list.append(new_atoms) return rotated_atoms_list @@ -1054,7 +1065,6 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): Lebedev(S^2) x uniform angle quadrature on SO(3). """ from scipy.integrate import lebedev_rule - from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) @@ -1066,19 +1076,13 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) - # Build all combinations (alpha_i, beta_i, gamma_j) - A = np.repeat(alpha, K) # (N,) - B = np.repeat(beta, K) # (N,) - G = np.tile(gamma, alpha.size) # (N,) - - # Compose ZYZ rotations - Rot = ( - Rotation.from_euler("z", A) - * Rotation.from_euler("y", B) - * Rotation.from_euler("z", G) - ) + Rot = _rotations_from_angles(alpha, beta, gamma) Rmats = Rot.as_matrix() # (N, 3, 3) + # Re-orthogonalize the rotation matrices to avoid numerical issues + U, _, Vt = np.linalg.svd(Rmats, full_matrices=False) + Rmats = U @ Vt + return Rmats @@ -1089,11 +1093,9 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): the first N in SO(3), the next N in its coset with inversion. """ from scipy.integrate import lebedev_rule - from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) - X, _ = lebedev_rule(lebedev_order) - + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) @@ -1101,9 +1103,26 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + Rot = _rotations_from_angles(alpha, beta, gamma) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + w_o3 = np.concatenate([0.5 * w_so3, 0.5 * w_so3], axis=0) + + return R_o3, w_o3 + + +def _rotations_from_angles(alpha, beta, gamma): + from scipy.spatial.transform import Rotation + # Build all combinations (alpha_i, beta_i, gamma_j) - A = np.repeat(alpha, K) # (N,) - B = np.repeat(beta, K) # (N,) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) G = np.tile(gamma, alpha.size) # (N,) # Compose ZYZ rotations in SO(3) @@ -1112,32 +1131,59 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): * Rotation.from_euler("y", B) * Rotation.from_euler("z", G) ) - R_so3 = Rot.as_matrix() # (N, 3, 3) - # Extend to O(3) by appending inversion * R - P = -np.eye(3) - R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + return Rot + + +def _compute_rotational_average(results, rotations, weights): + R = rotations + B = R.shape[0] + w = weights + w = w / w.sum() + + def _wreshape(x): + return w.reshape((B,) + (1,) * (x.ndim - 1)) - return R_o3 + def _wmean(x): + return np.sum(_wreshape(x) * x, axis=0) + def _wstd(x): + mu = _wmean(x) + return np.sqrt(np.sum(_wreshape(x) * (x - mu) ** 2, axis=0)) -def _compute_rotational_average(results, rotations): - R = np.asarray(rotations) # (B,3,3) out = {} + + # Energy (B,) if "energy" in results: - arr = np.asarray(results["energy"]) - out["energy"] = arr.mean() - out["energy_rot_std"] = arr.std() + E = np.asarray(results["energy"], dtype=float) + if E.shape != (B,): + raise ValueError(f"energy must be shape ({B},), got {E.shape}") + out["energy"] = _wmean(E) + out["energy_rot_std"] = _wstd(E) + + # Forces (B,N,3) from rotated structures: back-rotate with R^T F' if "forces" in results: - F = np.stack(results["forces"], axis=0) # (B,N,3) - F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) - out["forces"] = F_back.mean(axis=0) - out["forces_rot_std"] = F_back.std(axis=0) + F = np.asarray(results["forces"], dtype=float) # (B,N,3) + if F.ndim == 2: + F = F[np.newaxis, ...] + if F.shape[0] != B or F.shape[-1] != 3: + raise ValueError(f"forces must be (B,N,3); got {F.shape} with B={B}") + RT = np.swapaxes(R, 1, 2) + F_back = np.einsum("bnj,bjk->bnk", F, RT, optimize=True) # R^T * F' + out["forces"] = _wmean(F_back) # (N,3) + out["forces_rot_std"] = _wstd(F_back) # (N,3) + + # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: - S = np.stack(results["stress"], axis=0) # (B,3,3) + S = np.asarray(results["stress"], dtype=float) # (B,3,3) + if S.ndim == 2: + S = S[np.newaxis, ...] + if S.shape != (B, 3, 3): + raise ValueError(f"stress must be (B,3,3); got {S.shape} with B={B}") RT = np.swapaxes(R, 1, 2) tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) - S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) - out["stress"] = S_back.mean(axis=0) - out["stress_rot_std"] = S_back.std(axis=0) + S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) # R^T S' R + out["stress"] = _wmean(S_back) # (3,3) + out["stress_rot_std"] = _wstd(S_back) # (3,3) + return out From 57c953e4f3e2ee8a89d114f592ce1f92928bbe88 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 09:39:53 +0200 Subject: [PATCH 07/14] Simplify args --- .../metatomic/torch/ase_calculator.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 4a4113e6..d38979ed 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -968,8 +968,7 @@ class O3AveragedCalculator(ase.calculators.calculator.Calculator): def __init__( self, base_calculator: MetatomicCalculator, - lebedev_order: int = 3, - n_inplane_rotations: int = 4, + l_max: int = 3, batch_size: Optional[int] = None, return_o3_samples=False, **kwargs, @@ -985,9 +984,13 @@ def __init__( super().__init__(**kwargs) self.base_calculator = base_calculator - self.lebedev_order = lebedev_order - self.n_inplane_rotations = n_inplane_rotations + if l_max > 131: + raise ValueError( + f"l_max={l_max} is too large, the maximum supported value is 131" + ) + self.l_max = l_max + lebedev_order, n_inplane_rotations = choose_quadrature(l_max) self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( lebedev_order, n_inplane_rotations ) @@ -1045,6 +1048,48 @@ def calculate(self, atoms, properties, system_changes): self.results["o3_samples"] = results +def choose_quadrature(L_max): + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: rotated_atoms_list = [] has_cell = atoms.cell is not None and atoms.cell.rank > 0 From b6354b73c8303b031ef79288378cf0363f077830 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 10:57:24 +0200 Subject: [PATCH 08/14] Fix bug --- .../metatomic/torch/ase_calculator.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index d38979ed..ffd5dfc5 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -1201,30 +1201,19 @@ def _wstd(x): # Energy (B,) if "energy" in results: E = np.asarray(results["energy"], dtype=float) - if E.shape != (B,): - raise ValueError(f"energy must be shape ({B},), got {E.shape}") out["energy"] = _wmean(E) out["energy_rot_std"] = _wstd(E) # Forces (B,N,3) from rotated structures: back-rotate with R^T F' if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) - if F.ndim == 2: - F = F[np.newaxis, ...] - if F.shape[0] != B or F.shape[-1] != 3: - raise ValueError(f"forces must be (B,N,3); got {F.shape} with B={B}") - RT = np.swapaxes(R, 1, 2) - F_back = np.einsum("bnj,bjk->bnk", F, RT, optimize=True) # R^T * F' + F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) # F' R out["forces"] = _wmean(F_back) # (N,3) out["forces_rot_std"] = _wstd(F_back) # (N,3) # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: S = np.asarray(results["stress"], dtype=float) # (B,3,3) - if S.ndim == 2: - S = S[np.newaxis, ...] - if S.shape != (B, 3, 3): - raise ValueError(f"stress must be (B,3,3); got {S.shape} with B={B}") RT = np.swapaxes(R, 1, 2) tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) # R^T S' R From 6a09b260998cb919bc5cd206a1c4b7013cd7b2f8 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 13:07:19 +0200 Subject: [PATCH 09/14] Add group symmetrization --- .../metatomic/torch/ase_calculator.py | 142 +++++++++++++++++- 1 file changed, 141 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index ffd5dfc5..16d3a515 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -971,6 +971,7 @@ def __init__( l_max: int = 3, batch_size: Optional[int] = None, return_o3_samples=False, + apply_group_symmetry=False, **kwargs, ): try: @@ -1000,10 +1001,14 @@ def __init__( ) self.return_o3_samples = return_o3_samples + self.apply_group_symmetry = apply_group_symmetry def calculate(self, atoms, properties, system_changes): super().calculate(atoms, properties, system_changes) + if self.apply_group_symmetry: + Q_list, P_list = _get_group_operations(atoms) + compute_forces_and_stresses = "forces" in properties or "stress" in properties if len(self.o3_quadrature_rotations) > 0: @@ -1044,6 +1049,9 @@ def calculate(self, atoms, properties, system_changes): results, self.o3_quadrature_rotations, self.o3_quadrature_weights ) ) + + if self.apply_group_symmetry: + self.results.update(_average_over_group(self.results, Q_list, P_list)) if self.return_o3_samples: self.results["o3_samples"] = results @@ -1221,3 +1229,135 @@ def _wstd(x): out["stress_rot_std"] = _wstd(S_back) # (3,3) return out + + +def _get_group_operations( + atoms: ase.Atoms, symprec: float = 1e-6, angle_tolerance: float = -1.0 +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Extract point-group rotations Q_g (Cartesian, 3x3) and the corresponding + atom-index permutations P_g (N x N) induced by the space-group operations. + Returns Q_list, Cartesian rotation matrices of the point group, + and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + """ + try: + import spglib + except ImportError as e: + raise ImportError( + "spglib is required to use the O3AveragedCalculator with " + "`apply_group_symmetry=True`. Please install it with " + "`pip install spglib` or `conda install -c conda-forge spglib`" + ) from e + + # Lattice with column vectors a1,a2,a3 (spglib expects (cell, frac, Z)) + A = atoms.cell.array.T # (3,3) + frac = atoms.get_scaled_positions() # (N,3) in [0,1) + numbers = atoms.numbers + N = len(atoms) + + data = spglib.get_symmetry_dataset( + (atoms.cell.array, frac, numbers), + symprec=symprec, + angle_tolerance=angle_tolerance, + ) + R_frac = data.rotations # (n_ops, 3,3), integer + t_frac = data.translations # (n_ops, 3) + Z = numbers + + # Match fractional coords modulo 1 within a tolerance, respecting chemical species + def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): + d = np.abs(frac_ref - x_new) # (N,3) + d = np.minimum(d, 1.0 - d) # periodic distance + # Mask by identical species + mask = Z_ref == Z_i + if not np.any(mask): + raise RuntimeError("No matching species found while building permutation.") + # Choose argmin over max-norm within species + idx = np.where(mask)[0] + j = idx[np.argmin(np.max(d[idx], axis=1))] + + # Sanity check + if np.max(d[j]) > tol: + pass + return j + + Q_list, P_list = [], [] + seen = set() + Ainv = np.linalg.inv(A) + + for Rf, tf in zip(R_frac, t_frac): + # Cartesian rotation: Q = A Rf A^{-1} + Q = A @ Rf @ Ainv + # Deduplicate rotations (point group) by rounding + key = tuple(np.round(Q.flatten(), 12)) + if key in seen: + continue + seen.add(key) + + # Build the permutation P from i to j + P = np.zeros((N, N), dtype=int) + new_frac = (frac @ Rf.T + tf) % 1.0 # images after (Rf,tf) + for i in range(N): + j = _match_index(new_frac[i], frac, Z, Z[i]) + P[j, i] = 1 # column i maps to row j + + Q_list.append(Q.astype(float)) + P_list.append(P) + + return Q_list, P_list + + +def _average_over_group( + results: dict, Q_list: List[np.ndarray], P_list: List[np.ndarray] +) -> dict: + """ + Apply the point-group projector in output space. + """ + m = len(Q_list) + if m == 0: + # No symmetry found; return copies + out = {} + if "energy" in results: + out["energy_pg"] = float(results["energy"]) + if "forces" in results: + out["forces_pg"] = np.array(results["forces"], float, copy=True) + if "stress" in results: + S = np.array(results["stress"], float, copy=True) + S = 0.5 * (S + S.T) + out["stress_pg"] = S + out["stress_iso_pg"] = np.eye(3) * (np.trace(S) / 3.0) + out["stress_dev_pg"] = S - out["stress_iso_pg"] + return out + + out = {} + # Energy: unchanged by the projector (scalar invariant) + if "energy" in results: + out["energy"] = float(results["energy"]) + + # Forces: (N,3) row-vectors; projector: (1/|G|) \sum_g P_g^T F Q_g + if "forces" in results: + F = np.asarray(results["forces"], float) + if F.ndim != 2 or F.shape[1] != 3: + raise ValueError(f"'forces' must be (N,3), got {F.shape}") + acc = np.zeros_like(F) + for Q, P in zip(Q_list, P_list): + acc += P.T @ (F @ Q) + out["forces"] = acc / m + + # Stress: (3,3); projector: (1/|G|) \sum_g Q_g^T S Q_g + if "stress" in results: + S = np.asarray(results["stress"], float) + if S.shape != (3, 3): + raise ValueError(f"'stress' must be (3,3), got {S.shape}") + S = 0.5 * (S + S.T) # symmetrize just in case + acc = np.zeros_like(S) + for Q in Q_list: + acc += Q.T @ S @ Q + S_pg = acc / m + out["stress"] = S_pg + # # Expose L=0 projection and deviatoric part for debugging + # S_iso = np.trace(S_pg) / 3.0 + # out["stress_iso_pg"] = np.eye(3) * S_iso + # out["stress_dev_pg"] = S_pg - out["stress_iso_pg"] + + return out From 704b923b39556cf7dc8230c9ae338834bd72dc6f Mon Sep 17 00:00:00 2001 From: ppegolo Date: Sat, 20 Sep 2025 18:48:25 +0200 Subject: [PATCH 10/14] small change --- .../metatomic/torch/ase_calculator.py | 61 +++++++++++++------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 16d3a515..8fe059a4 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -991,10 +991,15 @@ def __init__( ) self.l_max = l_max - lebedev_order, n_inplane_rotations = choose_quadrature(l_max) - self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( - lebedev_order, n_inplane_rotations - ) + if l_max > 0: + lebedev_order, n_inplane_rotations = choose_quadrature(l_max) + self.o3_quadrature_rotations, self.o3_quadrature_weights = ( + _get_o3_quadrature(lebedev_order, n_inplane_rotations) + ) + else: + # no quadrature + self.o3_quadrature_rotations = np.array([np.eye(3)]) + self.o3_quadrature_weights = np.array([1.0]) self.batch_size = ( batch_size if batch_size is not None else len(self.o3_quadrature_rotations) @@ -1050,11 +1055,12 @@ def calculate(self, atoms, properties, system_changes): ) ) - if self.apply_group_symmetry: - self.results.update(_average_over_group(self.results, Q_list, P_list)) if self.return_o3_samples: self.results["o3_samples"] = results + if self.apply_group_symmetry: + self.results.update(_average_over_group(self.results, Q_list, P_list)) + def choose_quadrature(L_max): available = [ @@ -1120,23 +1126,21 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): from scipy.integrate import lebedev_rule # Lebedev nodes (X: (3, M)) - X, _ = lebedev_rule(lebedev_order) - + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) Rot = _rotations_from_angles(alpha, beta, gamma) - Rmats = Rot.as_matrix() # (N, 3, 3) + R_so3 = Rot.as_matrix() # (N, 3, 3) - # Re-orthogonalize the rotation matrices to avoid numerical issues - U, _, Vt = np.linalg.svd(Rmats, full_matrices=False) - Rmats = U @ Vt + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) - return Rmats + return R_so3, w_so3 def _get_o3_quadrature(lebedev_order: int, n_rotations: int): @@ -1151,7 +1155,8 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) @@ -1159,6 +1164,13 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): Rot = _rotations_from_angles(alpha, beta, gamma) R_so3 = Rot.as_matrix() # (N, 3, 3) + # rnd = np.random.uniform(size=(3, 3)) + # rnd = rnd - rnd.T + # import scipy.linalg + + # rnd = scipy.linalg.expm(-rnd) + # R_so3 = R_so3 @ rnd + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) @@ -1212,7 +1224,7 @@ def _wstd(x): out["energy"] = _wmean(E) out["energy_rot_std"] = _wstd(E) - # Forces (B,N,3) from rotated structures: back-rotate with R^T F' + # Forces (B,N,3) from rotated structures: back-rotate with F' R if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) # F' R @@ -1312,6 +1324,19 @@ def _average_over_group( ) -> dict: """ Apply the point-group projector in output space. + + Parameters + ---------- + results : dict + Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or 'stress' (3,3). + These are predictions for the *current* structure in the reference frame. + Q_list, P_list : outputs of _get_group_operations + + Returns + ------- + out : dict + Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. + For stress, also returns 'stress_iso_pg' (L=0) and 'stress_dev_pg'. """ m = len(Q_list) if m == 0: @@ -1323,7 +1348,7 @@ def _average_over_group( out["forces_pg"] = np.array(results["forces"], float, copy=True) if "stress" in results: S = np.array(results["stress"], float, copy=True) - S = 0.5 * (S + S.T) + # S = 0.5 * (S + S.T) out["stress_pg"] = S out["stress_iso_pg"] = np.eye(3) * (np.trace(S) / 3.0) out["stress_dev_pg"] = S - out["stress_iso_pg"] @@ -1349,7 +1374,7 @@ def _average_over_group( S = np.asarray(results["stress"], float) if S.shape != (3, 3): raise ValueError(f"'stress' must be (3,3), got {S.shape}") - S = 0.5 * (S + S.T) # symmetrize just in case + # S = 0.5 * (S + S.T) # symmetrize just in case acc = np.zeros_like(S) for Q in Q_list: acc += Q.T @ S @ Q From 971c22bbb13ff9069f474669361ac3fc1f82ff00 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 13:51:31 +0200 Subject: [PATCH 11/14] clean up and add tests --- .../metatomic/torch/ase_calculator.py | 286 ++++++-------- .../tests/symmetrized_ase_calculator.py | 354 ++++++++++++++++++ 2 files changed, 465 insertions(+), 175 deletions(-) create mode 100644 python/metatomic_torch/tests/symmetrized_ase_calculator.py diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8fe059a4..d3027f92 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -186,9 +186,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -870,97 +870,35 @@ def _get_energy_uncertainty_output(): ) -class SO3AveragedCalculator(ase.calculators.calculator.Calculator): - """ - Take a MetatomicCalculator and average its predictions over a - Lebedev (S^2) x Uniform (S^1) grid of rotations in SO(3). - """ - - implemented_properties = ["energy", "forces", "stress"] - - def __init__( - self, - base_calculator: MetatomicCalculator, - lebedev_order: int = 3, - n_inplane_rotations: int = 4, - batch_size: Optional[int] = None, - **kwargs, - ): - try: - from scipy.integrate import lebedev_rule # noqa: F401 - except ImportError as e: - raise ImportError( - "scipy is required to use the SO3AveragedCalculator, please install " - "it with `pip install scipy` or `conda install scipy`" - ) from e - - super().__init__(**kwargs) - - self.base_calculator = base_calculator - self.lebedev_order = lebedev_order - self.n_inplane_rotations = n_inplane_rotations - - self.so3_quadrature_rotations = _get_so3_quadrature( - lebedev_order, n_inplane_rotations - ) - - self.batch_size = ( - batch_size if batch_size is not None else len(self.so3_quadrature_rotations) - ) - - def calculate(self, atoms, properties, system_changes): - super().calculate(atoms, properties, system_changes) - - compute_forces_and_stresses = "forces" in properties or "stress" in properties - - if len(self.so3_quadrature_rotations) > 0: - rotated_atoms_list = _rotate_atoms(atoms, self.so3_quadrature_rotations) - batch_size = ( - self.batch_size - if self.batch_size is not None - else len(rotated_atoms_list) - ) - batches = [ - rotated_atoms_list[i : i + batch_size] - for i in range(0, len(rotated_atoms_list), batch_size) - ] - results: Dict[str, np.ndarray] = {} - for batch in batches: - try: - batch_results = self.base_calculator.compute_energy( - batch, compute_forces_and_stresses - ) - for key, value in batch_results.items(): - results.setdefault(key, []) - results[key].extend( - [value] if isinstance(value, float) else value - ) - except torch.cuda.OutOfMemoryError as e: - raise RuntimeError( - "Out of memory error encountered during rotational averaging. " - "Please reduce the batch size or use lower rotational " - "averaging parameters. This can be done by setting the " - "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " - "parameters while initializing the calculator." - f"Full error message: {e}" - ) - - # Clean up - try: - torch.cuda.empty_cache() - except Exception: - pass - - results = _compute_rotational_average( - results, self.so3_quadrature_rotations - ) - self.results.update(results) - - -class O3AveragedCalculator(ase.calculators.calculator.Calculator): - """ - Take a MetatomicCalculator and average its predictions over a - Lebedev (S^2) x Uniform (S^1) grid of rotations in O(3). +class SymmetrizedCalculator(ase.calculators.calculator.Calculator): + r""" + Take a MetatomicCalculator and average its predictions to make it (approximately) + equivariant. + + The default is to average over a quadrature of the orthogonal group O(3) composed + this way: + + - Lebedev quadrature of the unit sphere (S^2) + - Equispaced sampling of the unit circle (S^1) + - Both proper and improper rotations are taken into account by including the + inversion operation (if ``include_inversion=True``) + + :param base_calculator: the MetatomicCalculator to be symmetrized + :param l_max: the maximum spherical harmonic degree that the model is expected to + be able to represent. This is used to choose the quadrature order. If ``0``, + no rotational averaging will be performed (it can be useful to average only over + the space group, see ``apply_group_symmetry``). + :param batch_size: number of rotated systems to evaluate at once. If ``None``, all + systems will be evaluated at once (this can lead to high memory usage). + :param include_inversion: if ``True``, the inversion operation will be included in + the averaging. This is required to average over the full orthogonal group O(3). + :param apply_group_symmetry: if ``True``, the results will be averaged over the + discrete space group of rotations for the input system. The group operations are + computed with spglib, and the average is performed after the O(3) averaging + (if any). + :param return_samples: if ``True``, the results of the base calculator on each + rotated system will be returned. Most useful for debugging. + :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor """ implemented_properties = ["energy", "forces", "stress"] @@ -970,10 +908,11 @@ def __init__( base_calculator: MetatomicCalculator, l_max: int = 3, batch_size: Optional[int] = None, - return_o3_samples=False, - apply_group_symmetry=False, - **kwargs, - ): + include_inversion: bool = True, + apply_group_symmetry: bool = False, + return_samples: bool = False, + **kwargs: Any, + ) -> None: try: from scipy.integrate import lebedev_rule # noqa: F401 except ImportError as e: @@ -990,34 +929,43 @@ def __init__( f"l_max={l_max} is too large, the maximum supported value is 131" ) self.l_max = l_max + self.include_inversion = include_inversion if l_max > 0: - lebedev_order, n_inplane_rotations = choose_quadrature(l_max) - self.o3_quadrature_rotations, self.o3_quadrature_weights = ( - _get_o3_quadrature(lebedev_order, n_inplane_rotations) + lebedev_order, n_inplane_rotations = _choose_quadrature(l_max) + self.quadrature_rotations, self.quadrature_weights = _get_quadrature( + lebedev_order, n_inplane_rotations, include_inversion ) else: # no quadrature - self.o3_quadrature_rotations = np.array([np.eye(3)]) - self.o3_quadrature_weights = np.array([1.0]) + self.quadrature_rotations = np.array([np.eye(3)]) + self.quadrature_weights = np.array([1.0]) self.batch_size = ( - batch_size if batch_size is not None else len(self.o3_quadrature_rotations) + batch_size if batch_size is not None else len(self.quadrature_rotations) ) - self.return_o3_samples = return_o3_samples + self.return_samples = return_samples self.apply_group_symmetry = apply_group_symmetry - def calculate(self, atoms, properties, system_changes): - super().calculate(atoms, properties, system_changes) + def calculate( + self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] + ) -> None: + """ + Perform the calculation for the given atoms and properties. - if self.apply_group_symmetry: - Q_list, P_list = _get_group_operations(atoms) + :param atoms: the :py:class:`ase.Atoms` on which to perform the calculation + :param properties: list of properties to compute, among ``energy``, ``forces``, + and ``stress`` + :param system_changes: list of changes to the system since the last call to + ``calculate`` + """ + super().calculate(atoms, properties, system_changes) compute_forces_and_stresses = "forces" in properties or "stress" in properties - if len(self.o3_quadrature_rotations) > 0: - rotated_atoms_list = _rotate_atoms(atoms, self.o3_quadrature_rotations) + if len(self.quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) batches = [ rotated_atoms_list[i : i + self.batch_size] for i in range(0, len(rotated_atoms_list), self.batch_size) @@ -1040,29 +988,32 @@ def calculate(self, atoms, properties, system_changes): "averaging parameters. This can be done by setting the " "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " "parameters while initializing the calculator." - f"Full error message: {e}" - ) - - # Clean up - try: - torch.cuda.empty_cache() - except Exception: - pass + ) from e self.results.update( _compute_rotational_average( - results, self.o3_quadrature_rotations, self.o3_quadrature_weights + results, self.quadrature_rotations, self.quadrature_weights ) ) - if self.return_o3_samples: - self.results["o3_samples"] = results + if self.return_samples: + sample_names = "o3_samples" if self.include_inversion else "so3_samples" + self.results[sample_names] = results if self.apply_group_symmetry: + # Apply the discrete space group of the system a posteriori + Q_list, P_list = _get_group_operations(atoms) self.results.update(_average_over_group(self.results, Q_list, P_list)) -def choose_quadrature(L_max): +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ available = [ 3, 5, @@ -1105,6 +1056,13 @@ def choose_quadrature(L_max): def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + """ + Create a list of copies of ``atoms``, rotated by each of the given ``rotations``. + + :param atoms: the :py:class:`ase.Atoms` to be rotated + :param rotations: (N, 3, 3) array of orthogonal matrices + :return: list of N :py:class:`ase.Atoms`, each rotated by the corresponding matrix + """ rotated_atoms_list = [] has_cell = atoms.cell is not None and atoms.cell.rank > 0 for rot in rotations: @@ -1119,9 +1077,17 @@ def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Ato return rotated_atoms_list -def _get_so3_quadrature(lebedev_order: int, n_rotations: int): +def _get_quadrature(lebedev_order: int, n_rotations: int, include_inversion: bool): """ Lebedev(S^2) x uniform angle quadrature on SO(3). + If include_inversion=True, extend to O(3) by adding inversion * R. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :param include_inversion: if ``True``, include the inversion operation in the + quadrature + :return: (N, 3, 3) array of orthogonal matrices, and (N,) array of weights + associated to each matrix """ from scipy.integrate import lebedev_rule @@ -1138,42 +1104,12 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): Rot = _rotations_from_angles(alpha, beta, gamma) R_so3 = Rot.as_matrix() # (N, 3, 3) - w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) - - return R_so3, w_so3 - - -def _get_o3_quadrature(lebedev_order: int, n_rotations: int): - """ - Lebedev(S^2) x uniform angle quadrature on O(3). - Returns an array of shape (2N, 3, 3) with orthogonal matrices, - the first N in SO(3), the next N in its coset with inversion. - """ - from scipy.integrate import lebedev_rule - - # Lebedev nodes (X: (3, M)) - X, w = lebedev_rule(lebedev_order) # w sums to 4*pi - x, y, z = X - alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(z) # (M,) - # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) - - K = int(n_rotations) - gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) - - Rot = _rotations_from_angles(alpha, beta, gamma) - R_so3 = Rot.as_matrix() # (N, 3, 3) - - # rnd = np.random.uniform(size=(3, 3)) - # rnd = rnd - rnd.T - # import scipy.linalg - - # rnd = scipy.linalg.expm(-rnd) - # R_so3 = R_so3 @ rnd - # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + if not include_inversion: + return R_so3, w_so3 + # Extend to O(3) by appending inversion * R P = -np.eye(3) R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) @@ -1251,12 +1187,19 @@ def _get_group_operations( atom-index permutations P_g (N x N) induced by the space-group operations. Returns Q_list, Cartesian rotation matrices of the point group, and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + + :param atoms: input structure + :param symprec: tolerance for symmetry finding + :param angle_tolerance: tolerance for symmetry finding (in degrees). If less than 0, + a value depending on ``symprec`` will be chosen automatically by spglib. + :return: List of rotation matrices and permutation matrices. + """ try: import spglib except ImportError as e: raise ImportError( - "spglib is required to use the O3AveragedCalculator with " + "spglib is required to use the SymmetrizedCalculator with " "`apply_group_symmetry=True`. Please install it with " "`pip install spglib` or `conda install -c conda-forge spglib`" ) from e @@ -1325,17 +1268,14 @@ def _average_over_group( """ Apply the point-group projector in output space. - Parameters - ---------- - results : dict - Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or 'stress' (3,3). - These are predictions for the *current* structure in the reference frame. - Q_list, P_list : outputs of _get_group_operations - - Returns - ------- - out : dict - Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. + :param results: Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or + 'stress' (3,3). These are predictions for the current structure in the reference + frame. + :param Q_list: Rotation matrices of the point group, from + :py:func:`_get_group_operations` + :param P_list: Permutation matrices of the point group, from + :py:func:`_get_group_operations` + :return out: Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. For stress, also returns 'stress_iso_pg' (L=0) and 'stress_dev_pg'. """ m = len(Q_list) @@ -1355,7 +1295,7 @@ def _average_over_group( return out out = {} - # Energy: unchanged by the projector (scalar invariant) + # Energy: unchanged by the projector (scalar) if "energy" in results: out["energy"] = float(results["energy"]) @@ -1380,9 +1320,5 @@ def _average_over_group( acc += Q.T @ S @ Q S_pg = acc / m out["stress"] = S_pg - # # Expose L=0 projection and deviatoric part for debugging - # S_iso = np.trace(S_pg) / 3.0 - # out["stress_iso_pg"] = np.eye(3) * S_iso - # out["stress_dev_pg"] = S_pg - out["stress_iso_pg"] return out diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py new file mode 100644 index 00000000..977b3a18 --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -0,0 +1,354 @@ +import numpy as np +import pytest +from ase import Atoms + +from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature + + +def _body_axis_from_atoms(atoms: Atoms) -> np.ndarray: + """ + Return the normalized vector connecting the two farthest atoms. + + :param atoms: Atomic configuration. + :return: Normalized 3D vector defining the body axis. + """ + pos = atoms.get_positions() + if len(pos) < 2: + return np.array([0.0, 0.0, 1.0]) + d2 = np.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) + i, j = np.unravel_index(np.argmax(d2), d2.shape) + b = pos[j] - pos[i] + nrm = np.linalg.norm(b) + return b / nrm if nrm > 0 else np.array([0.0, 0.0, 1.0]) + + +def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: + """ + Compute Legendre polynomials P0..P3(c). + + :param c: Cosine between the body axis and the lab z-axis. + :return: Tuple (P0, P1, P2, P3). + """ + P0 = 1.0 + P1 = c + P2 = 0.5 * (3 * c * c - 1.0) + P3 = 0.5 * (5 * c * c * c - 3 * c) + return P0, P1, P2, P3 + + +class MockAnisoCalculator: + """ + Deterministic, rotation-dependent mock for testing SymmetrizedCalculator. + + Components: + - Energy: E_true + a1*P1 + a2*P2 + a3*P3 + - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*ẑ + optional tensor L=2 term + - Stress: p_iso*I + (c2*P2 + c3*P3)*D + + :param a: Coefficients for Legendre P0..P3 in the energy. + :param b: Coefficients for P1..P3 in the forces (spurious vector parts). + :param c: Coefficients for P2,P3 in the stress (spurious deviators). + :param p_iso: Isotropic (true) part of the stress tensor. + :param tensor_forces: If True, add L=2 tensor-coupled force term. + :param tensor_amp: Amplitude of the tensor-coupled force component. + """ + + def __init__( + self, + a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: tuple[float, float, float] = (0.0, 0.0, 0.0), + c: tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, + ) -> None: + self.a0, self.a1, self.a2, self.a3 = a + self.b1, self.b2, self.b3 = b + self.c2, self.c3 = c + self.p_iso = p_iso + self.tensor_forces = tensor_forces + self.tensor_amp = tensor_amp + + def compute_energy( + self, + batch: list[Atoms], + compute_forces_and_stresses: bool = False, + ) -> dict[str, list[np.ndarray | float]]: + """ + Compute deterministic, rotation-dependent properties for each batch entry. + + :param batch: List of atomic configurations. + :param compute_forces_and_stresses: Unused flag for API compatibility. + :return: Dictionary with lists of energies, forces, and stresses. + """ + out: dict[str, list[np.ndarray | float]] = { + "energy": [], + "forces": [], + "stress": [], + } + zhat = np.array([0.0, 0.0, 1.0]) + D = np.diag([1.0, -1.0, 0.0]) + + for atoms in batch: + pos = atoms.get_positions() + b = _body_axis_from_atoms(atoms) + c = float(np.dot(b, zhat)) + P0, P1, P2, P3 = _legendre_0_1_2_3(c) + + # Energy + E_true = float(np.sum(pos**2)) + E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 + + # Forces + F_true = pos.copy() + F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * zhat[None, :] + F = F_true + F_spur + + if self.tensor_forces: + # Build rotation R such that R ẑ = b + v = np.cross(zhat, b) + s = np.linalg.norm(v) + cth = np.dot(zhat, b) + if s < 1e-15: + R = np.eye(3) if cth > 0 else -np.eye(3) + else: + vx = np.array( + [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]] + ) + R = np.eye(3) + vx + vx @ vx * ((1 - cth) / (s**2)) + T = R @ D @ R.T + F_tensor = self.tensor_amp * (T @ zhat) + F = F + F_tensor[None, :] + + # Stress + S = self.p_iso * np.eye(3) + (self.c2 * P2 + self.c3 * P3) * D + + out["energy"].append(E) + out["forces"].append(F) + out["stress"].append(S) + return out + + +@pytest.fixture +def dimer() -> Atoms: + """ + Create a small asymmetric geometry with a well-defined body axis. + + :return: ASE Atoms object with the H2 molecule. + """ + return Atoms("H2", positions=[[0, 0, 0], [0.3, 0.2, 1.0]]) + + +def test_quadrature_normalization() -> None: + """Verify normalization and determinant signs of the quadrature.""" + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=True) + assert np.isclose(np.sum(w), 1.0) + dets = np.linalg.det(R) + assert np.all(np.isin(np.round(dets).astype(int), [-1, 1])) + + +@pytest.mark.parametrize("Lmax, expect_removed", [(0, False), (3, True)]) +def test_energy_L_components_removed( + dimer: Atoms, Lmax: int, expect_removed: bool +) -> None: + """ + Verify that spurious energy components vanish once rotational averaging is applied. + For Lmax>0, all use the same minimal Lebedev rule (order=3). + """ + a = (1.0, 1.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + e = dimer.get_potential_energy() + E_true = float(np.sum(dimer.positions**2)) + if expect_removed: + assert np.isclose(e, E_true + a[0], atol=1e-10) + else: + assert not np.isclose(e, E_true + a[0], atol=1e-10) + + +def test_force_backrotation_exact(dimer: Atoms) -> None: + """ + Check that forces are back-rotated exactly when no spurious terms are present. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-12) + + +def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: + """ + Tensor-coupled (L=2) force components must vanish under O(3) averaging. + + Since the minimal Lebedev order used internally is 3, all quadratures + integrate L=2 components exactly; we only check for correct cancellation. + """ + base = MockAnisoCalculator(tensor_forces=True, tensor_amp=1.0) + + for Lmax in [1, 2, 3]: + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-10) + + +def test_stress_isotropization(dimer: Atoms) -> None: + """ + Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(c=(1.0, 1.0), p_iso=5.0) + calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) + dimer.calc = calc + S = dimer.get_stress(voigt=False) + iso = np.trace(S) / 3.0 + assert np.allclose(S, np.eye(3) * iso, atol=1e-10) + assert np.isclose(iso, 5.0, atol=1e-10) + + +def test_cancellation_vs_Lmax(dimer: Atoms) -> None: + """ + Residual anisotropy must vanish once rotational averaging is applied. + All quadratures with Lmax>0 are equivalent (Lebedev order=3). + """ + a = (0.0, 0.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + E_true = float(np.sum(dimer.positions**2)) + + # No averaging + calc0 = SymmetrizedCalculator(base, l_max=0) + dimer.calc = calc0 + e0 = dimer.get_potential_energy() + + # Averaged + calc3 = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc3 + e3 = dimer.get_potential_energy() + + assert not np.isclose(e0, E_true, atol=1e-10) + assert np.isclose(e3, E_true, atol=1e-10) + + +def test_joint_energy_force_consistency(dimer: Atoms) -> None: + """ + Combined test: both energy and forces are consistent and invariant. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(a=(1, 1, 1, 1), b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + e = dimer.get_potential_energy() + f = dimer.get_forces() + assert np.isclose(e, np.sum(dimer.positions**2) + 1.0, atol=1e-10) + assert np.allclose(f, dimer.positions, atol=1e-12) + + +def test_rotate_atoms_preserves_geometry(tmp_path): + """Check that _rotate_atoms applies rotations correctly and preserves distances.""" + from scipy.spatial.transform import Rotation + + from metatomic.torch.ase_calculator import _rotate_atoms + + # Build simple cubic cell with 2 atoms along x + atoms = Atoms("H2", positions=[[0, 0, 0], [1, 0, 0]], cell=np.eye(3)) + R = Rotation.from_euler("z", 90, degrees=True).as_matrix()[None, ...] # 90° about z + + rotated = _rotate_atoms(atoms, R)[0] + # Positions should now align along y + assert np.allclose( + rotated.positions[1] - rotated.positions[0], [0, 1, 0], atol=1e-12 + ) + # Cell rotated + assert np.allclose(rotated.cell[0], [0, 1, 0], atol=1e-12) + # Distances preserved + d0 = atoms.get_distance(0, 1) + d1 = rotated.get_distance(0, 1) + assert np.isclose(d0, d1, atol=1e-12) + + +def test_choose_quadrature_rules(): + """Check that _choose_quadrature selects appropriate rules.""" + from metatomic.torch.ase_calculator import _choose_quadrature + + for L in [0, 5, 17, 50]: + lebedev_order, n_gamma = _choose_quadrature(L) + assert lebedev_order >= L + assert n_gamma == 2 * L + 1 + + +def test_get_quadrature_properties(): + """Check properties of the quadrature returned by _get_quadrature.""" + from metatomic.torch.ase_calculator import _get_quadrature + + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=False) + assert np.isclose(np.sum(w), 1.0) + assert np.allclose([np.dot(r.T, r) for r in R], np.eye(3), atol=1e-12) + assert np.allclose(np.linalg.det(R), 1.0, atol=1e-12) + + R_inv, w_inv = _get_quadrature( + lebedev_order=11, n_rotations=5, include_inversion=True + ) + assert len(R_inv) == 2 * len(R) + dets = np.linalg.det(R_inv) + assert np.all(np.isin(np.sign(dets).astype(int), [-1, 1])) + assert np.isclose(np.sum(w_inv), 1.0) + + +def test_compute_rotational_average_identity(): + """Check that _compute_rotational_average produces correct averages.""" + from metatomic.torch.ase_calculator import _compute_rotational_average + + R = np.repeat(np.eye(3)[None, :, :], 3, axis=0) + w = np.ones(3) / 3 + results = { + "energy": np.array([1.0, 2.0, 3.0]), + "forces": np.array([[[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]]), + "stress": np.array([np.eye(3), 2 * np.eye(3), 3 * np.eye(3)]), + } + out = _compute_rotational_average(results, R, w) + assert np.isclose(out["energy"], np.mean(results["energy"])) + assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) + assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) + + +def test_average_over_fcc_group(): + """ + Check that averaging over the space group of an FCC crystal + produces an isotropic (scalar) stress tensor. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # FCC conventional cubic cell (4 atoms) + a0 = 4.05 + atoms = Atoms( + "Cu4", + positions=[ + [0, 0, 0], + [0, 0.5, 0.5], + [0.5, 0, 0.5], + [0.5, 0.5, 0], + ], + cell=a0 * np.eye(3), + pbc=True, + ) + + # Create an intentionally anisotropic stress + stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) + results = {"stress": stress} + + Q_list, P_list = _get_group_operations(atoms) + out = _average_over_group(results, Q_list, P_list) + S_pg = out["stress"] + + # The averaged stress must be isotropic: S_pg = (trace/3)*I + iso = np.trace(S_pg) / 3.0 + assert np.allclose(S_pg, np.eye(3) * iso, atol=1e-8) From 6eec538516020736c6f1abcead10162cde7b3ed2 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:04:00 +0200 Subject: [PATCH 12/14] lint --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 7e886d0f..0840bae0 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -217,9 +217,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs @@ -1221,7 +1221,7 @@ def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): seen = set() Ainv = np.linalg.inv(A) - for Rf, tf in zip(R_frac, t_frac): + for Rf, tf in zip(R_frac, t_frac, strict=False): # Cartesian rotation: Q = A Rf A^{-1} Q = A @ Rf @ Ainv # Deduplicate rotations (point group) by rounding @@ -1286,7 +1286,7 @@ def _average_over_group( if F.ndim != 2 or F.shape[1] != 3: raise ValueError(f"'forces' must be (N,3), got {F.shape}") acc = np.zeros_like(F) - for Q, P in zip(Q_list, P_list): + for Q, P in zip(Q_list, P_list, strict=False): acc += P.T @ (F @ Q) out["forces"] = acc / m From a935bf4da0cf53490266350b1dadb1846cb1b3d9 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:04:12 +0200 Subject: [PATCH 13/14] add deps for testing --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index 86a3b3e8..795382c7 100644 --- a/tox.ini +++ b/tox.ini @@ -150,6 +150,9 @@ deps = # for metatensor-lj-test setuptools-scm cmake + # for symmetrized calculator + scipy + spglib changedir = python/metatomic_torch commands = From bb96a6eebe06c249309ab6387d2784d516df4af4 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:06:50 +0200 Subject: [PATCH 14/14] Add mention in the docs --- docs/src/engines/ase.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index da8bce2d..fed91cdd 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -23,6 +23,8 @@ Supported model outputs :py:meth:`ase.Atoms.get_forces`, …); - arbitrary outputs can be computed for any :py:class:`ase.Atoms` using :py:meth:`MetatomicCalculator.run_model`; +- for non-equivariant architectures like PET, rotatonally-averaged energies, forces, + and stresses can be computed using :py:class:`SymmetrizedCalculator`. How to install the code ^^^^^^^^^^^^^^^^^^^^^^^