Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions src/sage/rings/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ class ProductTree:
sage: R.<x> = F[]
sage: ms = [x - a^i for i in range(1024)] # roots of unity
sage: ys = [F.random_element() for _ in range(1024)] # input vector
sage: zs = ProductTree(ms).remainders(R(ys)) # compute FFT!
sage: tree = ProductTree(ms)
sage: zs = tree.remainders(R(ys)) # compute FFT!
sage: zs == [R(ys) % m for m in ms]
True

Similarly, the :meth:`interpolation` method can be used to implement
the inverse Fast Fourier Transform::

sage: tree.interpolation(zs).padded_list(len(ys)) == ys
True

This class encodes the tree as *layers*: Layer `0` is just a tuple
of the leaves. Layer `i+1` is obtained from layer `i` by replacing
each pair of two adjacent elements by their product, starting from
Expand Down Expand Up @@ -177,7 +184,6 @@ def remainders(self, x):
The base ring must support the ``%`` operator for this
method to work.


INPUT:

- ``x`` -- an element of the base ring of this product tree
Expand All @@ -199,6 +205,55 @@ def remainders(self, x):
X = [X[i // 2] % V[i] for i in range(len(V))]
return X

_crt_bases = None

def interpolation(self, xs):
r"""
Given a sequence ``xs`` of values, one per leaf, return a
single element `x` which is congruent to the `i`\th value in
``xs`` modulo the `i`\th leaf, for all `i`.

This is an explicit version of the Chinese remainder theorem;
see also :meth:`CRT`. Using this product tree is faster for
repeated calls since the required CRT bases are cached after
the first run.

The base ring must support the :func:`xgcd` function for this
method to work.

EXAMPLES::

sage: from sage.rings.generic import ProductTree
sage: vs = prime_range(100)
sage: tree = ProductTree(vs)
sage: tree.interpolation([1, 1, 2, 1, 9, 1, 7, 15, 8, 20, 15, 6, 27, 11, 2, 6, 0, 25, 49, 5, 51, 4, 19, 74, 13])
1085749272377676749812331719267

This method is faster than :func:`CRT` for repeated calls with
the same moduli::

sage: vs = prime_range(1000,2000)
sage: rs = lambda: [randrange(1,100) for _ in vs]
sage: tree = ProductTree(vs)
sage: %timeit CRT(rs(), vs) # not tested
372 µs ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
sage: %timeit tree.interpolation(rs()) # not tested
146 µs ± 479 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
"""
if self._crt_bases is None:
from sage.arith.misc import CRT_basis
self._crt_bases = []
for V in self.layers[:-1]:
B = tuple(CRT_basis(V[i:i+2]) for i in range(0, len(V), 2))
self._crt_bases.append(B)
if len(xs) != len(self.layers[0]):
raise ValueError('number of given elements must equal the number of leaves')
for basis, layer in zip(self._crt_bases, self.layers[1:]):
xs = [sum(c*x for c, x in zip(cs, xs[2*i:2*i+2])) % mod
for i, (cs, mod) in enumerate(zip(basis, layer))]
assert len(xs) == 1
return xs[0]


def prod_with_derivative(pairs):
r"""
Expand Down