diff --git a/src/sage/rings/generic.py b/src/sage/rings/generic.py index 99bf690bef6..e0145c8e828 100644 --- a/src/sage/rings/generic.py +++ b/src/sage/rings/generic.py @@ -28,10 +28,17 @@ class ProductTree: sage: R. = F[] sage: ms = [x - a^i for i in range(1024)] # roots of unity sage: ys = [F.random_element() for _ in range(1024)] # input vector - sage: zs = ProductTree(ms).remainders(R(ys)) # compute FFT! + sage: tree = ProductTree(ms) + sage: zs = tree.remainders(R(ys)) # compute FFT! sage: zs == [R(ys) % m for m in ms] True + Similarly, the :meth:`interpolation` method can be used to implement + the inverse Fast Fourier Transform:: + + sage: tree.interpolation(zs).padded_list(len(ys)) == ys + True + This class encodes the tree as *layers*: Layer `0` is just a tuple of the leaves. Layer `i+1` is obtained from layer `i` by replacing each pair of two adjacent elements by their product, starting from @@ -177,7 +184,6 @@ def remainders(self, x): The base ring must support the ``%`` operator for this method to work. - INPUT: - ``x`` -- an element of the base ring of this product tree @@ -199,6 +205,55 @@ def remainders(self, x): X = [X[i // 2] % V[i] for i in range(len(V))] return X + _crt_bases = None + + def interpolation(self, xs): + r""" + Given a sequence ``xs`` of values, one per leaf, return a + single element `x` which is congruent to the `i`\th value in + ``xs`` modulo the `i`\th leaf, for all `i`. + + This is an explicit version of the Chinese remainder theorem; + see also :meth:`CRT`. Using this product tree is faster for + repeated calls since the required CRT bases are cached after + the first run. + + The base ring must support the :func:`xgcd` function for this + method to work. + + EXAMPLES:: + + sage: from sage.rings.generic import ProductTree + sage: vs = prime_range(100) + sage: tree = ProductTree(vs) + sage: tree.interpolation([1, 1, 2, 1, 9, 1, 7, 15, 8, 20, 15, 6, 27, 11, 2, 6, 0, 25, 49, 5, 51, 4, 19, 74, 13]) + 1085749272377676749812331719267 + + This method is faster than :func:`CRT` for repeated calls with + the same moduli:: + + sage: vs = prime_range(1000,2000) + sage: rs = lambda: [randrange(1,100) for _ in vs] + sage: tree = ProductTree(vs) + sage: %timeit CRT(rs(), vs) # not tested + 372 µs ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) + sage: %timeit tree.interpolation(rs()) # not tested + 146 µs ± 479 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) + """ + if self._crt_bases is None: + from sage.arith.misc import CRT_basis + self._crt_bases = [] + for V in self.layers[:-1]: + B = tuple(CRT_basis(V[i:i+2]) for i in range(0, len(V), 2)) + self._crt_bases.append(B) + if len(xs) != len(self.layers[0]): + raise ValueError('number of given elements must equal the number of leaves') + for basis, layer in zip(self._crt_bases, self.layers[1:]): + xs = [sum(c*x for c, x in zip(cs, xs[2*i:2*i+2])) % mod + for i, (cs, mod) in enumerate(zip(basis, layer))] + assert len(xs) == 1 + return xs[0] + def prod_with_derivative(pairs): r"""